mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 05:06:15 +08:00
Merge remote-tracking branch 'myori/main' into feat/collaboration2
This commit is contained in:
@ -0,0 +1,26 @@
|
||||
"""add qdrant_endpoint to tidb_auth_bindings
|
||||
|
||||
Revision ID: 8574b23a38fd
|
||||
Revises: 6b5f9f8b1a2c
|
||||
Create Date: 2026-04-14 15:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8574b23a38fd"
|
||||
down_revision = "6b5f9f8b1a2c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
|
||||
batch_op.drop_column("qdrant_endpoint")
|
||||
@ -1305,6 +1305,7 @@ class TidbAuthBinding(TypeBase):
|
||||
)
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
@ -7,6 +8,8 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
import qdrant_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from flask import current_app
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel
|
||||
@ -421,13 +424,16 @@ class TidbOnQdrantVector(BaseVector):
|
||||
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if not tidb_auth_binding:
|
||||
logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id)
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if tidb_auth_binding:
|
||||
logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id)
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
@ -437,11 +443,18 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
.limit(1)
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
logger.info(
|
||||
"Assigning idle cluster %s to tenant %s",
|
||||
idle_tidb_auth_binding.cluster_id,
|
||||
dataset.tenant_id,
|
||||
)
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
tidb_auth_binding = idle_tidb_auth_binding
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id)
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
dify_config.TIDB_API_URL or "",
|
||||
@ -450,21 +463,39 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
dify_config.TIDB_PRIVATE_KEY or "",
|
||||
dify_config.TIDB_REGION or "",
|
||||
)
|
||||
logger.info(
|
||||
"New cluster created: cluster_id=%s, qdrant_endpoint=%s",
|
||||
new_cluster["cluster_id"],
|
||||
new_cluster.get("qdrant_endpoint"),
|
||||
)
|
||||
new_tidb_auth_binding = TidbAuthBinding(
|
||||
cluster_id=new_cluster["cluster_id"],
|
||||
cluster_name=new_cluster["cluster_name"],
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status=TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
tidb_auth_binding = new_tidb_auth_binding
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
else:
|
||||
logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id)
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
qdrant_url = (
|
||||
(tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or ""
|
||||
)
|
||||
logger.info(
|
||||
"Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)",
|
||||
qdrant_url,
|
||||
tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None,
|
||||
dify_config.TIDB_ON_QDRANT_URL,
|
||||
)
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
@ -479,7 +510,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=TidbOnQdrantConfig(
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
|
||||
endpoint=qdrant_url,
|
||||
api_key=TIDB_ON_QDRANT_API_KEY,
|
||||
root_path=str(config.root_path),
|
||||
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
@ -12,6 +13,8 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn
|
||||
_tidb_http_client: httpx.Client = get_pooled_http_client(
|
||||
"tidb:cloud",
|
||||
@ -20,6 +23,46 @@ _tidb_http_client: httpx.Client = get_pooled_http_client(
|
||||
|
||||
|
||||
class TidbService:
|
||||
@staticmethod
|
||||
def extract_qdrant_endpoint(cluster_response: dict) -> str | None:
|
||||
"""Extract the qdrant endpoint URL from a Get Cluster API response.
|
||||
|
||||
Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``),
|
||||
prepends ``qdrant-`` and wraps it as an ``https://`` URL.
|
||||
"""
|
||||
endpoints = cluster_response.get("endpoints") or {}
|
||||
public = endpoints.get("public") or {}
|
||||
host = public.get("host")
|
||||
if host:
|
||||
return f"https://qdrant-{host}"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None:
|
||||
"""Call Get Cluster API and extract the qdrant endpoint.
|
||||
|
||||
Use ``extract_qdrant_endpoint`` instead when you already have
|
||||
the cluster response to avoid a redundant API call.
|
||||
"""
|
||||
try:
|
||||
logger.info("Fetching qdrant endpoint for cluster %s", cluster_id)
|
||||
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||
if not cluster_response:
|
||||
logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id)
|
||||
return None
|
||||
qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response)
|
||||
if qdrant_url:
|
||||
logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url)
|
||||
return qdrant_url
|
||||
logger.warning(
|
||||
"No endpoints.public.host found for cluster %s, response keys: %s",
|
||||
cluster_id,
|
||||
list(cluster_response.keys()),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_tidb_serverless_cluster(
|
||||
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||
@ -57,6 +100,7 @@ class TidbService:
|
||||
"rootPassword": password,
|
||||
}
|
||||
|
||||
logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region)
|
||||
response = _tidb_http_client.post(
|
||||
f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)
|
||||
)
|
||||
@ -64,21 +108,39 @@ class TidbService:
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_id = response_data["clusterId"]
|
||||
logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id)
|
||||
retry_count = 0
|
||||
max_retries = 30
|
||||
while retry_count < max_retries:
|
||||
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||
if cluster_response["state"] == "ACTIVE":
|
||||
user_prefix = cluster_response["userPrefix"]
|
||||
qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response)
|
||||
logger.info(
|
||||
"Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s",
|
||||
cluster_id,
|
||||
user_prefix,
|
||||
qdrant_endpoint,
|
||||
)
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"cluster_name": display_name,
|
||||
"account": f"{user_prefix}.root",
|
||||
"password": password,
|
||||
"qdrant_endpoint": qdrant_endpoint,
|
||||
}
|
||||
time.sleep(30) # wait 30 seconds before retrying
|
||||
logger.info(
|
||||
"Cluster %s state=%s, retry %d/%d",
|
||||
cluster_id,
|
||||
cluster_response["state"],
|
||||
retry_count + 1,
|
||||
max_retries,
|
||||
)
|
||||
time.sleep(30)
|
||||
retry_count += 1
|
||||
logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries)
|
||||
else:
|
||||
logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
@ -243,19 +305,29 @@ class TidbService:
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
logger.info("Batch created %d clusters", len(response_data.get("clusters", [])))
|
||||
for item in response_data["clusters"]:
|
||||
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
|
||||
cached_password = redis_client.get(cache_key)
|
||||
if not cached_password:
|
||||
logger.warning("No cached password for cluster %s, skipping", item["displayName"])
|
||||
continue
|
||||
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
|
||||
logger.info(
|
||||
"Batch cluster %s: qdrant_endpoint=%s",
|
||||
item["clusterId"],
|
||||
qdrant_endpoint,
|
||||
)
|
||||
cluster_info = {
|
||||
"cluster_id": item["clusterId"],
|
||||
"cluster_name": item["displayName"],
|
||||
"account": "root",
|
||||
"password": cached_password.decode("utf-8"),
|
||||
"qdrant_endpoint": qdrant_endpoint,
|
||||
}
|
||||
cluster_infos.append(cluster_info)
|
||||
return cluster_infos
|
||||
else:
|
||||
logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text)
|
||||
response.raise_for_status()
|
||||
return []
|
||||
|
||||
@ -114,14 +114,12 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
def test_delete_by_ids_with_large_batch(self, vector_instance):
|
||||
"""Test deletion with a large batch of IDs."""
|
||||
# Create 1000 IDs
|
||||
def test_delete_by_ids_with_exactly_1000(self, vector_instance):
|
||||
"""Test deletion with exactly 1000 IDs triggers a single batch."""
|
||||
ids = [f"doc_{i}" for i in range(1000)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify single delete call with all IDs
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
@ -129,11 +127,28 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
|
||||
# Verify all 1000 IDs are in the batch
|
||||
assert len(field_condition.match.any) == 1000
|
||||
assert "doc_0" in field_condition.match.any
|
||||
assert "doc_999" in field_condition.match.any
|
||||
|
||||
def test_delete_by_ids_splits_into_batches(self, vector_instance):
|
||||
"""Test deletion with >1000 IDs triggers multiple batched calls."""
|
||||
ids = [f"doc_{i}" for i in range(2500)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
assert vector_instance._client.delete.call_count == 3
|
||||
|
||||
batches = []
|
||||
for call in vector_instance._client.delete.call_args_list:
|
||||
filter_selector = call[1]["points_selector"]
|
||||
field_condition = filter_selector.filter.must[0]
|
||||
batches.append(field_condition.match.any)
|
||||
|
||||
assert len(batches[0]) == 1000
|
||||
assert len(batches[1]) == 1000
|
||||
assert len(batches[2]) == 500
|
||||
|
||||
def test_delete_by_ids_filter_structure(self, vector_instance):
|
||||
"""Test that the filter structure is correctly constructed."""
|
||||
ids = ["doc1", "doc2"]
|
||||
@ -157,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
# Verify MatchAny structure
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ids
|
||||
|
||||
|
||||
class TestInitVectorEndpointSelection:
|
||||
"""Test that init_vector selects the correct qdrant endpoint.
|
||||
|
||||
We avoid importing the full module (which triggers Flask app context)
|
||||
by testing the endpoint selection logic directly on TidbOnQdrantConfig.
|
||||
"""
|
||||
|
||||
def test_uses_binding_endpoint_when_present(self):
|
||||
binding_endpoint = "https://qdrant-custom.tidb.com"
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-custom.tidb.com"
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == "https://qdrant-custom.tidb.com"
|
||||
|
||||
def test_falls_back_to_global_when_binding_endpoint_is_none(self):
|
||||
binding_endpoint = None
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-global.tidb.com"
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == "https://qdrant-global.tidb.com"
|
||||
|
||||
def test_falls_back_to_empty_when_both_none(self):
|
||||
binding_endpoint = None
|
||||
global_url = None
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == ""
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == ""
|
||||
|
||||
def test_binding_endpoint_takes_precedence_over_global(self):
|
||||
binding_endpoint = "https://qdrant-ap-southeast.tidb.com"
|
||||
global_url = "https://qdrant-us-east.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-ap-southeast.tidb.com"
|
||||
|
||||
def test_empty_string_binding_endpoint_falls_back_to_global(self):
|
||||
binding_endpoint = ""
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-global.tidb.com"
|
||||
|
||||
@ -0,0 +1,218 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
|
||||
|
||||
|
||||
class TestExtractQdrantEndpoint:
|
||||
"""Unit tests for TidbService.extract_qdrant_endpoint."""
|
||||
|
||||
def test_returns_endpoint_when_host_present(self):
|
||||
response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}}
|
||||
result = TidbService.extract_qdrant_endpoint(response)
|
||||
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
|
||||
|
||||
def test_returns_none_when_host_missing(self):
|
||||
response = {"endpoints": {"public": {}}}
|
||||
assert TidbService.extract_qdrant_endpoint(response) is None
|
||||
|
||||
def test_returns_none_when_public_missing(self):
|
||||
response = {"endpoints": {}}
|
||||
assert TidbService.extract_qdrant_endpoint(response) is None
|
||||
|
||||
def test_returns_none_when_endpoints_missing(self):
|
||||
assert TidbService.extract_qdrant_endpoint({}) is None
|
||||
|
||||
|
||||
class TestFetchQdrantEndpoint:
|
||||
"""Unit tests for TidbService.fetch_qdrant_endpoint."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_endpoint_when_host_present(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {
|
||||
"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}
|
||||
}
|
||||
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
|
||||
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = None
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_host_missing(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {"endpoints": {"public": {}}}
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_endpoints_missing(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {}
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_on_exception(self, mock_get_cluster):
|
||||
mock_get_cluster.side_effect = RuntimeError("network error")
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
|
||||
class TestCreateTidbServerlessClusterQdrantEndpoint:
|
||||
"""Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {
|
||||
"state": "ACTIVE",
|
||||
"userPrefix": "pfx",
|
||||
"endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}},
|
||||
}
|
||||
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
|
||||
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] is None
|
||||
|
||||
|
||||
class TestBatchCreateTidbServerlessClusterQdrantEndpoint:
|
||||
"""Verify that batch_create includes qdrant_endpoint per cluster."""
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
cluster_name = "abc123"
|
||||
mock_http.post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]},
|
||||
)
|
||||
mock_redis.setex = MagicMock()
|
||||
mock_redis.get.return_value = b"password123"
|
||||
|
||||
result = TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
|
||||
|
||||
|
||||
class TestCreateTidbServerlessClusterRetry:
|
||||
"""Cover retry/logging paths in create_tidb_serverless_cluster."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.side_effect = [
|
||||
{"state": "CREATING", "userPrefix": ""},
|
||||
{"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}},
|
||||
]
|
||||
|
||||
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com"
|
||||
assert mock_get_cluster.call_count == 2
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""}
|
||||
|
||||
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_raises_on_post_failure(self, mock_config, mock_http):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_response = MagicMock(status_code=400, text="Bad Request")
|
||||
mock_response.raise_for_status.side_effect = Exception("HTTP 400")
|
||||
mock_http.post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 400"):
|
||||
TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
|
||||
class TestBatchCreateEdgeCases:
|
||||
"""Cover logging/edge-case branches in batch_create."""
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]},
|
||||
)
|
||||
mock_redis.setex = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
assert len(result) == 0
|
||||
mock_fetch_ep.assert_not_called()
|
||||
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_response = MagicMock(status_code=500, text="Server Error")
|
||||
mock_response.raise_for_status.side_effect = Exception("HTTP 500")
|
||||
mock_http.post.return_value = mock_response
|
||||
mock_redis.setex = MagicMock()
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
@ -57,6 +57,7 @@ def create_clusters(batch_size):
|
||||
cluster_name=new_cluster["cluster_name"],
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
|
||||
active=False,
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
|
||||
@ -455,7 +455,7 @@ class AppDslService:
|
||||
app.updated_by = account.id
|
||||
|
||||
self._session.add(app)
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
# save dependencies
|
||||
|
||||
@ -0,0 +1,507 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class WebhookServiceRelationshipFactory:
|
||||
@staticmethod
|
||||
def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]:
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"webhook-{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App:
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"Webhook App {uuid4()}",
|
||||
description="",
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="bot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_workflow(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
app: App,
|
||||
account: Account,
|
||||
node_ids: list[str],
|
||||
version: str,
|
||||
) -> Workflow:
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": TRIGGER_WEBHOOK_NODE_TYPE,
|
||||
"title": f"Webhook {node_id}",
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"headers": [],
|
||||
"params": [],
|
||||
"body": [],
|
||||
"status_code": 200,
|
||||
"response_body": '{"status": "ok"}',
|
||||
"timeout": 30,
|
||||
},
|
||||
}
|
||||
for node_id in node_ids
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
workflow = Workflow(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
version=version,
|
||||
)
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
return workflow
|
||||
|
||||
@staticmethod
|
||||
def create_webhook_trigger(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
app: App,
|
||||
account: Account,
|
||||
node_id: str,
|
||||
webhook_id: str | None = None,
|
||||
) -> WorkflowWebhookTrigger:
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
app_id=app.id,
|
||||
node_id=node_id,
|
||||
tenant_id=app.tenant_id,
|
||||
webhook_id=webhook_id or uuid4().hex[:24],
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(webhook_trigger)
|
||||
db_session_with_containers.commit()
|
||||
return webhook_trigger
|
||||
|
||||
@staticmethod
|
||||
def create_app_trigger(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
app: App,
|
||||
node_id: str,
|
||||
status: AppTriggerStatus,
|
||||
) -> AppTrigger:
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
node_id=node_id,
|
||||
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
|
||||
provider_name="webhook",
|
||||
title=f"Webhook {node_id}",
|
||||
status=status,
|
||||
)
|
||||
db_session_with_containers.add(app_trigger)
|
||||
db_session_with_containers.commit()
|
||||
return app_trigger
|
||||
|
||||
|
||||
class TestWebhookServiceLookupWithContainers:
|
||||
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
factory.create_app_trigger(
|
||||
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
factory.create_app_trigger(
|
||||
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
factory.create_app_trigger(
|
||||
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
factory.create_workflow(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_ids=["published-node"],
|
||||
version="2026-04-14.001",
|
||||
)
|
||||
draft_workflow = factory.create_workflow(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_ids=["debug-node"],
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="debug-node"
|
||||
)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
webhook_trigger.webhook_id,
|
||||
is_debug=True,
|
||||
)
|
||||
|
||||
assert got_trigger.id == webhook_trigger.id
|
||||
assert got_workflow.id == draft_workflow.id
|
||||
assert got_node_config["id"] == "debug-node"
|
||||
|
||||
|
||||
class TestWebhookServiceTriggerExecutionWithContainers:
|
||||
def test_trigger_workflow_execution_triggers_async_workflow_successfully(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
|
||||
end_user = SimpleNamespace(id=str(uuid4()))
|
||||
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
|
||||
return_value=end_user,
|
||||
),
|
||||
patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume,
|
||||
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
|
||||
):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
mock_consume.assert_called_once_with(webhook_trigger.tenant_id)
|
||||
mock_trigger.assert_called_once()
|
||||
trigger_args = mock_trigger.call_args.args
|
||||
assert trigger_args[1] is end_user
|
||||
assert trigger_args[2].workflow_id == workflow.id
|
||||
assert trigger_args[2].root_node_id == webhook_trigger.node_id
|
||||
|
||||
def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
|
||||
return_value=SimpleNamespace(id=str(uuid4())),
|
||||
),
|
||||
patch(
|
||||
"services.trigger.webhook_service.QuotaType.TRIGGER.consume",
|
||||
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
|
||||
),
|
||||
patch(
|
||||
"services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited"
|
||||
) as mock_mark_rate_limited,
|
||||
):
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
|
||||
workflow,
|
||||
)
|
||||
|
||||
mock_mark_rate_limited.assert_called_once_with(tenant.id)
|
||||
|
||||
def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
|
||||
workflow,
|
||||
)
|
||||
|
||||
mock_logger_exception.assert_called_once()
|
||||
|
||||
|
||||
class TestWebhookServiceRelationshipSyncWithContainers:
|
||||
def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)]
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_raises_when_lock_not_acquired(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT
|
||||
)
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
|
||||
with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock):
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
stale_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_id="node-stale",
|
||||
webhook_id="stale-webhook-id-000001",
|
||||
)
|
||||
stale_trigger_id = stale_trigger.id
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_ids=["node-new"],
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001"
|
||||
):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
records = db_session_with_containers.scalars(
|
||||
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id)
|
||||
).all()
|
||||
|
||||
assert [record.node_id for record in records] == ["node-new"]
|
||||
assert records[0].webhook_id == "new-webhook-id-000001"
|
||||
assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None
|
||||
|
||||
def test_sync_webhook_relationships_sets_redis_cache_for_new_record(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_ids=["node-cache"],
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
)
|
||||
cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache"
|
||||
|
||||
with patch(
|
||||
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001"
|
||||
):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key)
|
||||
assert cached_payload is not None
|
||||
assert cached_payload["node_id"] == "node-cache"
|
||||
assert cached_payload["webhook_id"] == "cache-webhook-id-00001"
|
||||
|
||||
def test_sync_webhook_relationships_logs_when_lock_release_fails(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT
|
||||
)
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
with (
|
||||
patch("services.trigger.webhook_service.redis_client.lock", return_value=lock),
|
||||
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
|
||||
):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
mock_logger_exception.assert_called_once()
|
||||
|
||||
|
||||
def _read_cache(cache_key: str) -> dict[str, str] | None:
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
cached = redis_client.get(cache_key)
|
||||
if not cached:
|
||||
return None
|
||||
if isinstance(cached, bytes):
|
||||
cached = cached.decode("utf-8")
|
||||
return json.loads(cached)
|
||||
|
||||
|
||||
WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache)
|
||||
@ -1,5 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -13,11 +13,6 @@ from core.workflow.nodes.trigger_webhook.entities import (
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
@ -39,156 +34,13 @@ class _FakeQuery:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
class TestWebhookServiceLookup:
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1",
|
||||
is_debug=True,
|
||||
)
|
||||
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
def _workflow_trigger(**kwargs: Any) -> Any:
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
|
||||
class TestWebhookServiceExtractionFallbacks:
|
||||
@ -420,237 +272,6 @@ class TestWebhookServiceValidationAndConversion:
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
class TestWebhookServiceExecutionAndSync:
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=end_user),
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
quota_type = SimpleNamespace(
|
||||
TRIGGER=SimpleNamespace(
|
||||
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(side_effect=RuntimeError("boom")),
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit(self) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
class TestWebhookServiceUtilities:
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json(self) -> None:
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
@ -7,8 +7,6 @@ import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datase
|
||||
import DatasetCardFooter from '../components/dataset-card-footer'
|
||||
import Description from '../components/description'
|
||||
import DatasetCard from '../index'
|
||||
import OperationItem from '../operation-item'
|
||||
import Operations from '../operations'
|
||||
|
||||
// Mock external hooks only
|
||||
vi.mock('@/hooks/use-format-time-from-now', () => ({
|
||||
@ -62,8 +60,8 @@ vi.mock('../components/tag-area', () => ({
|
||||
<div ref={ref} data-testid="tag-area" onClick={onClick} />
|
||||
)),
|
||||
}))
|
||||
vi.mock('../components/operations-popover', () => ({
|
||||
default: () => <div data-testid="operations-popover" />,
|
||||
vi.mock('../components/operations-dropdown', () => ({
|
||||
default: () => <div data-testid="operations-dropdown" />,
|
||||
}))
|
||||
|
||||
// Factory function for DataSet mock data
|
||||
@ -233,152 +231,6 @@ describe('DatasetCard Integration', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Integration tests for OperationItem component
|
||||
describe('OperationItem', () => {
|
||||
const MockIcon = ({ className }: { className?: string }) => (
|
||||
<svg data-testid="mock-icon" className={className} />
|
||||
)
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render icon and name', () => {
|
||||
render(<OperationItem Icon={MockIcon as never} name="Edit" />)
|
||||
expect(screen.getByText('Edit')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('mock-icon')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should call handleClick when clicked', () => {
|
||||
const handleClick = vi.fn()
|
||||
render(<OperationItem Icon={MockIcon as never} name="Delete" handleClick={handleClick} />)
|
||||
|
||||
const item = screen.getByText('Delete').closest('div')
|
||||
fireEvent.click(item!)
|
||||
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should prevent default and stop propagation on click', () => {
|
||||
const handleClick = vi.fn()
|
||||
render(<OperationItem Icon={MockIcon as never} name="Action" handleClick={handleClick} />)
|
||||
|
||||
const item = screen.getByText('Action').closest('div')
|
||||
const event = new MouseEvent('click', { bubbles: true, cancelable: true })
|
||||
const preventDefaultSpy = vi.spyOn(event, 'preventDefault')
|
||||
const stopPropagationSpy = vi.spyOn(event, 'stopPropagation')
|
||||
|
||||
item!.dispatchEvent(event)
|
||||
|
||||
expect(preventDefaultSpy).toHaveBeenCalled()
|
||||
expect(stopPropagationSpy).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should not throw when handleClick is undefined', () => {
|
||||
render(<OperationItem Icon={MockIcon as never} name="No handler" />)
|
||||
const item = screen.getByText('No handler').closest('div')
|
||||
expect(() => {
|
||||
fireEvent.click(item!)
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle empty name', () => {
|
||||
render(<OperationItem Icon={MockIcon as never} name="" />)
|
||||
expect(screen.getByTestId('mock-icon')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Integration tests for Operations component
|
||||
describe('Operations', () => {
|
||||
const defaultProps = {
|
||||
showDelete: true,
|
||||
showExportPipeline: true,
|
||||
openRenameModal: vi.fn(),
|
||||
handleExportPipeline: vi.fn(),
|
||||
detectIsUsedByApp: vi.fn(),
|
||||
}
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should always render edit operation', () => {
|
||||
render(<Operations {...defaultProps} />)
|
||||
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render export pipeline when showExportPipeline is true', () => {
|
||||
render(<Operations {...defaultProps} showExportPipeline={true} />)
|
||||
expect(screen.getByText(/exportPipeline/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render export pipeline when showExportPipeline is false', () => {
|
||||
render(<Operations {...defaultProps} showExportPipeline={false} />)
|
||||
expect(screen.queryByText(/exportPipeline/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render delete when showDelete is true', () => {
|
||||
render(<Operations {...defaultProps} showDelete={true} />)
|
||||
expect(screen.getByText(/operation\.delete/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render delete when showDelete is false', () => {
|
||||
render(<Operations {...defaultProps} showDelete={false} />)
|
||||
expect(screen.queryByText(/operation\.delete/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should call openRenameModal when edit is clicked', () => {
|
||||
const openRenameModal = vi.fn()
|
||||
render(<Operations {...defaultProps} openRenameModal={openRenameModal} />)
|
||||
|
||||
const editItem = screen.getByText(/operation\.edit/).closest('div')
|
||||
fireEvent.click(editItem!)
|
||||
|
||||
expect(openRenameModal).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call handleExportPipeline when export is clicked', () => {
|
||||
const handleExportPipeline = vi.fn()
|
||||
render(<Operations {...defaultProps} handleExportPipeline={handleExportPipeline} />)
|
||||
|
||||
const exportItem = screen.getByText(/exportPipeline/).closest('div')
|
||||
fireEvent.click(exportItem!)
|
||||
|
||||
expect(handleExportPipeline).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call detectIsUsedByApp when delete is clicked', () => {
|
||||
const detectIsUsedByApp = vi.fn()
|
||||
render(<Operations {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
|
||||
|
||||
const deleteItem = screen.getByText(/operation\.delete/).closest('div')
|
||||
fireEvent.click(deleteItem!)
|
||||
|
||||
expect(detectIsUsedByApp).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should render only edit when both showDelete and showExportPipeline are false', () => {
|
||||
render(<Operations {...defaultProps} showDelete={false} showExportPipeline={false} />)
|
||||
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
|
||||
expect(screen.queryByText(/exportPipeline/)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText(/operation\.delete/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render divider before delete section when showDelete is true', () => {
|
||||
const { container } = render(<Operations {...defaultProps} showDelete={true} />)
|
||||
expect(container.querySelector('.bg-divider-subtle')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render divider when showDelete is false', () => {
|
||||
const { container } = render(<Operations {...defaultProps} showDelete={false} />)
|
||||
expect(container.querySelector('.bg-divider-subtle')).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('DatasetCard Component', () => {
|
||||
|
||||
@ -1,7 +1,16 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { DropdownMenu } from '@/app/components/base/ui/dropdown-menu'
|
||||
import Operations from '../operations'
|
||||
|
||||
function renderInMenu(ui: React.ReactElement) {
|
||||
return render(
|
||||
<DropdownMenu open>
|
||||
{ui}
|
||||
</DropdownMenu>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Operations', () => {
|
||||
const defaultProps = {
|
||||
showDelete: true,
|
||||
@ -17,100 +26,65 @@ describe('Operations', () => {
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
render(<Operations {...defaultProps} />)
|
||||
// Edit operation should always be visible
|
||||
renderInMenu(<Operations {...defaultProps} />)
|
||||
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render edit operation', () => {
|
||||
render(<Operations {...defaultProps} />)
|
||||
renderInMenu(<Operations {...defaultProps} />)
|
||||
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render export pipeline operation when showExportPipeline is true', () => {
|
||||
render(<Operations {...defaultProps} showExportPipeline={true} />)
|
||||
renderInMenu(<Operations {...defaultProps} showExportPipeline={true} />)
|
||||
expect(screen.getByText(/exportPipeline/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render export pipeline operation when showExportPipeline is false', () => {
|
||||
render(<Operations {...defaultProps} showExportPipeline={false} />)
|
||||
renderInMenu(<Operations {...defaultProps} showExportPipeline={false} />)
|
||||
expect(screen.queryByText(/exportPipeline/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render delete operation when showDelete is true', () => {
|
||||
render(<Operations {...defaultProps} showDelete={true} />)
|
||||
renderInMenu(<Operations {...defaultProps} showDelete={true} />)
|
||||
expect(screen.getByText(/operation\.delete/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render delete operation when showDelete is false', () => {
|
||||
render(<Operations {...defaultProps} showDelete={false} />)
|
||||
renderInMenu(<Operations {...defaultProps} showDelete={false} />)
|
||||
expect(screen.queryByText(/operation\.delete/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Props', () => {
|
||||
it('should render divider when showDelete is true', () => {
|
||||
const { container } = render(<Operations {...defaultProps} showDelete={true} />)
|
||||
const divider = container.querySelector('.bg-divider-subtle')
|
||||
expect(divider).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render divider when showDelete is false', () => {
|
||||
const { container } = render(<Operations {...defaultProps} showDelete={false} />)
|
||||
// Should not have the divider-subtle one (the separator before delete)
|
||||
expect(container.querySelector('.bg-divider-subtle')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should call openRenameModal when edit is clicked', () => {
|
||||
const openRenameModal = vi.fn()
|
||||
render(<Operations {...defaultProps} openRenameModal={openRenameModal} />)
|
||||
|
||||
const editItem = screen.getByText(/operation\.edit/).closest('div')
|
||||
fireEvent.click(editItem!)
|
||||
renderInMenu(<Operations {...defaultProps} openRenameModal={openRenameModal} />)
|
||||
|
||||
fireEvent.click(screen.getByText(/operation\.edit/))
|
||||
expect(openRenameModal).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call handleExportPipeline when export is clicked', () => {
|
||||
const handleExportPipeline = vi.fn()
|
||||
render(<Operations {...defaultProps} handleExportPipeline={handleExportPipeline} />)
|
||||
|
||||
const exportItem = screen.getByText(/exportPipeline/).closest('div')
|
||||
fireEvent.click(exportItem!)
|
||||
renderInMenu(<Operations {...defaultProps} handleExportPipeline={handleExportPipeline} />)
|
||||
|
||||
fireEvent.click(screen.getByText(/exportPipeline/))
|
||||
expect(handleExportPipeline).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call detectIsUsedByApp when delete is clicked', () => {
|
||||
const detectIsUsedByApp = vi.fn()
|
||||
render(<Operations {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
|
||||
|
||||
const deleteItem = screen.getByText(/operation\.delete/).closest('div')
|
||||
fireEvent.click(deleteItem!)
|
||||
renderInMenu(<Operations {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
|
||||
|
||||
fireEvent.click(screen.getByText(/operation\.delete/))
|
||||
expect(detectIsUsedByApp).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Styles', () => {
|
||||
it('should have correct container styling', () => {
|
||||
const { container } = render(<Operations {...defaultProps} />)
|
||||
const operationsContainer = container.firstChild
|
||||
expect(operationsContainer).toHaveClass(
|
||||
'relative',
|
||||
'flex',
|
||||
'w-full',
|
||||
'flex-col',
|
||||
'rounded-xl',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should render only edit when both showDelete and showExportPipeline are false', () => {
|
||||
render(<Operations {...defaultProps} showDelete={false} showExportPipeline={false} />)
|
||||
renderInMenu(<Operations {...defaultProps} showDelete={false} showExportPipeline={false} />)
|
||||
expect(screen.getByText(/operation\.edit/)).toBeInTheDocument()
|
||||
expect(screen.queryByText(/exportPipeline/)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText(/operation\.delete/)).not.toBeInTheDocument()
|
||||
|
||||
@ -80,7 +80,7 @@ describe('CornerLabels', () => {
|
||||
const dataset = createMockDataset({ embedding_available: false })
|
||||
const { container } = render(<CornerLabels dataset={dataset} />)
|
||||
const labelContainer = container.firstChild as HTMLElement
|
||||
expect(labelContainer).toHaveClass('absolute', 'right-0', 'top-0', 'z-10')
|
||||
expect(labelContainer).toHaveClass('absolute', 'right-0', 'top-0', 'z-5')
|
||||
})
|
||||
|
||||
it('should have correct positioning for pipeline label', () => {
|
||||
@ -90,7 +90,7 @@ describe('CornerLabels', () => {
|
||||
})
|
||||
const { container } = render(<CornerLabels dataset={dataset} />)
|
||||
const labelContainer = container.firstChild as HTMLElement
|
||||
expect(labelContainer).toHaveClass('absolute', 'right-0', 'top-0', 'z-10')
|
||||
expect(labelContainer).toHaveClass('absolute', 'right-0', 'top-0', 'z-5')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { fireEvent, render } from '@testing-library/react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||
import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets'
|
||||
import OperationsPopover from '../operations-popover'
|
||||
import OperationsDropdown from '../operations-dropdown'
|
||||
|
||||
describe('OperationsPopover', () => {
|
||||
describe('OperationsDropdown', () => {
|
||||
const createMockDataset = (overrides: Partial<DataSet> = {}): DataSet => ({
|
||||
id: 'dataset-1',
|
||||
name: 'Test Dataset',
|
||||
@ -42,102 +42,143 @@ describe('OperationsPopover', () => {
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
const { container } = render(<OperationsPopover {...defaultProps} />)
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the more icon button', () => {
|
||||
const { container } = render(<OperationsPopover {...defaultProps} />)
|
||||
const moreIcon = container.querySelector('svg')
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
const moreIcon = container.querySelector('.i-ri-more-fill')
|
||||
expect(moreIcon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render in hidden state initially (group-hover)', () => {
|
||||
const { container } = render(<OperationsPopover {...defaultProps} />)
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper).toHaveClass('hidden', 'group-hover:block')
|
||||
expect(wrapper).toHaveClass(
|
||||
'invisible',
|
||||
'pointer-events-none',
|
||||
'group-hover:visible',
|
||||
'group-hover:pointer-events-auto',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Props', () => {
|
||||
it('should show delete option when not workspace dataset operator', () => {
|
||||
render(<OperationsPopover {...defaultProps} isCurrentWorkspaceDatasetOperator={false} />)
|
||||
render(<OperationsDropdown {...defaultProps} isCurrentWorkspaceDatasetOperator={false} />)
|
||||
|
||||
const triggerButton = document.querySelector('[class*="cursor-pointer"]')
|
||||
if (triggerButton)
|
||||
fireEvent.click(triggerButton)
|
||||
|
||||
// showDelete should be true (inverse of isCurrentWorkspaceDatasetOperator)
|
||||
// This means delete operation will be visible
|
||||
})
|
||||
|
||||
it('should hide delete option when is workspace dataset operator', () => {
|
||||
render(<OperationsPopover {...defaultProps} isCurrentWorkspaceDatasetOperator={true} />)
|
||||
render(<OperationsDropdown {...defaultProps} isCurrentWorkspaceDatasetOperator={true} />)
|
||||
|
||||
const triggerButton = document.querySelector('[class*="cursor-pointer"]')
|
||||
if (triggerButton)
|
||||
fireEvent.click(triggerButton)
|
||||
|
||||
// showDelete should be false
|
||||
})
|
||||
|
||||
it('should show export pipeline when runtime_mode is rag_pipeline', () => {
|
||||
const dataset = createMockDataset({ runtime_mode: 'rag_pipeline' })
|
||||
render(<OperationsPopover {...defaultProps} dataset={dataset} />)
|
||||
render(<OperationsDropdown {...defaultProps} dataset={dataset} />)
|
||||
|
||||
const triggerButton = document.querySelector('[class*="cursor-pointer"]')
|
||||
if (triggerButton)
|
||||
fireEvent.click(triggerButton)
|
||||
|
||||
// showExportPipeline should be true
|
||||
})
|
||||
|
||||
it('should hide export pipeline when runtime_mode is not rag_pipeline', () => {
|
||||
const dataset = createMockDataset({ runtime_mode: 'general' })
|
||||
render(<OperationsPopover {...defaultProps} dataset={dataset} />)
|
||||
render(<OperationsDropdown {...defaultProps} dataset={dataset} />)
|
||||
|
||||
const triggerButton = document.querySelector('[class*="cursor-pointer"]')
|
||||
if (triggerButton)
|
||||
fireEvent.click(triggerButton)
|
||||
|
||||
// showExportPipeline should be false
|
||||
})
|
||||
})
|
||||
|
||||
describe('Styles', () => {
|
||||
it('should have correct positioning styles', () => {
|
||||
const { container } = render(<OperationsPopover {...defaultProps} />)
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper).toHaveClass('absolute', 'right-2', 'top-2', 'z-15')
|
||||
expect(wrapper).toHaveClass('absolute', 'right-2', 'top-2', 'z-5')
|
||||
})
|
||||
|
||||
it('should keep the trigger mounted when closed so menu exit animations retain an anchor', () => {
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
const trigger = container.querySelector('[aria-label="Dataset operations"]')
|
||||
|
||||
expect(wrapper).not.toHaveClass('hidden')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should have icon with correct size classes', () => {
|
||||
const { container } = render(<OperationsPopover {...defaultProps} />)
|
||||
const icon = container.querySelector('svg')
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
const icon = container.querySelector('.i-ri-more-fill')
|
||||
expect(icon).toHaveClass('h-5', 'w-5', 'text-text-tertiary')
|
||||
})
|
||||
|
||||
it('should have aria-label on trigger for accessibility', () => {
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
const trigger = container.querySelector('[aria-label="Dataset operations"]')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should expose visible keyboard focus styles on the trigger', () => {
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
const trigger = container.querySelector('[aria-label="Dataset operations"]')
|
||||
expect(trigger).toHaveClass(
|
||||
'focus-visible:outline-hidden',
|
||||
'focus-visible:ring-1',
|
||||
'focus-visible:ring-inset',
|
||||
'focus-visible:ring-components-input-border-hover',
|
||||
)
|
||||
})
|
||||
|
||||
it('should use a solid trigger background without backdrop blur on hover states', () => {
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} />)
|
||||
const trigger = container.querySelector('[aria-label="Dataset operations"]')
|
||||
expect(trigger).toHaveClass('bg-components-button-secondary-bg')
|
||||
expect(trigger).not.toHaveClass('hover:backdrop-blur-[5px]', 'backdrop-blur-[5px]')
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should keep outside interactions available when the menu is open', () => {
|
||||
const onOutsideClick = vi.fn()
|
||||
|
||||
render(
|
||||
<div>
|
||||
<button type="button" onClick={onOutsideClick}>Outside action</button>
|
||||
<OperationsDropdown {...defaultProps} />
|
||||
</div>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByLabelText('Dataset operations'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Outside action' }))
|
||||
|
||||
expect(onOutsideClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should pass openRenameModal to Operations', () => {
|
||||
const openRenameModal = vi.fn()
|
||||
render(<OperationsPopover {...defaultProps} openRenameModal={openRenameModal} />)
|
||||
|
||||
// The openRenameModal should be passed to Operations component
|
||||
expect(openRenameModal).not.toHaveBeenCalled() // Initially not called
|
||||
render(<OperationsDropdown {...defaultProps} openRenameModal={openRenameModal} />)
|
||||
expect(openRenameModal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass handleExportPipeline to Operations', () => {
|
||||
const handleExportPipeline = vi.fn()
|
||||
render(<OperationsPopover {...defaultProps} handleExportPipeline={handleExportPipeline} />)
|
||||
|
||||
render(<OperationsDropdown {...defaultProps} handleExportPipeline={handleExportPipeline} />)
|
||||
expect(handleExportPipeline).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass detectIsUsedByApp to Operations', () => {
|
||||
const detectIsUsedByApp = vi.fn()
|
||||
render(<OperationsPopover {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
|
||||
|
||||
render(<OperationsDropdown {...defaultProps} detectIsUsedByApp={detectIsUsedByApp} />)
|
||||
expect(detectIsUsedByApp).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -145,13 +186,13 @@ describe('OperationsPopover', () => {
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle dataset with external provider', () => {
|
||||
const dataset = createMockDataset({ provider: 'external' })
|
||||
const { container } = render(<OperationsPopover {...defaultProps} dataset={dataset} />)
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} dataset={dataset} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle dataset with undefined runtime_mode', () => {
|
||||
const dataset = createMockDataset({ runtime_mode: undefined })
|
||||
const { container } = render(<OperationsPopover {...defaultProps} dataset={dataset} />)
|
||||
const { container } = render(<OperationsDropdown {...defaultProps} dataset={dataset} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -14,7 +14,7 @@ const CornerLabels = ({ dataset }: CornerLabelsProps) => {
|
||||
return (
|
||||
<CornerLabel
|
||||
label={t('cornerLabel.unavailable', { ns: 'dataset' })}
|
||||
className="absolute right-0 top-0 z-10"
|
||||
className="absolute top-0 right-0 z-5"
|
||||
labelClassName="rounded-tr-xl"
|
||||
/>
|
||||
)
|
||||
@ -24,7 +24,7 @@ const CornerLabels = ({ dataset }: CornerLabelsProps) => {
|
||||
return (
|
||||
<CornerLabel
|
||||
label={t('cornerLabel.pipeline', { ns: 'dataset' })}
|
||||
className="absolute right-0 top-0 z-10"
|
||||
className="absolute top-0 right-0 z-5"
|
||||
labelClassName="rounded-tr-xl"
|
||||
/>
|
||||
)
|
||||
|
||||
@ -0,0 +1,68 @@
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import * as React from 'react'
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuTrigger,
|
||||
} from '@/app/components/base/ui/dropdown-menu'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Operations from '../operations'
|
||||
|
||||
type OperationsDropdownProps = {
|
||||
dataset: DataSet
|
||||
isCurrentWorkspaceDatasetOperator: boolean
|
||||
openRenameModal: () => void
|
||||
handleExportPipeline: (include?: boolean) => void
|
||||
detectIsUsedByApp: () => void
|
||||
}
|
||||
|
||||
const OperationsDropdown = ({
|
||||
dataset,
|
||||
isCurrentWorkspaceDatasetOperator,
|
||||
openRenameModal,
|
||||
handleExportPipeline,
|
||||
detectIsUsedByApp,
|
||||
}: OperationsDropdownProps) => {
|
||||
const [open, setOpen] = React.useState(false)
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'absolute top-2 right-2 z-5',
|
||||
open
|
||||
? 'pointer-events-auto visible'
|
||||
: 'pointer-events-none invisible group-hover:pointer-events-auto group-hover:visible',
|
||||
)}
|
||||
onClick={e => e.stopPropagation()}
|
||||
>
|
||||
<DropdownMenu modal={false} open={open} onOpenChange={setOpen}>
|
||||
<DropdownMenuTrigger
|
||||
className={cn(
|
||||
'inline-flex size-9 cursor-pointer items-center justify-center radius-lg border-[0.5px]',
|
||||
'border-components-actionbar-border bg-components-button-secondary-bg p-0 shadow-lg ring-2 shadow-shadow-shadow-5 ring-components-button-secondary-bg ring-inset',
|
||||
'transition-colors hover:border-components-actionbar-border hover:bg-state-base-hover',
|
||||
'focus-visible:bg-state-base-hover focus-visible:ring-1 focus-visible:ring-components-input-border-hover focus-visible:outline-hidden focus-visible:ring-inset',
|
||||
open && 'bg-state-base-hover',
|
||||
)}
|
||||
aria-label="Dataset operations"
|
||||
>
|
||||
<span className="i-ri-more-fill h-5 w-5 text-text-tertiary" />
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent
|
||||
placement="bottom-end"
|
||||
popupClassName="min-w-[186px]"
|
||||
>
|
||||
<Operations
|
||||
showDelete={!isCurrentWorkspaceDatasetOperator}
|
||||
showExportPipeline={dataset.runtime_mode === 'rag_pipeline'}
|
||||
openRenameModal={openRenameModal}
|
||||
handleExportPipeline={handleExportPipeline}
|
||||
detectIsUsedByApp={detectIsUsedByApp}
|
||||
/>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(OperationsDropdown)
|
||||
@ -1,52 +0,0 @@
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { RiMoreFill } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import CustomPopover from '@/app/components/base/popover'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Operations from '../operations'
|
||||
|
||||
type OperationsPopoverProps = {
|
||||
dataset: DataSet
|
||||
isCurrentWorkspaceDatasetOperator: boolean
|
||||
openRenameModal: () => void
|
||||
handleExportPipeline: (include?: boolean) => void
|
||||
detectIsUsedByApp: () => void
|
||||
}
|
||||
|
||||
const OperationsPopover = ({
|
||||
dataset,
|
||||
isCurrentWorkspaceDatasetOperator,
|
||||
openRenameModal,
|
||||
handleExportPipeline,
|
||||
detectIsUsedByApp,
|
||||
}: OperationsPopoverProps) => (
|
||||
<div className="absolute right-2 top-2 z-15 hidden group-hover:block">
|
||||
<CustomPopover
|
||||
htmlContent={(
|
||||
<Operations
|
||||
showDelete={!isCurrentWorkspaceDatasetOperator}
|
||||
showExportPipeline={dataset.runtime_mode === 'rag_pipeline'}
|
||||
openRenameModal={openRenameModal}
|
||||
handleExportPipeline={handleExportPipeline}
|
||||
detectIsUsedByApp={detectIsUsedByApp}
|
||||
/>
|
||||
)}
|
||||
className="z-20 min-w-[186px]"
|
||||
popupClassName="rounded-xl bg-none shadow-none ring-0 min-w-[186px]"
|
||||
position="br"
|
||||
trigger="click"
|
||||
btnElement={(
|
||||
<div className="flex size-8 items-center justify-center radius-lg hover:bg-state-base-hover">
|
||||
<RiMoreFill className="h-5 w-5 text-text-tertiary" />
|
||||
</div>
|
||||
)}
|
||||
btnClassName={open =>
|
||||
cn(
|
||||
'size-9 cursor-pointer justify-center radius-lg border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0 shadow-lg shadow-shadow-shadow-5 ring-2 ring-inset ring-components-actionbar-bg hover:border-components-actionbar-border',
|
||||
open ? 'border-components-actionbar-border bg-state-base-hover' : '',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
|
||||
export default React.memo(OperationsPopover)
|
||||
@ -9,7 +9,7 @@ import DatasetCardFooter from './components/dataset-card-footer'
|
||||
import DatasetCardHeader from './components/dataset-card-header'
|
||||
import DatasetCardModals from './components/dataset-card-modals'
|
||||
import Description from './components/description'
|
||||
import OperationsPopover from './components/operations-popover'
|
||||
import OperationsDropdown from './components/operations-dropdown'
|
||||
import TagArea from './components/tag-area'
|
||||
import { useDatasetCardState } from './hooks/use-dataset-card-state'
|
||||
|
||||
@ -82,7 +82,7 @@ const DatasetCard = ({
|
||||
onClick={handleTagAreaClick}
|
||||
/>
|
||||
<DatasetCardFooter dataset={dataset} />
|
||||
<OperationsPopover
|
||||
<OperationsDropdown
|
||||
dataset={dataset}
|
||||
isCurrentWorkspaceDatasetOperator={isCurrentWorkspaceDatasetOperator}
|
||||
openRenameModal={openRenameModal}
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import { RiDeleteBinLine, RiEditLine, RiFileDownloadLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import OperationItem from './operation-item'
|
||||
import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
} from '@/app/components/base/ui/dropdown-menu'
|
||||
|
||||
type OperationsProps = {
|
||||
showDelete: boolean
|
||||
@ -22,34 +23,27 @@ const Operations = ({
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="relative flex w-full flex-col rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg shadow-shadow-shadow-5">
|
||||
<div className="flex flex-col p-1">
|
||||
<OperationItem
|
||||
Icon={RiEditLine}
|
||||
name={t('operation.edit', { ns: 'common' })}
|
||||
handleClick={openRenameModal}
|
||||
/>
|
||||
{showExportPipeline && (
|
||||
<OperationItem
|
||||
Icon={RiFileDownloadLine}
|
||||
name={t('operations.exportPipeline', { ns: 'datasetPipeline' })}
|
||||
handleClick={handleExportPipeline}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<>
|
||||
<DropdownMenuItem onClick={openRenameModal}>
|
||||
<span aria-hidden className="i-ri-edit-line size-4 text-text-tertiary" />
|
||||
{t('operation.edit', { ns: 'common' })}
|
||||
</DropdownMenuItem>
|
||||
{showExportPipeline && (
|
||||
<DropdownMenuItem onClick={handleExportPipeline}>
|
||||
<span aria-hidden className="i-ri-file-download-line size-4 text-text-tertiary" />
|
||||
{t('operations.exportPipeline', { ns: 'datasetPipeline' })}
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
{showDelete && (
|
||||
<>
|
||||
<Divider type="horizontal" className="my-0 bg-divider-subtle" />
|
||||
<div className="flex flex-col p-1">
|
||||
<OperationItem
|
||||
Icon={RiDeleteBinLine}
|
||||
name={t('operation.delete', { ns: 'common' })}
|
||||
handleClick={detectIsUsedByApp}
|
||||
/>
|
||||
</div>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem destructive onClick={detectIsUsedByApp}>
|
||||
<span aria-hidden className="i-ri-delete-bin-line size-4" />
|
||||
{t('operation.delete', { ns: 'common' })}
|
||||
</DropdownMenuItem>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ import { useBoolean, useDebounceFn } from 'ahooks'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development'
|
||||
import Input from '@/app/components/base/input'
|
||||
import TagManagementModal from '@/app/components/base/tag-management'
|
||||
import TagFilter from '@/app/components/base/tag-management/filter'
|
||||
@ -85,11 +84,11 @@ const List = () => {
|
||||
}
|
||||
<div className="h-4 w-px bg-divider-regular" />
|
||||
<Button
|
||||
className="shadows-shadow-xs gap-0.5"
|
||||
className="gap-0.5 shadow-xs"
|
||||
onClick={() => setShowExternalApiPanel(true)}
|
||||
>
|
||||
<ApiConnectionMod className="h-4 w-4 text-components-button-secondary-text" />
|
||||
<div className="flex items-center justify-center gap-1 px-0.5 system-sm-medium text-components-button-secondary-text">{t('externalAPIPanelTitle', { ns: 'dataset' })}</div>
|
||||
<span className="i-custom-vender-solid-development-api-connection-mod h-4 w-4 text-components-button-secondary-text" />
|
||||
<span className="flex items-center justify-center gap-1 px-0.5 system-sm-medium text-components-button-secondary-text">{t('externalAPIPanelTitle', { ns: 'dataset' })}</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -4503,11 +4503,6 @@
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/datasets/list/dataset-card/components/corner-labels.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
@ -4526,14 +4521,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/list/dataset-card/components/operations-popover.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/datasets/list/dataset-card/components/tag-area.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
|
||||
Reference in New Issue
Block a user