Merge remote-tracking branch 'myori/main' into feat/collaboration2

This commit is contained in:
hjlarry
2026-04-15 16:59:11 +08:00
21 changed files with 1134 additions and 725 deletions

View File

@ -0,0 +1,26 @@
"""add qdrant_endpoint to tidb_auth_bindings
Revision ID: 8574b23a38fd
Revises: 6b5f9f8b1a2c
Create Date: 2026-04-14 15:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "8574b23a38fd"
down_revision = "6b5f9f8b1a2c"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True))
def downgrade():
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
batch_op.drop_column("qdrant_endpoint")

View File

@ -1305,6 +1305,7 @@ class TidbAuthBinding(TypeBase):
)
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -1,4 +1,5 @@
import json
import logging
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
@ -7,6 +8,8 @@ from typing import TYPE_CHECKING, Any
import httpx
import qdrant_client
logger = logging.getLogger(__name__)
from flask import current_app
from httpx import DigestAuth
from pydantic import BaseModel
@ -421,13 +424,16 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding:
logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id)
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding:
logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id)
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
@ -437,11 +443,18 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
.limit(1)
)
if idle_tidb_auth_binding:
logger.info(
"Assigning idle cluster %s to tenant %s",
idle_tidb_auth_binding.cluster_id,
dataset.tenant_id,
)
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
tidb_auth_binding = idle_tidb_auth_binding
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id)
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "",
dify_config.TIDB_API_URL or "",
@ -450,21 +463,39 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
dify_config.TIDB_PRIVATE_KEY or "",
dify_config.TIDB_REGION or "",
)
logger.info(
"New cluster created: cluster_id=%s, qdrant_endpoint=%s",
new_cluster["cluster_id"],
new_cluster.get("qdrant_endpoint"),
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
tenant_id=dataset.tenant_id,
active=True,
status=TidbAuthBindingStatus.ACTIVE,
)
db.session.add(new_tidb_auth_binding)
db.session.commit()
tidb_auth_binding = new_tidb_auth_binding
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id)
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
qdrant_url = (
(tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or ""
)
logger.info(
"Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)",
qdrant_url,
tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None,
dify_config.TIDB_ON_QDRANT_URL,
)
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@ -479,7 +510,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
endpoint=qdrant_url,
api_key=TIDB_ON_QDRANT_API_KEY,
root_path=str(config.root_path),
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,

View File

@ -1,3 +1,4 @@
import logging
import time
import uuid
from collections.abc import Sequence
@ -12,6 +13,8 @@ from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
logger = logging.getLogger(__name__)
# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn
_tidb_http_client: httpx.Client = get_pooled_http_client(
"tidb:cloud",
@ -20,6 +23,46 @@ _tidb_http_client: httpx.Client = get_pooled_http_client(
class TidbService:
@staticmethod
def extract_qdrant_endpoint(cluster_response: dict) -> str | None:
"""Extract the qdrant endpoint URL from a Get Cluster API response.
Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``),
prepends ``qdrant-`` and wraps it as an ``https://`` URL.
"""
endpoints = cluster_response.get("endpoints") or {}
public = endpoints.get("public") or {}
host = public.get("host")
if host:
return f"https://qdrant-{host}"
return None
@staticmethod
def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None:
"""Call Get Cluster API and extract the qdrant endpoint.
Use ``extract_qdrant_endpoint`` instead when you already have
the cluster response to avoid a redundant API call.
"""
try:
logger.info("Fetching qdrant endpoint for cluster %s", cluster_id)
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if not cluster_response:
logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id)
return None
qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response)
if qdrant_url:
logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url)
return qdrant_url
logger.warning(
"No endpoints.public.host found for cluster %s, response keys: %s",
cluster_id,
list(cluster_response.keys()),
)
except Exception:
logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id)
return None
@staticmethod
def create_tidb_serverless_cluster(
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
@ -57,6 +100,7 @@ class TidbService:
"rootPassword": password,
}
logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region)
response = _tidb_http_client.post(
f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)
)
@ -64,21 +108,39 @@ class TidbService:
if response.status_code == 200:
response_data = response.json()
cluster_id = response_data["clusterId"]
logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id)
retry_count = 0
max_retries = 30
while retry_count < max_retries:
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if cluster_response["state"] == "ACTIVE":
user_prefix = cluster_response["userPrefix"]
qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response)
logger.info(
"Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s",
cluster_id,
user_prefix,
qdrant_endpoint,
)
return {
"cluster_id": cluster_id,
"cluster_name": display_name,
"account": f"{user_prefix}.root",
"password": password,
"qdrant_endpoint": qdrant_endpoint,
}
time.sleep(30) # wait 30 seconds before retrying
logger.info(
"Cluster %s state=%s, retry %d/%d",
cluster_id,
cluster_response["state"],
retry_count + 1,
max_retries,
)
time.sleep(30)
retry_count += 1
logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries)
else:
logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text)
response.raise_for_status()
@staticmethod
@ -243,19 +305,29 @@ class TidbService:
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
logger.info("Batch created %d clusters", len(response_data.get("clusters", [])))
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
cached_password = redis_client.get(cache_key)
if not cached_password:
logger.warning("No cached password for cluster %s, skipping", item["displayName"])
continue
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
logger.info(
"Batch cluster %s: qdrant_endpoint=%s",
item["clusterId"],
qdrant_endpoint,
)
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": cached_password.decode("utf-8"),
"qdrant_endpoint": qdrant_endpoint,
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text)
response.raise_for_status()
return []

View File

@ -114,14 +114,12 @@ class TestTidbOnQdrantVectorDeleteByIds:
assert exc_info.value.status_code == 500
def test_delete_by_ids_with_large_batch(self, vector_instance):
"""Test deletion with a large batch of IDs."""
# Create 1000 IDs
def test_delete_by_ids_with_exactly_1000(self, vector_instance):
"""Test deletion with exactly 1000 IDs triggers a single batch."""
ids = [f"doc_{i}" for i in range(1000)]
vector_instance.delete_by_ids(ids)
# Verify single delete call with all IDs
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
@ -129,11 +127,28 @@ class TestTidbOnQdrantVectorDeleteByIds:
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
# Verify all 1000 IDs are in the batch
assert len(field_condition.match.any) == 1000
assert "doc_0" in field_condition.match.any
assert "doc_999" in field_condition.match.any
def test_delete_by_ids_splits_into_batches(self, vector_instance):
"""Test deletion with >1000 IDs triggers multiple batched calls."""
ids = [f"doc_{i}" for i in range(2500)]
vector_instance.delete_by_ids(ids)
assert vector_instance._client.delete.call_count == 3
batches = []
for call in vector_instance._client.delete.call_args_list:
filter_selector = call[1]["points_selector"]
field_condition = filter_selector.filter.must[0]
batches.append(field_condition.match.any)
assert len(batches[0]) == 1000
assert len(batches[1]) == 1000
assert len(batches[2]) == 500
def test_delete_by_ids_filter_structure(self, vector_instance):
"""Test that the filter structure is correctly constructed."""
ids = ["doc1", "doc2"]
@ -157,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds:
# Verify MatchAny structure
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ids
class TestInitVectorEndpointSelection:
"""Test that init_vector selects the correct qdrant endpoint.
We avoid importing the full module (which triggers Flask app context)
by testing the endpoint selection logic directly on TidbOnQdrantConfig.
"""
def test_uses_binding_endpoint_when_present(self):
binding_endpoint = "https://qdrant-custom.tidb.com"
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-custom.tidb.com"
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == "https://qdrant-custom.tidb.com"
def test_falls_back_to_global_when_binding_endpoint_is_none(self):
binding_endpoint = None
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-global.tidb.com"
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == "https://qdrant-global.tidb.com"
def test_falls_back_to_empty_when_both_none(self):
binding_endpoint = None
global_url = None
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == ""
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == ""
def test_binding_endpoint_takes_precedence_over_global(self):
binding_endpoint = "https://qdrant-ap-southeast.tidb.com"
global_url = "https://qdrant-us-east.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-ap-southeast.tidb.com"
def test_empty_string_binding_endpoint_falls_back_to_global(self):
binding_endpoint = ""
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-global.tidb.com"

View File

@ -0,0 +1,218 @@
from unittest.mock import MagicMock, patch
import pytest
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
class TestExtractQdrantEndpoint:
"""Unit tests for TidbService.extract_qdrant_endpoint."""
def test_returns_endpoint_when_host_present(self):
response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}}
result = TidbService.extract_qdrant_endpoint(response)
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
def test_returns_none_when_host_missing(self):
response = {"endpoints": {"public": {}}}
assert TidbService.extract_qdrant_endpoint(response) is None
def test_returns_none_when_public_missing(self):
response = {"endpoints": {}}
assert TidbService.extract_qdrant_endpoint(response) is None
def test_returns_none_when_endpoints_missing(self):
assert TidbService.extract_qdrant_endpoint({}) is None
class TestFetchQdrantEndpoint:
"""Unit tests for TidbService.fetch_qdrant_endpoint."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_endpoint_when_host_present(self, mock_get_cluster):
mock_get_cluster.return_value = {
"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}
}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster):
mock_get_cluster.return_value = None
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_host_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {"endpoints": {"public": {}}}
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_endpoints_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {}
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_on_exception(self, mock_get_cluster):
mock_get_cluster.side_effect = RuntimeError("network error")
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
class TestCreateTidbServerlessClusterQdrantEndpoint:
"""Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {
"state": "ACTIVE",
"userPrefix": "pfx",
"endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}},
}
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] is None
class TestBatchCreateTidbServerlessClusterQdrantEndpoint:
"""Verify that batch_create includes qdrant_endpoint per cluster."""
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
mock_config.TIDB_SPEND_LIMIT = 10
cluster_name = "abc123"
mock_http.post.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]},
)
mock_redis.setex = MagicMock()
mock_redis.get.return_value = b"password123"
result = TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)
assert len(result) == 1
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
class TestCreateTidbServerlessClusterRetry:
"""Cover retry/logging paths in create_tidb_serverless_cluster."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.side_effect = [
{"state": "CREATING", "userPrefix": ""},
{"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}},
]
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com"
assert mock_get_cluster.call_count == 2
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""}
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is None
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_raises_on_post_failure(self, mock_config, mock_http):
mock_config.TIDB_SPEND_LIMIT = 10
mock_response = MagicMock(status_code=400, text="Bad Request")
mock_response.raise_for_status.side_effect = Exception("HTTP 400")
mock_http.post.return_value = mock_response
with pytest.raises(Exception, match="HTTP 400"):
TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
class TestBatchCreateEdgeCases:
"""Cover logging/edge-case branches in batch_create."""
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]},
)
mock_redis.setex = MagicMock()
mock_redis.get.return_value = None
result = TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)
assert len(result) == 0
mock_fetch_ep.assert_not_called()
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis):
mock_config.TIDB_SPEND_LIMIT = 10
mock_response = MagicMock(status_code=500, text="Server Error")
mock_response.raise_for_status.side_effect = Exception("HTTP 500")
mock_http.post.return_value = mock_response
mock_redis.setex = MagicMock()
with pytest.raises(Exception, match="HTTP 500"):
TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)

View File

@ -57,6 +57,7 @@ def create_clusters(batch_size):
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
active=False,
status=TidbAuthBindingStatus.CREATING,
)

View File

@ -455,7 +455,7 @@ class AppDslService:
app.updated_by = account.id
self._session.add(app)
self._session.commit()
self._session.flush()
app_was_created.send(app, account=account)
# save dependencies

View File

@ -0,0 +1,507 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import AppTriggerStatus, AppTriggerType
from models.model import App
from models.trigger import AppTrigger, WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger.webhook_service import WebhookService
class WebhookServiceRelationshipFactory:
@staticmethod
def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]:
account = Account(
name=f"Account {uuid4()}",
email=f"webhook-{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App:
app = App(
tenant_id=tenant.id,
name=f"Webhook App {uuid4()}",
description="",
mode="workflow",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
@staticmethod
def create_workflow(
db_session_with_containers: Session,
*,
app: App,
account: Account,
node_ids: list[str],
version: str,
) -> Workflow:
graph = {
"nodes": [
{
"id": node_id,
"data": {
"type": TRIGGER_WEBHOOK_NODE_TYPE,
"title": f"Webhook {node_id}",
"method": "post",
"content_type": "application/json",
"headers": [],
"params": [],
"body": [],
"status_code": 200,
"response_body": '{"status": "ok"}',
"timeout": 30,
},
}
for node_id in node_ids
],
"edges": [],
}
workflow = Workflow(
tenant_id=app.tenant_id,
app_id=app.id,
type="workflow",
graph=json.dumps(graph),
features=json.dumps({}),
created_by=account.id,
updated_by=account.id,
environment_variables=[],
conversation_variables=[],
version=version,
)
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
return workflow
@staticmethod
def create_webhook_trigger(
db_session_with_containers: Session,
*,
app: App,
account: Account,
node_id: str,
webhook_id: str | None = None,
) -> WorkflowWebhookTrigger:
webhook_trigger = WorkflowWebhookTrigger(
app_id=app.id,
node_id=node_id,
tenant_id=app.tenant_id,
webhook_id=webhook_id or uuid4().hex[:24],
created_by=account.id,
)
db_session_with_containers.add(webhook_trigger)
db_session_with_containers.commit()
return webhook_trigger
@staticmethod
def create_app_trigger(
db_session_with_containers: Session,
*,
app: App,
node_id: str,
status: AppTriggerStatus,
) -> AppTrigger:
app_trigger = AppTrigger(
tenant_id=app.tenant_id,
app_id=app.id,
node_id=node_id,
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
provider_name="webhook",
title=f"Webhook {node_id}",
status=status,
)
db_session_with_containers.add(app_trigger)
db_session_with_containers.commit()
return app_trigger
class TestWebhookServiceLookupWithContainers:
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED
)
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED
)
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED
)
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["published-node"],
version="2026-04-14.001",
)
draft_workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["debug-node"],
version=Workflow.VERSION_DRAFT,
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="debug-node"
)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
webhook_trigger.webhook_id,
is_debug=True,
)
assert got_trigger.id == webhook_trigger.id
assert got_workflow.id == draft_workflow.id
assert got_node_config["id"] == "debug-node"
class TestWebhookServiceTriggerExecutionWithContainers:
def test_trigger_workflow_execution_triggers_async_workflow_successfully(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
end_user = SimpleNamespace(id=str(uuid4()))
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
return_value=end_user,
),
patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume,
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
):
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
mock_consume.assert_called_once_with(webhook_trigger.tenant_id)
mock_trigger.assert_called_once()
trigger_args = mock_trigger.call_args.args
assert trigger_args[1] is end_user
assert trigger_args[2].workflow_id == workflow.id
assert trigger_args[2].root_node_id == webhook_trigger.node_id
def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
return_value=SimpleNamespace(id=str(uuid4())),
),
patch(
"services.trigger.webhook_service.QuotaType.TRIGGER.consume",
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
),
patch(
"services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited"
) as mock_mark_rate_limited,
):
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
workflow,
)
mock_mark_rate_limited.assert_called_once_with(tenant.id)
def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
side_effect=RuntimeError("boom"),
),
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
):
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
workflow,
)
mock_logger_exception.assert_called_once()
class TestWebhookServiceRelationshipSyncWithContainers:
def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)]
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT
)
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_raises_when_lock_not_acquired(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT
)
lock = MagicMock()
lock.acquire.return_value = False
with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock):
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
stale_trigger = factory.create_webhook_trigger(
db_session_with_containers,
app=app,
account=account,
node_id="node-stale",
webhook_id="stale-webhook-id-000001",
)
stale_trigger_id = stale_trigger.id
workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["node-new"],
version=Workflow.VERSION_DRAFT,
)
with patch(
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001"
):
WebhookService.sync_webhook_relationships(app, workflow)
db_session_with_containers.expire_all()
records = db_session_with_containers.scalars(
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id)
).all()
assert [record.node_id for record in records] == ["node-new"]
assert records[0].webhook_id == "new-webhook-id-000001"
assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None
def test_sync_webhook_relationships_sets_redis_cache_for_new_record(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["node-cache"],
version=Workflow.VERSION_DRAFT,
)
cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache"
with patch(
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001"
):
WebhookService.sync_webhook_relationships(app, workflow)
cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key)
assert cached_payload is not None
assert cached_payload["node_id"] == "node-cache"
assert cached_payload["webhook_id"] == "cache-webhook-id-00001"
def test_sync_webhook_relationships_logs_when_lock_release_fails(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT
)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
with (
patch("services.trigger.webhook_service.redis_client.lock", return_value=lock),
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
):
WebhookService.sync_webhook_relationships(app, workflow)
mock_logger_exception.assert_called_once()
def _read_cache(cache_key: str) -> dict[str, str] | None:
from extensions.ext_redis import redis_client
cached = redis_client.get(cache_key)
if not cached:
return None
if isinstance(cached, bytes):
cached = cached.decode("utf-8")
return json.loads(cached)
WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache)

View File

@ -1,5 +1,5 @@
from types import SimpleNamespace
from typing import Any, cast
from typing import Any
from unittest.mock import MagicMock
import pytest
@ -13,11 +13,6 @@ from core.workflow.nodes.trigger_webhook.entities import (
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
@ -39,156 +34,13 @@ class _FakeQuery:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
class TestWebhookServiceLookup:
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_session = MagicMock()
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1",
is_debug=True,
)
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
def _workflow_trigger(**kwargs: Any) -> Any:
return SimpleNamespace(**kwargs)
class TestWebhookServiceExtractionFallbacks:
@ -420,237 +272,6 @@ class TestWebhookServiceValidationAndConversion:
assert result["webhook_body"] == {"b": 2}
class TestWebhookServiceExecutionAndSync:
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=end_user),
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
quota_type = SimpleNamespace(
TRIGGER=SimpleNamespace(
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
)
)
monkeypatch.setattr(service_module, "QuotaType", quota_type)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(
service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock
)
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(side_effect=RuntimeError("boom")),
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit(self) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
WebhookService.sync_webhook_relationships(app, workflow)
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
WebhookService.sync_webhook_relationships(app, workflow)
assert logger_exception_mock.call_count == 1
class TestWebhookServiceUtilities:
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json(self) -> None:
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}

View File

@ -7,8 +7,6 @@ import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datase
import DatasetCardFooter from '../components/dataset-card-footer'
import Description from '../components/description'
import DatasetCard from '../index'
import OperationItem from '../operation-item'
import Operations from '../operations'
// Mock external hooks only
vi.mock('@/hooks/use-format-time-from-now', () => ({
@ -62,8 +60,8 @@ vi.mock('../components/tag-area', () => ({
<div ref={ref} data-testid="tag-area" onClick={onClick} />
)),
}))
vi.mock('../components/operations-popover', () => ({
default: () => <div data-testid="operations-popover" />,
vi.mock('../components/operations-dropdown', () => ({
default: () => <div data-testid="operations-dropdown" />,
}))
// Factory function for DataSet mock data
@ -233,152 +231,6 @@ describe('DatasetCard Integration', () => {
})
})
})
// Integration tests for OperationItem component
describe('OperationItem', () => {
const MockIcon = ({ className }: { className?: string }) => (
<svg data-testid="mock-icon" className={className} />
)
describe('Rendering', () => {
it('should render icon and name', () => {
render(<OperationItem Icon={MockIcon as never} name="Edit" />)
expect(screen.getByText('Edit')).toBeInTheDocument()
expect(screen.getByTestId('mock-icon')).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call handleClick when clicked', () => {
const handleClick = vi.fn()
render(<OperationItem Icon={MockIcon as never} name="Delete" handleClick={handleClick} />)
const item = screen.getByText('Delete').closest('div')
fireEvent.click(item!)
expect(handleClick).toHaveBeenCalledTimes(1)
})
it('should prevent default and stop propagation on click', () => {
const handleClick = vi.fn()
render(<OperationItem Icon={MockIcon as never} name="Action" handleClick={handleClick} />)
const item = screen.getByText('Action').closest('div')
const event = new MouseEvent('click', { bubbles: true, cancelable: true })
const preventDefaultSpy = vi.spyOn(event, 'preventDefault')
const stopPropagationSpy = vi.spyOn(event, 'stopPropagation')
item!.dispatchEvent(event)
expect(preventDefaultSpy).toHaveBeenCalled()
expect(stopPropagationSpy).toHaveBeenCalled()
})
})
describe('Edge Cases', () => {
it('should not throw when handleClick is undefined', () => {
render(<OperationItem Icon={MockIcon as never} name="No handler" />)
const item = screen.getByText('No handler').closest('div')
expect(() => {
fireEvent.click(item!)
}).not.toThrow()
})
it('should handle empty name', () => {
render(<OperationItem Icon={MockIcon as never} name="" />)
expect(screen.getByTestId('mock-icon')).toBeInTheDocument()
})
})
})
// Integration tests for Operations component
describe('Operations', () => {
const defaultProps = {
showDelete: true,
showExportPipeline: true,
openRenameModal: vi.fn(),
handleExportPipeline: vi.fn(),
detectIsUsedByApp: vi.fn(),
}
describe('Rendering', () => {
it('should always render edit operation', () => {
render(<Operations {...defaultProps} />)
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
})
it('should render export pipeline when showExportPipeline is true', () => {
render(<Operations {...defaultProps} showExportPipeline={true} />)
expect(screen.getByText(/exportPipeline/)).toBeInTheDocument()
})
it('should not render export pipeline when showExportPipeline is false', () => {
render(<Operations {...defaultProps} showExportPipeline={false} />)
expect(screen.queryByText(/exportPipeline/)).not.toBeInTheDocument()
})
it('should render delete when showDelete is true', () => {
render(<Operations {...defaultProps} showDelete={true} />)
expect(screen.getByText(/operation\.delete/)).toBeInTheDocument()
})
it('should not render delete when showDelete is false', () => {
render(<Operations {...defaultProps} showDelete={false} />)
expect(screen.queryByText(/operation\.delete/)).not.toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call openRenameModal when edit is clicked', () => {
const openRenameModal = vi.fn()
render(<Operations {...defaultProps} openRenameModal={openRenameModal} />)
const editItem = screen.getByText(/operation\.edit/).closest('div')
fireEvent.click(editItem!)
expect(openRenameModal).toHaveBeenCalledTimes(1)
})
it('should call handleExportPipeline when export is clicked', () => {
const handleExportPipeline = vi.fn()
render(<Operations {...defaultProps} handleExportPipeline={handleExportPipeline} />)
const exportItem = screen.getByText(/exportPipeline/).closest('div')
fireEvent.click(exportItem!)
expect(handleExportPipeline).toHaveBeenCalledTimes(1)
})
it('should call detectIsUsedByApp when delete is clicked', () => {
const detectIsUsedByApp = vi.fn()
render(<Operations {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
const deleteItem = screen.getByText(/operation\.delete/).closest('div')
fireEvent.click(deleteItem!)
expect(detectIsUsedByApp).toHaveBeenCalledTimes(1)
})
})
describe('Edge Cases', () => {
it('should render only edit when both showDelete and showExportPipeline are false', () => {
render(<Operations {...defaultProps} showDelete={false} showExportPipeline={false} />)
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
expect(screen.queryByText(/exportPipeline/)).not.toBeInTheDocument()
expect(screen.queryByText(/operation\.delete/)).not.toBeInTheDocument()
})
it('should render divider before delete section when showDelete is true', () => {
const { container } = render(<Operations {...defaultProps} showDelete={true} />)
expect(container.querySelector('.bg-divider-subtle')).toBeInTheDocument()
})
it('should not render divider when showDelete is false', () => {
const { container } = render(<Operations {...defaultProps} showDelete={false} />)
expect(container.querySelector('.bg-divider-subtle')).toBeNull()
})
})
})
})
describe('DatasetCard Component', () => {

View File

@ -1,7 +1,16 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import { DropdownMenu } from '@/app/components/base/ui/dropdown-menu'
import Operations from '../operations'
function renderInMenu(ui: React.ReactElement) {
return render(
<DropdownMenu open>
{ui}
</DropdownMenu>,
)
}
describe('Operations', () => {
const defaultProps = {
showDelete: true,
@ -17,100 +26,65 @@ describe('Operations', () => {
describe('Rendering', () => {
it('should render without crashing', () => {
render(<Operations {...defaultProps} />)
// Edit operation should always be visible
renderInMenu(<Operations {...defaultProps} />)
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
})
it('should render edit operation', () => {
render(<Operations {...defaultProps} />)
renderInMenu(<Operations {...defaultProps} />)
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
})
it('should render export pipeline operation when showExportPipeline is true', () => {
render(<Operations {...defaultProps} showExportPipeline={true} />)
renderInMenu(<Operations {...defaultProps} showExportPipeline={true} />)
expect(screen.getByText(/exportPipeline/)).toBeInTheDocument()
})
it('should not render export pipeline operation when showExportPipeline is false', () => {
render(<Operations {...defaultProps} showExportPipeline={false} />)
renderInMenu(<Operations {...defaultProps} showExportPipeline={false} />)
expect(screen.queryByText(/exportPipeline/)).not.toBeInTheDocument()
})
it('should render delete operation when showDelete is true', () => {
render(<Operations {...defaultProps} showDelete={true} />)
renderInMenu(<Operations {...defaultProps} showDelete={true} />)
expect(screen.getByText(/operation\.delete/)).toBeInTheDocument()
})
it('should not render delete operation when showDelete is false', () => {
render(<Operations {...defaultProps} showDelete={false} />)
renderInMenu(<Operations {...defaultProps} showDelete={false} />)
expect(screen.queryByText(/operation\.delete/)).not.toBeInTheDocument()
})
})
describe('Props', () => {
it('should render divider when showDelete is true', () => {
const { container } = render(<Operations {...defaultProps} showDelete={true} />)
const divider = container.querySelector('.bg-divider-subtle')
expect(divider).toBeInTheDocument()
})
it('should not render divider when showDelete is false', () => {
const { container } = render(<Operations {...defaultProps} showDelete={false} />)
// Should not have the divider-subtle one (the separator before delete)
expect(container.querySelector('.bg-divider-subtle')).toBeNull()
})
})
describe('User Interactions', () => {
it('should call openRenameModal when edit is clicked', () => {
const openRenameModal = vi.fn()
render(<Operations {...defaultProps} openRenameModal={openRenameModal} />)
const editItem = screen.getByText(/operation\.edit/).closest('div')
fireEvent.click(editItem!)
renderInMenu(<Operations {...defaultProps} openRenameModal={openRenameModal} />)
fireEvent.click(screen.getByText(/operation\.edit/))
expect(openRenameModal).toHaveBeenCalledTimes(1)
})
it('should call handleExportPipeline when export is clicked', () => {
const handleExportPipeline = vi.fn()
render(<Operations {...defaultProps} handleExportPipeline={handleExportPipeline} />)
const exportItem = screen.getByText(/exportPipeline/).closest('div')
fireEvent.click(exportItem!)
renderInMenu(<Operations {...defaultProps} handleExportPipeline={handleExportPipeline} />)
fireEvent.click(screen.getByText(/exportPipeline/))
expect(handleExportPipeline).toHaveBeenCalledTimes(1)
})
it('should call detectIsUsedByApp when delete is clicked', () => {
const detectIsUsedByApp = vi.fn()
render(<Operations {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
const deleteItem = screen.getByText(/operation\.delete/).closest('div')
fireEvent.click(deleteItem!)
renderInMenu(<Operations {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
fireEvent.click(screen.getByText(/operation\.delete/))
expect(detectIsUsedByApp).toHaveBeenCalledTimes(1)
})
})
describe('Styles', () => {
it('should have correct container styling', () => {
const { container } = render(<Operations {...defaultProps} />)
const operationsContainer = container.firstChild
expect(operationsContainer).toHaveClass(
'relative',
'flex',
'w-full',
'flex-col',
'rounded-xl',
)
})
})
describe('Edge Cases', () => {
it('should render only edit when both showDelete and showExportPipeline are false', () => {
render(<Operations {...defaultProps} showDelete={false} showExportPipeline={false} />)
renderInMenu(<Operations {...defaultProps} showDelete={false} showExportPipeline={false} />)
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
expect(screen.queryByText(/exportPipeline/)).not.toBeInTheDocument()
expect(screen.queryByText(/operation\.delete/)).not.toBeInTheDocument()

View File

@ -80,7 +80,7 @@ describe('CornerLabels', () => {
const dataset = createMockDataset({ embedding_available: false })
const { container } = render(<CornerLabels dataset={dataset} />)
const labelContainer = container.firstChild as HTMLElement
expect(labelContainer).toHaveClass('absolute', 'right-0', 'top-0', 'z-10')
expect(labelContainer).toHaveClass('absolute', 'right-0', 'top-0', 'z-5')
})
it('should have correct positioning for pipeline label', () => {
@ -90,7 +90,7 @@ describe('CornerLabels', () => {
})
const { container } = render(<CornerLabels dataset={dataset} />)
const labelContainer = container.firstChild as HTMLElement
expect(labelContainer).toHaveClass('absolute', 'right-0', 'top-0', 'z-10')
expect(labelContainer).toHaveClass('absolute', 'right-0', 'top-0', 'z-5')
})
})

View File

@ -1,11 +1,11 @@
import type { DataSet } from '@/models/datasets'
import { fireEvent, render } from '@testing-library/react'
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import { IndexingType } from '@/app/components/datasets/create/step-two'
import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets'
import OperationsPopover from '../operations-popover'
import OperationsDropdown from '../operations-dropdown'
describe('OperationsPopover', () => {
describe('OperationsDropdown', () => {
const createMockDataset = (overrides: Partial<DataSet> = {}): DataSet => ({
id: 'dataset-1',
name: 'Test Dataset',
@ -42,102 +42,143 @@ describe('OperationsPopover', () => {
describe('Rendering', () => {
it('should render without crashing', () => {
const { container } = render(<OperationsPopover {...defaultProps} />)
const { container } = render(<OperationsDropdown {...defaultProps} />)
expect(container.firstChild).toBeInTheDocument()
})
it('should render the more icon button', () => {
const { container } = render(<OperationsPopover {...defaultProps} />)
const moreIcon = container.querySelector('svg')
const { container } = render(<OperationsDropdown {...defaultProps} />)
const moreIcon = container.querySelector('.i-ri-more-fill')
expect(moreIcon).toBeInTheDocument()
})
it('should render in hidden state initially (group-hover)', () => {
const { container } = render(<OperationsPopover {...defaultProps} />)
const { container } = render(<OperationsDropdown {...defaultProps} />)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveClass('hidden', 'group-hover:block')
expect(wrapper).toHaveClass(
'invisible',
'pointer-events-none',
'group-hover:visible',
'group-hover:pointer-events-auto',
)
})
})
describe('Props', () => {
it('should show delete option when not workspace dataset operator', () => {
render(<OperationsPopover {...defaultProps} isCurrentWorkspaceDatasetOperator={false} />)
render(<OperationsDropdown {...defaultProps} isCurrentWorkspaceDatasetOperator={false} />)
const triggerButton = document.querySelector('[class*="cursor-pointer"]')
if (triggerButton)
fireEvent.click(triggerButton)
// showDelete should be true (inverse of isCurrentWorkspaceDatasetOperator)
// This means delete operation will be visible
})
it('should hide delete option when is workspace dataset operator', () => {
render(<OperationsPopover {...defaultProps} isCurrentWorkspaceDatasetOperator={true} />)
render(<OperationsDropdown {...defaultProps} isCurrentWorkspaceDatasetOperator={true} />)
const triggerButton = document.querySelector('[class*="cursor-pointer"]')
if (triggerButton)
fireEvent.click(triggerButton)
// showDelete should be false
})
it('should show export pipeline when runtime_mode is rag_pipeline', () => {
const dataset = createMockDataset({ runtime_mode: 'rag_pipeline' })
render(<OperationsPopover {...defaultProps} dataset={dataset} />)
render(<OperationsDropdown {...defaultProps} dataset={dataset} />)
const triggerButton = document.querySelector('[class*="cursor-pointer"]')
if (triggerButton)
fireEvent.click(triggerButton)
// showExportPipeline should be true
})
it('should hide export pipeline when runtime_mode is not rag_pipeline', () => {
const dataset = createMockDataset({ runtime_mode: 'general' })
render(<OperationsPopover {...defaultProps} dataset={dataset} />)
render(<OperationsDropdown {...defaultProps} dataset={dataset} />)
const triggerButton = document.querySelector('[class*="cursor-pointer"]')
if (triggerButton)
fireEvent.click(triggerButton)
// showExportPipeline should be false
})
})
describe('Styles', () => {
it('should have correct positioning styles', () => {
const { container } = render(<OperationsPopover {...defaultProps} />)
const { container } = render(<OperationsDropdown {...defaultProps} />)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveClass('absolute', 'right-2', 'top-2', 'z-15')
expect(wrapper).toHaveClass('absolute', 'right-2', 'top-2', 'z-5')
})
it('should keep the trigger mounted when closed so menu exit animations retain an anchor', () => {
const { container } = render(<OperationsDropdown {...defaultProps} />)
const wrapper = container.firstChild as HTMLElement
const trigger = container.querySelector('[aria-label="Dataset operations"]')
expect(wrapper).not.toHaveClass('hidden')
expect(trigger).toBeInTheDocument()
})
it('should have icon with correct size classes', () => {
const { container } = render(<OperationsPopover {...defaultProps} />)
const icon = container.querySelector('svg')
const { container } = render(<OperationsDropdown {...defaultProps} />)
const icon = container.querySelector('.i-ri-more-fill')
expect(icon).toHaveClass('h-5', 'w-5', 'text-text-tertiary')
})
it('should have aria-label on trigger for accessibility', () => {
const { container } = render(<OperationsDropdown {...defaultProps} />)
const trigger = container.querySelector('[aria-label="Dataset operations"]')
expect(trigger).toBeInTheDocument()
})
it('should expose visible keyboard focus styles on the trigger', () => {
const { container } = render(<OperationsDropdown {...defaultProps} />)
const trigger = container.querySelector('[aria-label="Dataset operations"]')
expect(trigger).toHaveClass(
'focus-visible:outline-hidden',
'focus-visible:ring-1',
'focus-visible:ring-inset',
'focus-visible:ring-components-input-border-hover',
)
})
it('should use a solid trigger background without backdrop blur on hover states', () => {
const { container } = render(<OperationsDropdown {...defaultProps} />)
const trigger = container.querySelector('[aria-label="Dataset operations"]')
expect(trigger).toHaveClass('bg-components-button-secondary-bg')
expect(trigger).not.toHaveClass('hover:backdrop-blur-[5px]', 'backdrop-blur-[5px]')
})
})
describe('User Interactions', () => {
it('should keep outside interactions available when the menu is open', () => {
const onOutsideClick = vi.fn()
render(
<div>
<button type="button" onClick={onOutsideClick}>Outside action</button>
<OperationsDropdown {...defaultProps} />
</div>,
)
fireEvent.click(screen.getByLabelText('Dataset operations'))
fireEvent.click(screen.getByRole('button', { name: 'Outside action' }))
expect(onOutsideClick).toHaveBeenCalledTimes(1)
})
it('should pass openRenameModal to Operations', () => {
const openRenameModal = vi.fn()
render(<OperationsPopover {...defaultProps} openRenameModal={openRenameModal} />)
// The openRenameModal should be passed to Operations component
expect(openRenameModal).not.toHaveBeenCalled() // Initially not called
render(<OperationsDropdown {...defaultProps} openRenameModal={openRenameModal} />)
expect(openRenameModal).not.toHaveBeenCalled()
})
it('should pass handleExportPipeline to Operations', () => {
const handleExportPipeline = vi.fn()
render(<OperationsPopover {...defaultProps} handleExportPipeline={handleExportPipeline} />)
render(<OperationsDropdown {...defaultProps} handleExportPipeline={handleExportPipeline} />)
expect(handleExportPipeline).not.toHaveBeenCalled()
})
it('should pass detectIsUsedByApp to Operations', () => {
const detectIsUsedByApp = vi.fn()
render(<OperationsPopover {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
render(<OperationsDropdown {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
expect(detectIsUsedByApp).not.toHaveBeenCalled()
})
})
@ -145,13 +186,13 @@ describe('OperationsPopover', () => {
describe('Edge Cases', () => {
it('should handle dataset with external provider', () => {
const dataset = createMockDataset({ provider: 'external' })
const { container } = render(<OperationsPopover {...defaultProps} dataset={dataset} />)
const { container } = render(<OperationsDropdown {...defaultProps} dataset={dataset} />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle dataset with undefined runtime_mode', () => {
const dataset = createMockDataset({ runtime_mode: undefined })
const { container } = render(<OperationsPopover {...defaultProps} dataset={dataset} />)
const { container } = render(<OperationsDropdown {...defaultProps} dataset={dataset} />)
expect(container.firstChild).toBeInTheDocument()
})
})

View File

@ -14,7 +14,7 @@ const CornerLabels = ({ dataset }: CornerLabelsProps) => {
return (
<CornerLabel
label={t('cornerLabel.unavailable', { ns: 'dataset' })}
className="absolute right-0 top-0 z-10"
className="absolute top-0 right-0 z-5"
labelClassName="rounded-tr-xl"
/>
)
@ -24,7 +24,7 @@ const CornerLabels = ({ dataset }: CornerLabelsProps) => {
return (
<CornerLabel
label={t('cornerLabel.pipeline', { ns: 'dataset' })}
className="absolute right-0 top-0 z-10"
className="absolute top-0 right-0 z-5"
labelClassName="rounded-tr-xl"
/>
)

View File

@ -0,0 +1,68 @@
import type { DataSet } from '@/models/datasets'
import * as React from 'react'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuTrigger,
} from '@/app/components/base/ui/dropdown-menu'
import { cn } from '@/utils/classnames'
import Operations from '../operations'
type OperationsDropdownProps = {
dataset: DataSet
isCurrentWorkspaceDatasetOperator: boolean
openRenameModal: () => void
handleExportPipeline: (include?: boolean) => void
detectIsUsedByApp: () => void
}
const OperationsDropdown = ({
dataset,
isCurrentWorkspaceDatasetOperator,
openRenameModal,
handleExportPipeline,
detectIsUsedByApp,
}: OperationsDropdownProps) => {
const [open, setOpen] = React.useState(false)
return (
<div
className={cn(
'absolute top-2 right-2 z-5',
open
? 'pointer-events-auto visible'
: 'pointer-events-none invisible group-hover:pointer-events-auto group-hover:visible',
)}
onClick={e => e.stopPropagation()}
>
<DropdownMenu modal={false} open={open} onOpenChange={setOpen}>
<DropdownMenuTrigger
className={cn(
'inline-flex size-9 cursor-pointer items-center justify-center radius-lg border-[0.5px]',
'border-components-actionbar-border bg-components-button-secondary-bg p-0 shadow-lg ring-2 shadow-shadow-shadow-5 ring-components-button-secondary-bg ring-inset',
'transition-colors hover:border-components-actionbar-border hover:bg-state-base-hover',
'focus-visible:bg-state-base-hover focus-visible:ring-1 focus-visible:ring-components-input-border-hover focus-visible:outline-hidden focus-visible:ring-inset',
open && 'bg-state-base-hover',
)}
aria-label="Dataset operations"
>
<span className="i-ri-more-fill h-5 w-5 text-text-tertiary" />
</DropdownMenuTrigger>
<DropdownMenuContent
placement="bottom-end"
popupClassName="min-w-[186px]"
>
<Operations
showDelete={!isCurrentWorkspaceDatasetOperator}
showExportPipeline={dataset.runtime_mode === 'rag_pipeline'}
openRenameModal={openRenameModal}
handleExportPipeline={handleExportPipeline}
detectIsUsedByApp={detectIsUsedByApp}
/>
</DropdownMenuContent>
</DropdownMenu>
</div>
)
}
export default React.memo(OperationsDropdown)

View File

@ -1,52 +0,0 @@
import type { DataSet } from '@/models/datasets'
import { RiMoreFill } from '@remixicon/react'
import * as React from 'react'
import CustomPopover from '@/app/components/base/popover'
import { cn } from '@/utils/classnames'
import Operations from '../operations'
type OperationsPopoverProps = {
dataset: DataSet
isCurrentWorkspaceDatasetOperator: boolean
openRenameModal: () => void
handleExportPipeline: (include?: boolean) => void
detectIsUsedByApp: () => void
}
const OperationsPopover = ({
dataset,
isCurrentWorkspaceDatasetOperator,
openRenameModal,
handleExportPipeline,
detectIsUsedByApp,
}: OperationsPopoverProps) => (
<div className="absolute right-2 top-2 z-15 hidden group-hover:block">
<CustomPopover
htmlContent={(
<Operations
showDelete={!isCurrentWorkspaceDatasetOperator}
showExportPipeline={dataset.runtime_mode === 'rag_pipeline'}
openRenameModal={openRenameModal}
handleExportPipeline={handleExportPipeline}
detectIsUsedByApp={detectIsUsedByApp}
/>
)}
className="z-20 min-w-[186px]"
popupClassName="rounded-xl bg-none shadow-none ring-0 min-w-[186px]"
position="br"
trigger="click"
btnElement={(
<div className="flex size-8 items-center justify-center radius-lg hover:bg-state-base-hover">
<RiMoreFill className="h-5 w-5 text-text-tertiary" />
</div>
)}
btnClassName={open =>
cn(
'size-9 cursor-pointer justify-center radius-lg border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0 shadow-lg shadow-shadow-shadow-5 ring-2 ring-inset ring-components-actionbar-bg hover:border-components-actionbar-border',
open ? 'border-components-actionbar-border bg-state-base-hover' : '',
)}
/>
</div>
)
export default React.memo(OperationsPopover)

View File

@ -9,7 +9,7 @@ import DatasetCardFooter from './components/dataset-card-footer'
import DatasetCardHeader from './components/dataset-card-header'
import DatasetCardModals from './components/dataset-card-modals'
import Description from './components/description'
import OperationsPopover from './components/operations-popover'
import OperationsDropdown from './components/operations-dropdown'
import TagArea from './components/tag-area'
import { useDatasetCardState } from './hooks/use-dataset-card-state'
@ -82,7 +82,7 @@ const DatasetCard = ({
onClick={handleTagAreaClick}
/>
<DatasetCardFooter dataset={dataset} />
<OperationsPopover
<OperationsDropdown
dataset={dataset}
isCurrentWorkspaceDatasetOperator={isCurrentWorkspaceDatasetOperator}
openRenameModal={openRenameModal}

View File

@ -1,8 +1,9 @@
import { RiDeleteBinLine, RiEditLine, RiFileDownloadLine } from '@remixicon/react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import Divider from '@/app/components/base/divider'
import OperationItem from './operation-item'
import {
DropdownMenuItem,
DropdownMenuSeparator,
} from '@/app/components/base/ui/dropdown-menu'
type OperationsProps = {
showDelete: boolean
@ -22,34 +23,27 @@ const Operations = ({
const { t } = useTranslation()
return (
<div className="relative flex w-full flex-col rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg shadow-shadow-shadow-5">
<div className="flex flex-col p-1">
<OperationItem
Icon={RiEditLine}
name={t('operation.edit', { ns: 'common' })}
handleClick={openRenameModal}
/>
{showExportPipeline && (
<OperationItem
Icon={RiFileDownloadLine}
name={t('operations.exportPipeline', { ns: 'datasetPipeline' })}
handleClick={handleExportPipeline}
/>
)}
</div>
<>
<DropdownMenuItem onClick={openRenameModal}>
<span aria-hidden className="i-ri-edit-line size-4 text-text-tertiary" />
{t('operation.edit', { ns: 'common' })}
</DropdownMenuItem>
{showExportPipeline && (
<DropdownMenuItem onClick={handleExportPipeline}>
<span aria-hidden className="i-ri-file-download-line size-4 text-text-tertiary" />
{t('operations.exportPipeline', { ns: 'datasetPipeline' })}
</DropdownMenuItem>
)}
{showDelete && (
<>
<Divider type="horizontal" className="my-0 bg-divider-subtle" />
<div className="flex flex-col p-1">
<OperationItem
Icon={RiDeleteBinLine}
name={t('operation.delete', { ns: 'common' })}
handleClick={detectIsUsedByApp}
/>
</div>
<DropdownMenuSeparator />
<DropdownMenuItem destructive onClick={detectIsUsedByApp}>
<span aria-hidden className="i-ri-delete-bin-line size-4" />
{t('operation.delete', { ns: 'common' })}
</DropdownMenuItem>
</>
)}
</div>
</>
)
}

View File

@ -5,7 +5,6 @@ import { useBoolean, useDebounceFn } from 'ahooks'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development'
import Input from '@/app/components/base/input'
import TagManagementModal from '@/app/components/base/tag-management'
import TagFilter from '@/app/components/base/tag-management/filter'
@ -85,11 +84,11 @@ const List = () => {
}
<div className="h-4 w-px bg-divider-regular" />
<Button
className="shadows-shadow-xs gap-0.5"
className="gap-0.5 shadow-xs"
onClick={() => setShowExternalApiPanel(true)}
>
<ApiConnectionMod className="h-4 w-4 text-components-button-secondary-text" />
<div className="flex items-center justify-center gap-1 px-0.5 system-sm-medium text-components-button-secondary-text">{t('externalAPIPanelTitle', { ns: 'dataset' })}</div>
<span className="i-custom-vender-solid-development-api-connection-mod h-4 w-4 text-components-button-secondary-text" />
<span className="flex items-center justify-center gap-1 px-0.5 system-sm-medium text-components-button-secondary-text">{t('externalAPIPanelTitle', { ns: 'dataset' })}</span>
</Button>
</div>
</div>

View File

@ -4503,11 +4503,6 @@
"count": 3
}
},
"app/components/datasets/list/dataset-card/components/corner-labels.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 2
}
},
"app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx": {
"no-restricted-imports": {
"count": 1
@ -4526,14 +4521,6 @@
"count": 1
}
},
"app/components/datasets/list/dataset-card/components/operations-popover.tsx": {
"no-restricted-imports": {
"count": 1
},
"tailwindcss/enforce-consistent-class-order": {
"count": 2
}
},
"app/components/datasets/list/dataset-card/components/tag-area.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 1