Compare commits

..

12 Commits

Author SHA1 Message Date
b62965034e refactor: document_indexing_sync_task split db session (#32129)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-09 17:16:17 +08:00
016d72a8c6 fix: fix trigger output schema miss (#32116) 2026-02-09 17:16:08 +08:00
125f7e3ab4 refactor: document_indexing_update_task split database session (#32105)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-09 10:51:45 +08:00
400ed2fd72 refactor: partition Celery task sessions into smaller, discrete execu… (#32085)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-08 21:05:03 +08:00
840a8f3fc2 perf: use batch delete method instead of single delete (#32036)
Co-authored-by: fatelei <fatelei@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: FFXN <lizy@dify.ai>
2026-02-06 15:13:17 +08:00
b4a5296fd1 fix: fix tool type is miss (#32042) 2026-02-06 14:38:54 +08:00
fcb53383df fix: fix agent node tool type is not right (#32008)
Infer real tool type via querying relevant database tables.

The root cause for incorrect `type` field is still not clear.
2026-02-06 11:25:29 +08:00
540e1db83c perf(api): Optimize the response time of AppListApi endpoint (#31999) 2026-02-06 10:46:25 +08:00
2f75e38c08 fix: fix miss use db.session (#31971) 2026-02-05 15:59:37 +08:00
cd03e0a9ef fix: fix delete_draft_variables_batch cycle forever (#31934)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-04 19:42:50 +08:00
df2421d187 fix: auto summary env (#31930) 2026-02-04 19:42:26 +08:00
0ba321d840 chore: bump version in docker-compose and package manager to 1.12.1 (#31947) 2026-02-04 19:41:50 +08:00
72 changed files with 2883 additions and 5564 deletions

View File

@ -136,6 +136,7 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.template_transform.template_transform_node -> configs
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager

View File

@ -38,7 +38,6 @@ from . import (
extension,
feature,
init_validate,
notification,
ping,
setup,
spec,
@ -183,7 +182,6 @@ __all__ = [
"model_config",
"model_providers",
"models",
"notification",
"oauth",
"oauth_server",
"ops_trace",

View File

@ -1,5 +1,3 @@
import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
@ -8,7 +6,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
@ -18,7 +16,6 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
@ -280,113 +277,3 @@ class DeleteExploreBannerApi(Resource):
db.session.commit()
return {"result": "success"}, 204
class SaveNotificationContentPayload(BaseModel):
content: str = Field(...)
class SaveNotificationUserPayload(BaseModel):
user_email: list[str] = Field(...)
console_ns.schema_model(
SaveNotificationContentPayload.__name__,
SaveNotificationContentPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
SaveNotificationUserPayload.__name__,
SaveNotificationUserPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/admin/save_notification_content")
class SaveNotificationContentApi(Resource):
@console_ns.doc("save_notification_content")
@console_ns.doc(description="Save a notification content")
@console_ns.expect(console_ns.models[SaveNotificationContentPayload.__name__])
@console_ns.response(200, "Notification content saved successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = SaveNotificationContentPayload.model_validate(console_ns.payload)
BillingService.save_notification_content(payload.content)
return {"result": "success"}, 200
@console_ns.route("/admin/save_notification_user")
class SaveNotificationUserApi(Resource):
@console_ns.doc("save_notification_user")
@console_ns.doc(description="Save notification users via JSON body or file upload. "
"JSON: {\"user_email\": [\"a@example.com\", ...]}. "
"File: multipart/form-data with a 'file' field (CSV or TXT, one email per line).")
@console_ns.response(200, "Notification users saved successfully")
@only_edition_cloud
@admin_required
def post(self):
# Determine input mode: file upload or JSON body
if "file" in request.files:
emails = self._parse_emails_from_file()
else:
payload = SaveNotificationUserPayload.model_validate(console_ns.payload)
emails = payload.user_email
if not emails:
raise BadRequest("No valid email addresses provided.")
# Use batch API for bulk insert (chunks of 1000 per request to billing service)
result = BillingService.save_notification_users_batch(emails)
return {
"result": "success",
"total": len(emails),
"succeeded": result["succeeded"],
"failed_chunks": result["failed_chunks"],
}, 200
@staticmethod
def _parse_emails_from_file() -> list[str]:
"""Parse email addresses from an uploaded CSV or TXT file."""
file = request.files["file"]
if not file.filename:
raise BadRequest("Uploaded file has no filename.")
filename_lower = file.filename.lower()
if not filename_lower.endswith((".csv", ".txt")):
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
# Read file content
try:
content = file.read().decode("utf-8")
except UnicodeDecodeError:
try:
file.seek(0)
content = file.read().decode("gbk")
except UnicodeDecodeError:
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
emails: list[str] = []
if filename_lower.endswith(".csv"):
reader = csv.reader(io.StringIO(content))
for row in reader:
for cell in row:
cell = cell.strip()
emails.append(cell)
else:
# TXT file: one email per line
for line in content.splitlines():
line = line.strip()
emails.append(line)
# Deduplicate while preserving order
seen: set[str] = set()
unique_emails: list[str] = []
for email in emails:
email_lower = email.lower()
if email_lower not in seen:
seen.add(email_lower)
unique_emails.append(email)
return unique_emails

View File

@ -1,3 +1,4 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
@ -499,6 +502,7 @@ class AppListApi(Resource):
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
Workflow.tenant_id == current_tenant_id,
)
)
.scalars()
@ -510,12 +514,14 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
node_id = None
try:
for _, node_data in workflow.walk_nodes():
for node_id, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
continue
for app in app_pagination.items:

View File

@ -1,16 +1,15 @@
import logging
from typing import Any, Literal, cast
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.common.schema import get_or_create_model
from controllers.console import api, console_ns
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@ -118,56 +117,7 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
# Pydantic models for request validation
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowRunRequest(BaseModel):
inputs: dict
files: list | None = None
class ChatRequest(BaseModel):
inputs: dict
query: str
files: list | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = "explore_app"
class TextToSpeechRequest(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
class CompletionRequest(BaseModel):
inputs: dict
query: str = ""
files: list | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = "explore_app"
# Register schemas for Swagger documentation
console_ns.schema_model(
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class TrialAppWorkflowRunApi(TrialAppResource):
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
def post(self, trial_app):
"""
Run workflow
@ -179,8 +129,10 @@ class TrialAppWorkflowRunApi(TrialAppResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
assert current_user is not None
try:
app_id = app_model.id
@ -231,7 +183,6 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
class TrialChatApi(TrialAppResource):
@console_ns.expect(console_ns.models[ChatRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
@ -239,14 +190,14 @@ class TrialChatApi(TrialAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
request_data = ChatRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
# Validate UUID values if provided
if args.get("conversation_id"):
args["conversation_id"] = uuid_value(args["conversation_id"])
if args.get("parent_message_id"):
args["parent_message_id"] = uuid_value(args["parent_message_id"])
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
@ -369,16 +320,20 @@ class TrialChatAudioApi(TrialAppResource):
class TrialChatTextApi(TrialAppResource):
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = request_data.message_id
text = request_data.text
voice = request_data.voice
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@ -416,15 +371,19 @@ class TrialChatTextApi(TrialAppResource):
class TrialCompletionApi(TrialAppResource):
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
request_data = CompletionRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False

View File

@ -1,26 +0,0 @@
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@console_ns.route("/notification")
class NotificationApi(Resource):
@console_ns.doc("get_notification")
@console_ns.doc(description="Get notification for the current user")
@console_ns.doc(
responses={
200: "Success",
401: "Unauthorized",
}
)
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, _ = current_account_with_tenant()
notification = BillingService.read_notification(current_user.email)
return notification

View File

@ -47,7 +47,6 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
template_transform_max_output_length: int | None = None,
http_request_http_client: HttpClientProtocol | None = None,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol | None = None,
@ -69,9 +68,6 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = (
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
)
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
@ -126,7 +122,6 @@ class DifyNodeFactory(NodeFactory):
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
max_output_length=self._template_transform_max_output_length,
)
if node_type == NodeType.HTTP_REQUEST:

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@ -15,13 +16,12 @@ if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
_max_output_length: int
def __init__(
self,
@ -31,7 +31,6 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
max_output_length: int | None = None,
) -> None:
super().__init__(
id=id,
@ -41,10 +40,6 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
if max_output_length is not None and max_output_length <= 0:
raise ValueError("max_output_length must be a positive integer")
self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
@ -74,11 +69,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(rendered) > self._max_output_length:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Output length exceeds {self._max_output_length} characters",
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
)
return NodeRunResult(

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.12.0"
version = "1.12.1"
requires-python = ">=3.11,<3.13"
dependencies = [
@ -81,7 +81,7 @@ dependencies = [
"starlette==0.49.1",
"tiktoken~=0.9.0",
"transformers~=4.56.1",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
"sseclient-py~=1.8.0",

View File

@ -393,35 +393,3 @@ class BillingService:
for item in data:
tenant_whitelist.append(item["tenant_id"])
return tenant_whitelist
@classmethod
def read_notification(cls, user_email: str):
params = {"user_email": user_email}
return cls._send_request("GET", "/notification/read", params=params)
@classmethod
def save_notification_user(cls, user_email: str):
json = {"user_email": user_email}
return cls._send_request("POST", "/notification/new-notification-user", json=json)
@classmethod
def save_notification_users_batch(cls, user_emails: list[str]) -> dict:
"""Batch save notification users in chunks of 1000."""
chunk_size = 1000
total_succeeded = 0
failed_chunks: list[dict] = []
for i in range(0, len(user_emails), chunk_size):
chunk = user_emails[i : i + chunk_size]
try:
resp = cls._send_request("POST", "/notification/batch-notification-users", json={"user_emails": chunk})
total_succeeded += resp.get("count", len(chunk))
except Exception as e:
failed_chunks.append({"offset": i, "count": len(chunk), "error": str(e)})
return {"succeeded": total_succeeded, "failed_chunks": failed_chunks}
@classmethod
def save_notification_content(cls, content: str):
json = {"content": content}
return cls._send_request("POST", "/notification/new-notification", json=json)

View File

@ -1,7 +1,6 @@
from flask_login import current_user
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
from services.account_service import TenantService
@ -54,12 +53,7 @@ class WorkspaceService:
from services.credit_pool_service import CreditPoolService
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
# if the tenant is not on the sandbox plan and the paid pool is not full, use the paid pool
if (
feature.billing.subscription.plan != CloudPlan.SANDBOX
and paid_pool is not None
and paid_pool.quota_limit > paid_pool.quota_used
):
if paid_pool:
tenant_info["trial_credits"] = paid_pool.quota_limit
tenant_info["trial_credits_used"] = paid_pool.quota_used
else:

View File

@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -58,5 +57,3 @@ def add_annotation_to_index_task(
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@ -5,7 +5,6 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -59,5 +58,3 @@ def update_annotation_to_index_task(
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@ -14,6 +14,9 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
# Batch size for database operations to keep transactions short
BATCH_SIZE = 1000
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if not doc_form:
raise ValueError("doc_form is required")
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
storage_keys_to_delete: list[str] = []
index_node_ids: list[str] = []
segment_ids: list[str] = []
total_image_upload_file_ids: list[str] = []
try:
# ============ Step 1: Query segment and file data (short read-only transaction) ============
with session_factory.create_session() as session:
# Get segments info
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
# Collect image file IDs from segment content
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
session.delete(segment)
total_image_upload_file_ids.extend(image_upload_file_ids)
# Query storage keys for image files
if total_image_upload_file_ids:
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
).all()
storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
# Query storage keys for document files
if file_ids:
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
storage_keys_to_delete.extend([f.key for f in files if f and f.key])
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
# ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
if index_node_ids:
try:
# Fetch dataset in a fresh session to avoid DetachedInstanceError
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
else:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception(
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
dataset_id,
document_ids,
len(index_node_ids),
)
)
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = (
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
.delete(synchronize_session=False)
)
session.commit()
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
except Exception:
logger.exception("Cleaned documents when documents deleted failed")
logger.exception(
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
if total_image_upload_file_ids:
failed_batches = 0
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
session.execute(stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
i,
i + len(batch),
dataset_id,
)
if failed_batches > 0:
logger.warning(
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
failed_batches,
total_batches,
dataset_id,
)
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
if segment_ids:
failed_batches = 0
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(segment_ids), BATCH_SIZE):
batch = segment_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
session.execute(segment_delete_stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
i,
i + len(batch),
dataset_id,
document_ids,
)
if failed_batches > 0:
logger.warning(
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
failed_batches,
total_batches,
document_ids,
)
# ============ Step 6: Delete document-associated files (separate short transaction) ============
if file_ids:
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
session.commit()
except Exception:
logger.exception(
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
dataset_id,
file_ids,
)
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
storage_delete_failures = 0
for storage_key in storage_keys_to_delete:
try:
storage.delete(storage_key)
except Exception:
storage_delete_failures += 1
logger.exception("Failed to delete file from storage, key: %s", storage_key)
if storage_delete_failures > 0:
logger.warning(
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
storage_delete_failures,
len(storage_keys_to_delete),
dataset_id,
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
f"storage_files: {len(storage_keys_to_delete)}",
fg="green",
)
)
except Exception:
logger.exception(
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)

View File

@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
# Initialize variables with default values
upload_file_key: str | None = None
dataset_config: dict | None = None
document_config: dict | None = None
with session_factory.create_session() as session:
try:
dataset = session.get(Dataset, dataset_id)
@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
dataset_config = {
"id": dataset.id,
"indexing_technique": dataset.indexing_technique,
"tenant_id": dataset.tenant_id,
"embedding_model_provider": dataset.embedding_model_provider,
"embedding_model": dataset.embedding_model,
}
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_config = {
"id": dataset_document.id,
"doc_form": dataset_document.doc_form,
"word_count": dataset_document.word_count or 0,
}
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
upload_file_key = upload_file.key
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
return
# Ensure required variables are set before proceeding
if upload_file_key is None or dataset_config is None or document_config is None:
logger.error("Required configuration not set due to session error")
redis_client.setex(indexing_cache_key, 600, "error")
return
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file_key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file_key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
tokens_list = [0] * len(content)
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
document_segments = []
embedding_model = None
if dataset_config["indexing_technique"] == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset_config["tenant_id"],
provider=dataset_config["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=dataset_config["embedding_model"],
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
else:
tokens_list = [0] * len(content)
with session_factory.create_session() as session, session.begin():
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document_config["id"])
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
with session_factory.create_session() as session, session.begin():
dataset_document = session.get(Document, document_id)
if dataset_document:
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
session.add(dataset_document)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
with session_factory.create_session() as session:
dataset = session.get(Dataset, dataset_id)
if dataset:
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)

View File

@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
"""
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
total_attachment_files = []
with session_factory.create_session() as session:
try:
@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
index_node_ids = [segment.index_node_id for segment in segments]
segment_contents = [segment.content for segment in segments]
except Exception:
logger.exception("Cleaned document when document deleted failed")
return
# check segment is exist
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
total_image_files = []
with session_factory.create_session() as session, session.begin():
for segment_content in segment_contents:
image_upload_file_ids = get_image_upload_file_ids(segment_content)
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
total_image_files.extend([image_file.key for image_file in image_files])
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
session.delete(segment)
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
session.commit()
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
for image_file_key in total_image_files:
try:
storage.delete(image_file_key)
except Exception:
logger.exception("Cleaned document when document deleted failed")
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file_key,
)
with session_factory.create_session() as session, session.begin():
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
with session_factory.create_session() as session, session.begin():
# delete segment attachments
if attachment_ids:
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
if binding_ids:
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
session.execute(binding_delete_stmt)
for attachment_file_key in total_attachment_files:
try:
storage.delete(attachment_file_key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
attachment_file_key,
)
with session_factory.create_session() as session, session.begin():
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)

View File

@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
"""
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
total_index_node_ids = []
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
for document_id in document_ids:
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
for document_id in document_ids:
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
total_index_node_ids.extend([segment.index_node_id for segment in segments])
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -67,8 +68,14 @@ def delete_segment_from_index_task(
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
session.delete(binding)
segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
for i in range(0, len(segment_attachment_bind_ids), 1000):
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
)
session.execute(segment_attachment_bind_delete_stmt)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.commit()

View File

@ -27,104 +27,129 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
"""
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
tenant_id = None
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status == "parsing":
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
if document.data_source_type != "notion_import":
logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
return
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
tenant_id = document.tenant_id
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
tenant_id,
credential_id,
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
session.commit()
return
return
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=tenant_id,
)
last_edited_time = loader.get_notion_last_edited_time()
last_edited_time = loader.get_notion_last_edited_time()
if last_edited_time == page_edited_time:
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
return
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
except Exception:
logger.exception("Failed to clean vector index for document %s", document_id)
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if not document:
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
data_source_info = document.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
try:
indexing_runner = IndexingRunner()
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if document:
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()

View File

@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
# Phase 1: Update status to parsing (short transaction)
with session_factory.create_session() as session, session.begin():
documents = (
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
)
for document in documents:
if document:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document)
session.commit()
# Transaction committed and closed
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
has_error = False
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
has_error = True
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
has_error = True
if not has_error:
with session_factory.create_session() as session:
# Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated)
@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# expire all session to get latest document's indexing status
session.expire_all()
# Check each document's indexing status and trigger summary generation if completed
for document_id in document_ids:
# Re-query document to get latest status (IndexingRunner may have updated it)
document = (
session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
documents = (
session.query(Document)
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
.all()
)
for document in documents:
if document:
logger.info(
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
document_id,
document.id,
document.indexing_status,
document.doc_form,
document.need_summary,
@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
and document.need_summary is True
):
try:
generate_summary_index_task.delay(dataset.id, document_id, None)
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info(
"Queued summary index generation task for document %s in dataset %s "
"after indexing completed",
document_id,
document.id,
dataset.id,
)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document_id,
document.id,
)
# Don't fail the entire indexing process if summary task queuing fails
else:
logger.info(
"Skipping summary generation for document %s: "
"status=%s, doc_form=%s, need_summary=%s",
document_id,
document.id,
document.indexing_status,
document.doc_form,
document.need_summary,
)
else:
logger.warning("Document %s not found after indexing", document_id)
else:
logger.info(
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
dataset.id,
summary_index_setting.get("enable") if summary_index_setting else None,
)
logger.warning("Document %s not found after indexing", document.id)
else:
logger.info(
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
dataset.id,
dataset.indexing_technique,
)
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(

View File

@ -8,7 +8,6 @@ from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
@ -36,27 +35,20 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
db.session.commit()
clean_success = False
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if index_node_ids:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter()
logger.info(
click.style(
@ -66,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
clean_success = True
except Exception:
logger.exception("Failed to clean document index during update, document_id: %s", document_id)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
if clean_success:
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)

View File

@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
def del_workflow_archive_log(workflow_archive_log_id: str):
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
def del_workflow_archive_log(session, workflow_archive_log_id: str):
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
)
@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables

View File

@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
"""
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from extensions.ext_database import db
from core.db.session_factory import session_factory
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
@ -17,6 +16,6 @@ def save_workflow_execution_task(
self,
deletions: list[DraftVarFileDeletion],
):
with Session(bind=db.engine) as session, session.begin():
with session_factory.create_session() as session, session.begin():
srv = WorkflowDraftVariableService(session=session)
srv.delete_workflow_draft_variable_file(deletions=deletions)

View File

@ -10,7 +10,10 @@ from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
from tasks.remove_app_and_related_data_task import (
_delete_draft_variables,
delete_draft_variables_batch,
)
@pytest.fixture
@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.return_value = None
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = session.query(WorkflowDraftVariableFile).count()
upload_files_before = session.query(UploadFile).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0
assert upload_files_after == 0
@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.side_effect = [Exception("Storage error"), None]
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0
assert upload_files_after == 0
@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
if app2_obj:
session.delete(app2_obj)
session.commit()
class TestDeleteDraftVariablesSessionCommit:
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with offload files for session commit tests."""
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
tenant, app = app_and_tenant
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file1)
session.add(upload_file2)
session.flush()
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
session.add(var_file1)
session.add(var_file2)
session.flush()
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(draft_var1)
session.add(draft_var2)
session.add(draft_var3)
session.commit()
data = {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
yield data
with session_factory.create_session() as session:
for table, ids in [
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
(UploadFile, [uf.id for uf in data["upload_files"]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
session.execute(cleanup_query)
session.commit()
@pytest.fixture
def setup_commit_test_data(self, app_and_tenant):
"""Create test data for session commit tests."""
tenant, app = app_and_tenant
variable_ids: list[str] = []
with session_factory.create_session() as session:
variables = []
for i in range(10):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
variables.append(var)
session.commit()
variable_ids = [v.id for v in variables]
yield {
"app": app,
"tenant": tenant,
"variable_ids": variable_ids,
}
with session_factory.create_session() as session:
cleanup_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
.execution_options(synchronize_session=False)
)
session.execute(cleanup_query)
session.commit()
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
"""Test that session.begin() is used for automatic transaction management."""
data = setup_commit_test_data
app_id = data["app"].id
# Since session.begin() is used, the transaction is automatically committed
# when the with block exits successfully. We verify this by checking that
# data is actually persisted.
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
# Verify all data was deleted (proves transaction was committed)
with session_factory.create_session() as session:
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert deleted_count == 10
assert remaining_count == 0
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
"""Test that data is actually persisted to database after batch deletion with commits."""
data = setup_commit_test_data
app_id = data["app"].id
variable_ids = data["variable_ids"]
# Verify initial state
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Perform deletion with small batch size to force multiple commits
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
assert deleted_count == 10
# Verify all data is deleted in a new session (proves commits worked)
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
# Verify specific IDs are deleted
with session_factory.create_session() as session:
remaining_vars = (
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
)
assert remaining_vars == 0
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
"""Test session behavior when deleting from an empty dataset."""
nonexistent_app_id = str(uuid.uuid4())
# Should not raise any errors and should return 0
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
assert deleted_count == 0
def test_session_commit_with_single_batch(self, setup_commit_test_data):
"""Test that commit happens correctly when all data fits in a single batch."""
data = setup_commit_test_data
app_id = data["app"].id
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Delete all in a single batch
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
assert deleted_count == 10
# Verify data is persisted
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
"""Test that invalid batch size raises ValueError."""
data = setup_commit_test_data
app_id = data["app"].id
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=0)
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=-1)
@patch("extensions.ext_storage.storage")
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that session commits correctly when cleaning up offload data."""
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
mock_storage.delete.return_value = None
# Verify initial state
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
# Delete variables with offload data
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count == 3
# Verify all data is persisted (deleted) in new session
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_after == 0
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage cleanup was called
assert mock_storage.delete.call_count == 2

View File

@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage.download.side_effect = mock_download
# Execute the task
# Execute the task - should raise ValueError for empty CSV
job_id = str(uuid.uuid4())
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
with pytest.raises(ValueError, match="The CSV file is empty"):
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
# Verify error handling
# Check Redis cache was set to error status
from extensions.ext_redis import redis_client
cache_key = f"segment_batch_import_{job_id}"
cache_value = redis_client.get(cache_key)
assert cache_value == b"error"
# Verify no segments were created
# Since exception was raised, no segments should be created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()

View File

@ -153,8 +153,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task(document_ids, dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids))
@ -162,9 +161,9 @@ class TestCleanNotionDocumentTask:
== 0
)
# Verify index processor was called for each document
# Verify index processor was called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
assert mock_processor.clean.call_count == len(document_ids)
mock_processor.clean.assert_called_once()
# This test successfully verifies:
# 1. Document records are properly deleted from the database
@ -186,12 +185,12 @@ class TestCleanNotionDocumentTask:
non_existent_dataset_id = str(uuid.uuid4())
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Execute cleanup task with non-existent dataset
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Execute cleanup task with non-existent dataset - expect exception
with pytest.raises(Exception, match="Document has no dataset"):
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Verify that the index processor was not called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called()
# Verify that the index processor factory was not used
mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
def test_clean_notion_document_task_empty_document_list(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@ -229,9 +228,13 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task with empty document list
clean_notion_document_task([], dataset.id)
# Verify that the index processor was not called
# Verify that the index processor was called once with empty node list
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called()
assert mock_processor.clean.call_count == 1
args, kwargs = mock_processor.clean.call_args
# args: (dataset, total_index_node_ids)
assert isinstance(args[0], Dataset)
assert args[1] == []
def test_clean_notion_document_task_with_different_index_types(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@ -315,8 +318,7 @@ class TestCleanNotionDocumentTask:
# Note: This test successfully verifies cleanup with different document types.
# The task properly handles various index types and document configurations.
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == document.id)
@ -404,8 +406,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0
@ -508,8 +509,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task(documents_to_clean, dataset.id)
# Verify only specified documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
# Verify only specified documents' segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(documents_to_clean))
@ -697,11 +697,12 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Mock index processor to raise an exception
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_index_processor.clean.side_effect = Exception("Index processor error")
# Execute cleanup task - it should handle the exception gracefully
clean_notion_document_task([document.id], dataset.id)
# Execute cleanup task - current implementation propagates the exception
with pytest.raises(Exception, match="Index processor error"):
clean_notion_document_task([document.id], dataset.id)
# Note: This test demonstrates the task's error handling capability.
# Even with external service errors, the database operations complete successfully.
@ -803,8 +804,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
# Verify all segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0
@ -914,8 +914,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task([target_document.id], target_dataset.id)
# Verify only documents from target dataset are deleted
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
# Verify only documents' segments from target dataset are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == target_document.id)
@ -1030,8 +1029,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted regardless of status
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
# Verify all segments are deleted regardless of status
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0
@ -1142,8 +1140,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0

View File

@ -0,0 +1,182 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_update_task import document_indexing_update_task
class TestDocumentIndexingUpdateTask:
@pytest.fixture
def mock_external_dependencies(self):
"""Patch external collaborators used by the update task.
- IndexProcessorFactory.init_index_processor().clean(...)
- IndexingRunner.run([...])
"""
with (
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
):
processor_instance = MagicMock()
mock_factory.return_value.init_index_processor.return_value = processor_instance
runner_instance = MagicMock()
mock_runner.return_value = runner_instance
yield {
"factory": mock_factory,
"processor": processor_instance,
"runner": mock_runner,
"runner_instance": runner_instance,
}
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
fake = Faker()
# Account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=fake.company(), status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Dataset and document
dataset = Dataset(
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=64),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
document = Document(
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Segments
node_ids = []
for i in range(segment_count):
node_id = f"node-{i + 1}"
seg = DocumentSegment(
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
content=fake.text(max_nb_chars=32),
answer=None,
word_count=10,
tokens=5,
index_node_id=node_id,
status="completed",
created_by=account.id,
)
db_session_with_containers.add(seg)
node_ids.append(node_id)
db_session_with_containers.commit()
# Refresh to ensure ORM state
db_session_with_containers.refresh(dataset)
db_session_with_containers.refresh(document)
return dataset, document, node_ids
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Act
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Assert document status updated before reindex
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
assert updated.indexing_status == "parsing"
assert updated.processing_started_at is not None
# Segments should be deleted
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining == 0
# Assert index processor clean was called with expected args
clean_call = mock_external_dependencies["processor"].clean.call_args
assert clean_call is not None
args, kwargs = clean_call
# args[0] is a Dataset instance (from another session) — validate by id
assert getattr(args[0], "id", None) == dataset.id
# args[1] should contain our node_ids
assert set(args[1]) == set(node_ids)
assert kwargs.get("with_keywords") is True
assert kwargs.get("delete_child_chunks") is True
# Assert indexing runner invoked with the updated document
run_call = mock_external_dependencies["runner_instance"].run.call_args
assert run_call is not None
run_docs = run_call[0][0]
assert len(run_docs) == 1
first = run_docs[0]
assert getattr(first, "id", None) == document.id
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Force clean to raise; task should continue to indexing
mock_external_dependencies["processor"].clean.side_effect = Exception("boom")
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Indexing should still be triggered
mock_external_dependencies["runner_instance"].run.assert_called_once()
# Segments should remain (since clean failed before DB delete)
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining > 0
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
fake = Faker()
# Act with non-existent document id
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())
# Neither processor nor runner should be called
mock_external_dependencies["processor"].clean.assert_not_called()
mock_external_dependencies["runner_instance"].run.assert_not_called()

View File

@ -217,6 +217,7 @@ class TestTemplateTransformNode:
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
def test_run_output_length_exceeds_limit(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
@ -230,7 +231,6 @@ class TestTemplateTransformNode:
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
max_output_length=10,
)
result = node._run()

View File

@ -4,7 +4,7 @@ from typing import Any
from uuid import uuid4
import pytest
from hypothesis import given, settings
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from core.file import File, FileTransferMethod, FileType
@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
)
@settings(max_examples=50)
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value)
@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
assert seg.value == value
@settings(max_examples=50)
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@given(values=st.lists(_scalar_value(), max_size=20))
def test_build_segment_and_extract_values_for_array_types(values):
seg = variable_factory.build_segment(values)

View File

@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests that expect session.close() to be called can observe it via the context manager
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
sessions = [] # Track all created sessions
# Shared mock data that all sessions will access
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
def _exit_side_effect(*args, **kwargs):
session.close()
def create_session_side_effect():
session = MagicMock()
session.close = MagicMock()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
# Track commit calls
commit_mock = MagicMock()
session.commit = commit_mock
cm = MagicMock()
cm.__enter__.return_value = session
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
# Support session.begin() for transactions
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def begin_exit_side_effect(*args, **kwargs):
# Auto-commit on transaction exit (like SQLAlchemy)
session.commit()
# Also mark wrapper's commit as called
if sessions:
sessions[0].commit()
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
session.begin = MagicMock(return_value=begin_cm)
sessions.append(session)
# Setup query with side_effect to handle both Dataset and Document queries
def query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
# Support both .first() and .all() calls with chaining
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
# Create an iterator for .first() calls if not exists
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
session.query = MagicMock(side_effect=query_side_effect)
return cm
mock_sf.create_session.side_effect = create_session_side_effect
# Create a wrapper that behaves like the first session but has access to all sessions
class SessionWrapper:
def __init__(self):
self._sessions = sessions
self._shared_data = shared_mock_data
# Create a default session for setup phase
self._default_session = MagicMock()
self._default_session.close = MagicMock()
self._default_session.commit = MagicMock()
# Support session.begin() for default session too
begin_cm = MagicMock()
begin_cm.__enter__.return_value = self._default_session
def default_begin_exit_side_effect(*args, **kwargs):
self._default_session.commit()
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
self._default_session.begin = MagicMock(return_value=begin_cm)
def default_query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
def __getattr__(self, name):
# Forward all attribute access to the first session, or default if none created yet
target_session = self._sessions[0] if self._sessions else self._default_session
return getattr(target_session, name)
@property
def all_sessions(self):
"""Access all created sessions for testing."""
return self._sessions
wrapper = SessionWrapper()
yield wrapper
@pytest.fixture
@ -252,18 +356,9 @@ class TestTaskEnqueuing:
use the deprecated function.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -304,21 +399,9 @@ class TestBatchProcessing:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Create an iterator for documents
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -357,19 +440,9 @@ class TestBatchProcessing:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
@ -407,19 +480,9 @@ class TestBatchProcessing:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
@ -444,7 +507,10 @@ class TestBatchProcessing:
"""
# Arrange
document_ids = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Set shared mock data with empty documents list
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = []
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -482,19 +548,9 @@ class TestProgressTracking:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -528,19 +584,9 @@ class TestProgressTracking:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -635,19 +681,9 @@ class TestErrorHandling:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set up to trigger vector space limit error
mock_feature_service.get_features.return_value.billing.enabled = True
@ -674,17 +710,9 @@ class TestErrorHandling:
Errors during indexing should be caught and logged, but not crash the task.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
@ -708,17 +736,9 @@ class TestErrorHandling:
but not treated as a failure.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise DocumentIsPausedError
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
@ -853,17 +873,9 @@ class TestTaskCancellation:
Session cleanup should happen in finally block.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -883,17 +895,9 @@ class TestTaskCancellation:
Session cleanup should happen even when errors occur.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Test error")
@ -962,6 +966,7 @@ class TestAdvancedScenarios:
document_ids = [str(uuid.uuid4()) for _ in range(3)]
# Create only 2 documents (simulate one missing)
# The new code uses .all() which will only return existing documents
mock_documents = []
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
doc = MagicMock(spec=Document)
@ -971,21 +976,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Create iterator that returns None for missing document
doc_responses = [mock_documents[0], None, mock_documents[1]]
doc_iter = iter(doc_responses)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data - .all() will only return existing documents
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set vector space exactly at limit
mock_feature_service.get_features.return_value.billing.enabled = True
@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Billing disabled - limits should not be checked
mock_feature_service.get_features.return_value.billing.enabled = False
@ -1273,19 +1246,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1321,19 +1284,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1415,17 +1368,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = [mock_document]
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1465,17 +1410,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = [mock_document]
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1555,19 +1492,9 @@ class TestEdgeCases:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set vector space limit to 0 (unlimited)
mock_feature_service.get_features.return_value.billing.enabled = True
@ -1612,19 +1539,9 @@ class TestEdgeCases:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set negative vector space limit
mock_feature_service.get_features.return_value.billing.enabled = True
@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Configure billing with sufficient limits
mock_feature_service.get_features.return_value.billing.enabled = True
@ -1826,19 +1733,9 @@ class TestRobustness:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
@ -1866,7 +1763,7 @@ class TestRobustness:
- No exceptions occur
Expected behavior:
- Database session is closed
- All database sessions are closed
- No connection leaks
"""
# Arrange
@ -1879,19 +1776,9 @@ class TestRobustness:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1899,10 +1786,11 @@ class TestRobustness:
# Act
_document_indexing(dataset_id, document_ids)
# Assert
assert mock_db_session.close.called
# Verify close is called exactly once
assert mock_db_session.close.call_count == 1
# Assert - All created sessions should be closed
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
assert len(mock_db_session.all_sessions) >= 1
for session in mock_db_session.all_sessions:
assert session.close.called, "All sessions should be closed"
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
"""

View File

@ -109,25 +109,87 @@ def mock_document_segments(document_id):
@pytest.fixture
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
"""Mock database session via session_factory.create_session().
After session split refactor, the code calls create_session() multiple times.
This fixture creates shared query mocks so all sessions use the same
query configuration, simulating database persistence across sessions.
The fixture automatically converts side_effect to cycle to prevent StopIteration.
Tests configure mocks the same way as before, but behind the scenes the values
are cycled infinitely for all sessions.
"""
from itertools import cycle
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
sessions = []
def _exit_side_effect(*args, **kwargs):
session.close()
# Shared query mocks - all sessions use these
shared_query = MagicMock()
shared_filter_by = MagicMock()
shared_scalars_result = MagicMock()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
# Create custom first mock that auto-cycles side_effect
class CyclicMock(MagicMock):
def __setattr__(self, name, value):
if name == "side_effect" and value is not None:
# Convert list/tuple to infinite cycle
if isinstance(value, (list, tuple)):
value = cycle(value)
super().__setattr__(name, value)
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
shared_query.where.return_value.first = CyclicMock()
shared_filter_by.first = CyclicMock()
def _create_session():
"""Create a new mock session for each create_session() call."""
session = MagicMock()
session.close = MagicMock()
session.commit = MagicMock()
# Mock session.begin() context manager
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def _begin_exit_side_effect(exc_type, exc, tb):
# commit on success
if exc_type is None:
session.commit()
# return False to propagate exceptions
return False
begin_cm.__exit__.side_effect = _begin_exit_side_effect
session.begin.return_value = begin_cm
# Mock create_session() context manager
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(exc_type, exc, tb):
session.close()
return False
cm.__exit__.side_effect = _exit_side_effect
# All sessions use the same shared query mocks
session.query.return_value = shared_query
shared_query.where.return_value = shared_query
shared_query.filter_by.return_value = shared_filter_by
session.scalars.return_value = shared_scalars_result
sessions.append(session)
# Attach helpers on the first created session for assertions across all sessions
if len(sessions) == 1:
session.get_all_sessions = lambda: sessions
session.any_close_called = lambda: any(s.close.called for s in sessions)
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
return cm
mock_sf.create_session.side_effect = _create_session
# Create first session and return it
_create_session()
yield sessions[0]
@pytest.fixture
@ -186,8 +248,8 @@ class TestDocumentIndexingSyncTask:
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_db_session.close.assert_called_once()
# Assert - at least one session should have been closed
assert mock_db_session.any_close_called()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing."""
@ -230,6 +292,7 @@ class TestDocumentIndexingSyncTask:
"""Test that task handles missing credentials by updating document status."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act
@ -239,8 +302,8 @@ class TestDocumentIndexingSyncTask:
assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None
mock_db_session.commit.assert_called()
mock_db_session.close.assert_called()
assert mock_db_session.any_commit_called()
assert mock_db_session.any_close_called()
def test_page_not_updated(
self,
@ -254,6 +317,7 @@ class TestDocumentIndexingSyncTask:
"""Test that task does nothing when page has not been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
@ -263,8 +327,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# Session should still be closed via context manager teardown
assert mock_db_session.close.called
# At least one session should have been closed via context manager teardown
assert mock_db_session.any_close_called()
def test_successful_sync_when_page_updated(
self,
@ -281,7 +345,20 @@ class TestDocumentIndexingSyncTask:
):
"""Test successful sync flow when Notion page has been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
# Set exact sequence of returns across calls to `.first()`:
# 1) document (initial fetch)
# 2) dataset (pre-check)
# 3) dataset (cleaning phase)
# 4) document (pre-indexing update)
# 5) document (indexing runner fetch)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
@ -299,28 +376,40 @@ class TestDocumentIndexingSyncTask:
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
# Aggregate execute calls across all created sessions
execute_sqls = []
for s in mock_db_session.get_all_sessions():
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations
assert mock_db_session.commit.called
mock_db_session.close.assert_called_once()
# Verify session operations (across any created session)
assert mock_db_session.any_commit_called()
assert mock_db_session.any_close_called()
def test_dataset_not_found_during_cleaning(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_indexing_runner,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles dataset not found during cleaning phase."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
None,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
@ -329,8 +418,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document should still be set to parsing
assert mock_document.indexing_status == "parsing"
# Session should be closed after error
mock_db_session.close.assert_called_once()
# At least one session should be closed after error
assert mock_db_session.any_close_called()
def test_cleaning_error_continues_to_indexing(
self,
@ -346,8 +435,14 @@ class TestDocumentIndexingSyncTask:
):
"""Test that indexing continues even if cleaning fails."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Make the cleaning step fail but not the segment fetch
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
processor.clean.side_effect = Exception("Cleaning error")
mock_db_session.scalars.return_value.all.return_value = []
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
@ -356,7 +451,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document])
mock_db_session.close.assert_called_once()
assert mock_db_session.any_close_called()
def test_indexing_runner_document_paused_error(
self,
@ -373,7 +468,10 @@ class TestDocumentIndexingSyncTask:
):
"""Test that DocumentIsPausedError is handled gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
@ -383,7 +481,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Session should be closed after handling error
mock_db_session.close.assert_called_once()
assert mock_db_session.any_close_called()
def test_indexing_runner_general_error(
self,
@ -400,7 +498,10 @@ class TestDocumentIndexingSyncTask:
):
"""Test that general exceptions during indexing are handled."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error")
@ -410,7 +511,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Session should be closed after error
mock_db_session.close.assert_called_once()
assert mock_db_session.any_close_called()
def test_notion_extractor_initialized_with_correct_params(
self,
@ -517,7 +618,14 @@ class TestDocumentIndexingSyncTask:
):
"""Test that index processor clean is called with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"

View File

@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs:
mock_query.where.return_value = mock_delete_query
mock_db.session.query.return_value = mock_query
delete_func("log-1")
delete_func(mock_db.session, "log-1")
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
mock_query.where.assert_called_once()

13
api/uv.lock generated
View File

@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.12.0"
version = "1.12.1"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },
@ -1653,7 +1653,7 @@ requires-dist = [
{ name = "starlette", specifier = "==0.49.1" },
{ name = "tiktoken", specifier = "~=0.9.0" },
{ name = "transformers", specifier = "~=4.56.1" },
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" },
{ name = "weave", specifier = ">=0.52.16" },
{ name = "weaviate-client", specifier = "==4.17.0" },
{ name = "webvtt-py", specifier = "~=0.5.1" },
@ -6814,12 +6814,12 @@ wheels = [
[[package]]
name = "unstructured"
version = "0.16.25"
version = "0.18.31"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "backoff" },
{ name = "beautifulsoup4" },
{ name = "chardet" },
{ name = "charset-normalizer" },
{ name = "dataclasses-json" },
{ name = "emoji" },
{ name = "filetype" },
@ -6827,6 +6827,7 @@ dependencies = [
{ name = "langdetect" },
{ name = "lxml" },
{ name = "nltk" },
{ name = "numba" },
{ name = "numpy" },
{ name = "psutil" },
{ name = "python-iso639" },
@ -6839,9 +6840,9 @@ dependencies = [
{ name = "unstructured-client" },
{ name = "wrapt" },
]
sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" }
sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" },
{ url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" },
]
[package.optional-dependencies]

View File

@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.12.0
image: langgenius/dify-web:1.12.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -707,7 +707,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@ -749,7 +749,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@ -788,7 +788,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@ -818,7 +818,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.12.0
image: langgenius/dify-web:1.12.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -109,6 +109,7 @@ const AgentTools: FC = () => {
tool_parameters: paramsWithDefaultValue,
notAuthor: !tool.is_team_authorization,
enabled: true,
type: tool.provider_type as CollectionType,
}
}
const handleSelectTool = (tool: ToolDefaultValue) => {

View File

@ -1,19 +1,27 @@
'use client'
import type { FileUpload } from '@/app/components/base/features/types'
import type { App } from '@/types/app'
import { useRef } from 'react'
import * as React from 'react'
import { useMemo, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import Loading from '@/app/components/base/loading'
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
import AppInputsForm from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-form'
import { useAppInputsFormSchema } from '@/app/components/plugins/plugin-detail-panel/app-selector/hooks/use-app-inputs-form-schema'
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
import { useAppDetail } from '@/service/use-apps'
import { useFileUploadConfig } from '@/service/use-common'
import { useAppWorkflow } from '@/service/use-workflow'
import { AppModeEnum, Resolution } from '@/types/app'
import { cn } from '@/utils/classnames'
type Props = {
value?: {
app_id: string
inputs: Record<string, unknown>
inputs: Record<string, any>
}
appDetail: App
onFormChange: (value: Record<string, unknown>) => void
onFormChange: (value: Record<string, any>) => void
}
const AppInputsPanel = ({
@ -22,33 +30,155 @@ const AppInputsPanel = ({
onFormChange,
}: Props) => {
const { t } = useTranslation()
const inputsRef = useRef<Record<string, unknown>>(value?.inputs || {})
const inputsRef = useRef<any>(value?.inputs || {})
const isBasicApp = appDetail.mode !== AppModeEnum.ADVANCED_CHAT && appDetail.mode !== AppModeEnum.WORKFLOW
const { data: fileUploadConfig } = useFileUploadConfig()
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(isBasicApp ? '' : appDetail.id)
const isLoading = isAppLoading || isWorkflowLoading
const { inputFormSchema, isLoading } = useAppInputsFormSchema({ appDetail })
const basicAppFileConfig = useMemo(() => {
let fileConfig: FileUpload
if (isBasicApp)
fileConfig = currentApp?.model_config?.file_upload as FileUpload
else
fileConfig = currentWorkflow?.features?.file_upload as FileUpload
return {
image: {
detail: fileConfig?.image?.detail || Resolution.high,
enabled: !!fileConfig?.image?.enabled,
number_limits: fileConfig?.image?.number_limits || 3,
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
},
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
allowed_file_extensions: fileConfig?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods || fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
}
}, [currentApp?.model_config?.file_upload, currentWorkflow?.features?.file_upload, isBasicApp])
const handleFormChange = (newValue: Record<string, unknown>) => {
inputsRef.current = newValue
onFormChange(newValue)
const inputFormSchema = useMemo(() => {
if (!currentApp)
return []
let inputFormSchema = []
if (isBasicApp) {
inputFormSchema = currentApp.model_config?.user_input_form?.filter((item: any) => !item.external_data_tool).map((item: any) => {
if (item.paragraph) {
return {
...item.paragraph,
type: 'paragraph',
required: false,
}
}
if (item.number) {
return {
...item.number,
type: 'number',
required: false,
}
}
if (item.checkbox) {
return {
...item.checkbox,
type: 'checkbox',
required: false,
}
}
if (item.select) {
return {
...item.select,
type: 'select',
required: false,
}
}
if (item['file-list']) {
return {
...item['file-list'],
type: 'file-list',
required: false,
fileUploadConfig,
}
}
if (item.file) {
return {
...item.file,
type: 'file',
required: false,
fileUploadConfig,
}
}
if (item.json_object) {
return {
...item.json_object,
type: 'json_object',
}
}
return {
...item['text-input'],
type: 'text-input',
required: false,
}
}) || []
}
else {
const startNode = currentWorkflow?.graph?.nodes.find(node => node.data.type === BlockEnum.Start) as any
inputFormSchema = startNode?.data.variables.map((variable: any) => {
if (variable.type === InputVarType.multiFiles) {
return {
...variable,
required: false,
fileUploadConfig,
}
}
if (variable.type === InputVarType.singleFile) {
return {
...variable,
required: false,
fileUploadConfig,
}
}
return {
...variable,
required: false,
}
}) || []
}
if ((currentApp.mode === AppModeEnum.COMPLETION || currentApp.mode === AppModeEnum.WORKFLOW) && basicAppFileConfig.enabled) {
inputFormSchema.push({
label: 'Image Upload',
variable: '#image#',
type: InputVarType.singleFile,
required: false,
...basicAppFileConfig,
fileUploadConfig,
})
}
return inputFormSchema || []
}, [basicAppFileConfig, currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
const handleFormChange = (value: Record<string, any>) => {
inputsRef.current = value
onFormChange(value)
}
const hasInputs = inputFormSchema.length > 0
return (
<div className={cn('flex max-h-[240px] flex-col rounded-b-2xl border-t border-divider-subtle pb-4')}>
{isLoading && <div className="pt-3"><Loading type="app" /></div>}
{!isLoading && (
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">
{t('appSelector.params', { ns: 'app' })}
</div>
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">{t('appSelector.params', { ns: 'app' })}</div>
)}
{!isLoading && !hasInputs && (
{!isLoading && !inputFormSchema.length && (
<div className="flex h-16 flex-col items-center justify-center">
<div className="system-sm-regular text-text-tertiary">
{t('appSelector.noParams', { ns: 'app' })}
</div>
<div className="system-sm-regular text-text-tertiary">{t('appSelector.noParams', { ns: 'app' })}</div>
</div>
)}
{!isLoading && hasInputs && (
{!isLoading && !!inputFormSchema.length && (
<div className="grow overflow-y-auto">
<AppInputsForm
inputs={value?.inputs || {}}

View File

@ -1,211 +0,0 @@
'use client'
import type { FileUpload } from '@/app/components/base/features/types'
import type { FileUploadConfigResponse } from '@/models/common'
import type { App } from '@/types/app'
import type { FetchWorkflowDraftResponse } from '@/types/workflow'
import { useMemo } from 'react'
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
import { useAppDetail } from '@/service/use-apps'
import { useFileUploadConfig } from '@/service/use-common'
import { useAppWorkflow } from '@/service/use-workflow'
import { AppModeEnum, Resolution } from '@/types/app'
const BASIC_INPUT_TYPE_MAP: Record<string, string> = {
'paragraph': 'paragraph',
'number': 'number',
'checkbox': 'checkbox',
'select': 'select',
'file-list': 'file-list',
'file': 'file',
'json_object': 'json_object',
}
const FILE_INPUT_TYPES = new Set(['file-list', 'file'])
const WORKFLOW_FILE_VAR_TYPES = new Set([InputVarType.multiFiles, InputVarType.singleFile])
type InputSchemaItem = {
label?: string
variable?: string
type: string
required: boolean
fileUploadConfig?: FileUploadConfigResponse
[key: string]: unknown
}
function isBasicAppMode(mode: string): boolean {
return mode !== AppModeEnum.ADVANCED_CHAT && mode !== AppModeEnum.WORKFLOW
}
function supportsImageUpload(mode: string): boolean {
return mode === AppModeEnum.COMPLETION || mode === AppModeEnum.WORKFLOW
}
function buildFileConfig(fileConfig: FileUpload | undefined) {
return {
image: {
detail: fileConfig?.image?.detail || Resolution.high,
enabled: !!fileConfig?.image?.enabled,
number_limits: fileConfig?.image?.number_limits || 3,
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
},
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
allowed_file_extensions: fileConfig?.allowed_file_extensions
|| [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods
|| fileConfig?.image?.transfer_methods
|| ['local_file', 'remote_url'],
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
}
}
function mapBasicAppInputItem(
item: Record<string, unknown>,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem | null {
for (const [key, type] of Object.entries(BASIC_INPUT_TYPE_MAP)) {
if (!item[key])
continue
const inputData = item[key] as Record<string, unknown>
const needsFileConfig = FILE_INPUT_TYPES.has(key)
return {
...inputData,
type,
required: false,
...(needsFileConfig && { fileUploadConfig }),
}
}
const textInput = item['text-input'] as Record<string, unknown> | undefined
if (!textInput)
return null
return {
...textInput,
type: 'text-input',
required: false,
}
}
function mapWorkflowVariable(
variable: Record<string, unknown>,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem {
const needsFileConfig = WORKFLOW_FILE_VAR_TYPES.has(variable.type as InputVarType)
return {
...variable,
type: variable.type as string,
required: false,
...(needsFileConfig && { fileUploadConfig }),
}
}
function createImageUploadSchema(
basicFileConfig: ReturnType<typeof buildFileConfig>,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem {
return {
label: 'Image Upload',
variable: '#image#',
type: InputVarType.singleFile,
required: false,
...basicFileConfig,
fileUploadConfig,
}
}
function buildBasicAppSchema(
currentApp: App,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem[] {
const userInputForm = currentApp.model_config?.user_input_form as Array<Record<string, unknown>> | undefined
if (!userInputForm)
return []
return userInputForm
.filter((item: Record<string, unknown>) => !item.external_data_tool)
.map((item: Record<string, unknown>) => mapBasicAppInputItem(item, fileUploadConfig))
.filter((item): item is InputSchemaItem => item !== null)
}
function buildWorkflowSchema(
workflow: FetchWorkflowDraftResponse,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem[] {
const startNode = workflow.graph?.nodes.find(
node => node.data.type === BlockEnum.Start,
) as { data: { variables: Array<Record<string, unknown>> } } | undefined
if (!startNode?.data.variables)
return []
return startNode.data.variables.map(
variable => mapWorkflowVariable(variable, fileUploadConfig),
)
}
type UseAppInputsFormSchemaParams = {
appDetail: App
}
type UseAppInputsFormSchemaResult = {
inputFormSchema: InputSchemaItem[]
isLoading: boolean
fileUploadConfig?: FileUploadConfigResponse
}
export function useAppInputsFormSchema({
appDetail,
}: UseAppInputsFormSchemaParams): UseAppInputsFormSchemaResult {
const isBasicApp = isBasicAppMode(appDetail.mode)
const { data: fileUploadConfig } = useFileUploadConfig()
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(
isBasicApp ? '' : appDetail.id,
)
const isLoading = isAppLoading || isWorkflowLoading
const inputFormSchema = useMemo(() => {
if (!currentApp)
return []
if (!isBasicApp && !currentWorkflow)
return []
// Build base schema based on app type
// Note: currentWorkflow is guaranteed to be defined here due to the early return above
const baseSchema = isBasicApp
? buildBasicAppSchema(currentApp, fileUploadConfig)
: buildWorkflowSchema(currentWorkflow!, fileUploadConfig)
if (!supportsImageUpload(currentApp.mode))
return baseSchema
const rawFileConfig = isBasicApp
? currentApp.model_config?.file_upload as FileUpload
: currentWorkflow?.features?.file_upload as FileUpload
const basicFileConfig = buildFileConfig(rawFileConfig)
if (!basicFileConfig.enabled)
return baseSchema
return [
...baseSchema,
createImageUploadSchema(basicFileConfig, fileUploadConfig),
]
}, [currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
return {
inputFormSchema,
isLoading,
fileUploadConfig,
}
}

View File

@ -6,6 +6,7 @@ import Toast from '@/app/components/base/toast'
import { PluginSource } from '../types'
import DetailHeader from './detail-header'
// Use vi.hoisted for mock functions used in vi.mock factories
const {
mockSetShowUpdatePluginModal,
mockRefreshModelProviders,

View File

@ -1,2 +1,416 @@
// Re-export from refactored module for backward compatibility
export { default } from './detail-header/index'
import type { PluginDetail } from '../types'
import {
RiArrowLeftRightLine,
RiBugLine,
RiCloseLine,
RiHardDrive3Line,
} from '@remixicon/react'
import { useBoolean } from 'ahooks'
import * as React from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ActionButton from '@/app/components/base/action-button'
import { trackEvent } from '@/app/components/base/amplitude'
import Badge from '@/app/components/base/badge'
import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm'
import { Github } from '@/app/components/base/icons/src/public/common'
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
import Toast from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth'
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
import { API_PREFIX } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useGetLanguage, useLocale } from '@/context/i18n'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
import useTheme from '@/hooks/use-theme'
import { uninstallPlugin } from '@/service/plugins'
import { useAllToolProviders, useInvalidateAllToolProviders } from '@/service/use-tools'
import { cn } from '@/utils/classnames'
import { getMarketplaceUrl } from '@/utils/var'
import { AutoUpdateLine } from '../../base/icons/src/vender/system'
import Verified from '../base/badges/verified'
import DeprecationNotice from '../base/deprecation-notice'
import Icon from '../card/base/card-icon'
import Description from '../card/base/description'
import OrgInfo from '../card/base/org-info'
import Title from '../card/base/title'
import { useGitHubReleases } from '../install-plugin/hooks'
import useReferenceSetting from '../plugin-page/use-reference-setting'
import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting/types'
import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../reference-setting-modal/auto-update-setting/utils'
import { PluginCategoryEnum, PluginSource } from '../types'
const i18nPrefix = 'action'
type Props = {
detail: PluginDetail
isReadmeView?: boolean
onHide?: () => void
onUpdate?: (isDelete?: boolean) => void
}
const DetailHeader = ({
detail,
isReadmeView = false,
onHide,
onUpdate,
}: Props) => {
const { t } = useTranslation()
const { userProfile: { timezone } } = useAppContext()
const { theme } = useTheme()
const locale = useGetLanguage()
const currentLocale = useLocale()
const { checkForUpdates, fetchReleases } = useGitHubReleases()
const { setShowUpdatePluginModal } = useModalContext()
const { refreshModelProviders } = useProviderContext()
const invalidateAllToolProviders = useInvalidateAllToolProviders()
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
const {
id,
source,
tenant_id,
version,
latest_unique_identifier,
latest_version,
meta,
plugin_id,
status,
deprecated_reason,
alternative_plugin_id,
} = detail
const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail
const isTool = category === PluginCategoryEnum.tool
const providerBriefInfo = tool?.identity
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
const { data: collectionList = [] } = useAllToolProviders(isTool)
const provider = useMemo(() => {
return collectionList.find(collection => collection.name === providerKey)
}, [collectionList, providerKey])
const isFromGitHub = source === PluginSource.github
const isFromMarketplace = source === PluginSource.marketplace
const [isShow, setIsShow] = useState(false)
const [targetVersion, setTargetVersion] = useState({
version: latest_version,
unique_identifier: latest_unique_identifier,
})
const hasNewVersion = useMemo(() => {
if (isFromMarketplace)
return !!latest_version && latest_version !== version
return false
}, [isFromMarketplace, latest_version, version])
const iconFileName = theme === 'dark' && icon_dark ? icon_dark : icon
const iconSrc = iconFileName
? (iconFileName.startsWith('http') ? iconFileName : `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenant_id}&filename=${iconFileName}`)
: ''
const detailUrl = useMemo(() => {
if (isFromGitHub)
return `https://github.com/${meta!.repo}`
if (isFromMarketplace)
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: currentLocale, theme })
return ''
}, [author, isFromGitHub, isFromMarketplace, meta, name, theme])
const [isShowUpdateModal, {
setTrue: showUpdateModal,
setFalse: hideUpdateModal,
}] = useBoolean(false)
const { referenceSetting } = useReferenceSetting()
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
const isAutoUpgradeEnabled = useMemo(() => {
if (!enable_marketplace)
return false
if (!autoUpgradeInfo || !isFromMarketplace)
return false
if (autoUpgradeInfo.strategy_setting === 'disabled')
return false
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
return true
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
return true
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
return true
return false
}, [autoUpgradeInfo, plugin_id, isFromMarketplace])
const [isDowngrade, setIsDowngrade] = useState(false)
const handleUpdate = async (isDowngrade?: boolean) => {
if (isFromMarketplace) {
setIsDowngrade(!!isDowngrade)
showUpdateModal()
return
}
const owner = meta!.repo.split('/')[0] || author
const repo = meta!.repo.split('/')[1] || name
const fetchedReleases = await fetchReleases(owner, repo)
if (fetchedReleases.length === 0)
return
const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta!.version)
Toast.notify(toastProps)
if (needUpdate) {
setShowUpdatePluginModal({
onSaveCallback: () => {
onUpdate?.()
},
payload: {
type: PluginSource.github,
category: detail.declaration.category,
github: {
originalPackageInfo: {
id: detail.plugin_unique_identifier,
repo: meta!.repo,
version: meta!.version,
package: meta!.package,
releases: fetchedReleases,
},
},
},
})
}
}
const handleUpdatedFromMarketplace = () => {
onUpdate?.()
hideUpdateModal()
}
const [isShowPluginInfo, {
setTrue: showPluginInfo,
setFalse: hidePluginInfo,
}] = useBoolean(false)
const [isShowDeleteConfirm, {
setTrue: showDeleteConfirm,
setFalse: hideDeleteConfirm,
}] = useBoolean(false)
const [deleting, {
setTrue: showDeleting,
setFalse: hideDeleting,
}] = useBoolean(false)
const handleDelete = useCallback(async () => {
showDeleting()
const res = await uninstallPlugin(id)
hideDeleting()
if (res.success) {
hideDeleteConfirm()
onUpdate?.(true)
if (PluginCategoryEnum.model.includes(category))
refreshModelProviders()
if (PluginCategoryEnum.tool.includes(category))
invalidateAllToolProviders()
trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name })
}
}, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders, plugin_id, name])
return (
<div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}>
<div className="flex">
<div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}>
<Icon src={iconSrc} />
</div>
<div className="ml-3 w-0 grow">
<div className="flex h-5 items-center">
<Title title={label[locale]} />
{verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />}
{!!version && (
<PluginVersionPicker
disabled={!isFromMarketplace || isReadmeView}
isShow={isShow}
onShowChange={setIsShow}
pluginID={plugin_id}
currentVersion={version}
onSelect={(state) => {
setTargetVersion(state)
handleUpdate(state.isDowngrade)
}}
trigger={(
<Badge
className={cn(
'mx-1',
isShow && 'bg-state-base-hover',
(isShow || isFromMarketplace) && 'hover:bg-state-base-hover',
)}
uppercase={false}
text={(
<>
<div>{isFromGitHub ? meta!.version : version}</div>
{isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />}
</>
)}
hasRedCornerMark={hasNewVersion}
/>
)}
/>
)}
{/* Auto update info */}
{isAutoUpgradeEnabled && !isReadmeView && (
<Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}>
{/* add a a div to fix tooltip hover not show problem */}
<div>
<Badge className="mr-1 cursor-pointer px-1">
<AutoUpdateLine className="size-3" />
</Badge>
</div>
</Tooltip>
)}
{(hasNewVersion || isFromGitHub) && (
<Button
variant="secondary-accent"
size="small"
className="!h-5"
onClick={() => {
if (isFromMarketplace) {
setTargetVersion({
version: latest_version,
unique_identifier: latest_unique_identifier,
})
}
handleUpdate()
}}
>
{t('detailPanel.operation.update', { ns: 'plugin' })}
</Button>
)}
</div>
<div className="mb-1 flex h-4 items-center justify-between">
<div className="mt-0.5 flex items-center">
<OrgInfo
packageNameClassName="w-auto"
orgName={author}
packageName={name?.includes('/') ? (name.split('/').pop() || '') : name}
/>
{!!source && (
<>
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
{source === PluginSource.marketplace && (
<Tooltip popupContent={t('detailPanel.categoryTip.marketplace', { ns: 'plugin' })}>
<div><BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" /></div>
</Tooltip>
)}
{source === PluginSource.github && (
<Tooltip popupContent={t('detailPanel.categoryTip.github', { ns: 'plugin' })}>
<div><Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" /></div>
</Tooltip>
)}
{source === PluginSource.local && (
<Tooltip popupContent={t('detailPanel.categoryTip.local', { ns: 'plugin' })}>
<div><RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" /></div>
</Tooltip>
)}
{source === PluginSource.debugging && (
<Tooltip popupContent={t('detailPanel.categoryTip.debugging', { ns: 'plugin' })}>
<div><RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" /></div>
</Tooltip>
)}
</>
)}
</div>
</div>
</div>
{!isReadmeView && (
<div className="flex gap-1">
<OperationDropdown
source={source}
onInfo={showPluginInfo}
onCheckVersion={handleUpdate}
onRemove={showDeleteConfirm}
detailUrl={detailUrl}
/>
<ActionButton onClick={onHide}>
<RiCloseLine className="h-4 w-4" />
</ActionButton>
</div>
)}
</div>
{isFromMarketplace && (
<DeprecationNotice
status={status}
deprecatedReason={deprecated_reason}
alternativePluginId={alternative_plugin_id}
alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })}
className="mt-3"
/>
)}
{!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2}></Description>}
{
category === PluginCategoryEnum.tool && !isReadmeView && (
<PluginAuth
pluginPayload={{
provider: provider?.name || '',
category: AuthCategory.tool,
providerType: provider?.type || '',
detail,
}}
/>
)
}
{isShowPluginInfo && (
<PluginInfo
repository={isFromGitHub ? meta?.repo : ''}
release={version}
packageName={meta?.package || ''}
onHide={hidePluginInfo}
/>
)}
{isShowDeleteConfirm && (
<Confirm
isShow
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
content={(
<div>
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
<span className="system-md-semibold">{label[locale]}</span>
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
<br />
</div>
)}
onCancel={hideDeleteConfirm}
onConfirm={handleDelete}
isLoading={deleting}
isDisabled={deleting}
/>
)}
{
isShowUpdateModal && (
<UpdateFromMarketplace
pluginId={plugin_id}
payload={{
category: detail.declaration.category,
originalPackageInfo: {
id: detail.plugin_unique_identifier,
payload: detail.declaration,
},
targetPackageInfo: {
id: targetVersion.unique_identifier,
version: targetVersion.version,
},
}}
onCancel={hideUpdateModal}
onSave={handleUpdatedFromMarketplace}
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
/>
)
}
</div>
)
}
export default DetailHeader

View File

@ -1,539 +0,0 @@
import type { PluginDetail } from '../../../types'
import type { ModalStates, VersionTarget } from '../hooks'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginSource } from '../../../types'
import HeaderModals from './header-modals'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/context/i18n', () => ({
useGetLanguage: () => 'en_US',
}))
vi.mock('@/app/components/base/confirm', () => ({
default: ({ isShow, title, onCancel, onConfirm, isLoading }: {
isShow: boolean
title: string
onCancel: () => void
onConfirm: () => void
isLoading: boolean
}) => isShow
? (
<div data-testid="delete-confirm">
<div data-testid="delete-title">{title}</div>
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
</div>
)
: null,
}))
vi.mock('@/app/components/plugins/plugin-page/plugin-info', () => ({
default: ({ repository, release, packageName, onHide }: {
repository: string
release: string
packageName: string
onHide: () => void
}) => (
<div data-testid="plugin-info">
<div data-testid="plugin-info-repo">{repository}</div>
<div data-testid="plugin-info-release">{release}</div>
<div data-testid="plugin-info-package">{packageName}</div>
<button data-testid="plugin-info-close" onClick={onHide}>Close</button>
</div>
),
}))
vi.mock('@/app/components/plugins/update-plugin/from-market-place', () => ({
default: ({ pluginId, onSave, onCancel, isShowDowngradeWarningModal }: {
pluginId: string
onSave: () => void
onCancel: () => void
isShowDowngradeWarningModal: boolean
}) => (
<div data-testid="update-modal">
<div data-testid="update-plugin-id">{pluginId}</div>
<div data-testid="update-downgrade-warning">{String(isShowDowngradeWarningModal)}</div>
<button data-testid="update-modal-save" onClick={onSave}>Save</button>
<button data-testid="update-modal-cancel" onClick={onCancel}>Cancel</button>
</div>
),
}))
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
author: 'test-author',
name: 'test-plugin-name',
category: 'tool',
label: { en_US: 'Test Plugin Label' },
description: { en_US: 'Test description' },
icon: 'icon.png',
verified: true,
} as unknown as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '2.0.0',
latest_unique_identifier: 'new-uid',
source: PluginSource.marketplace,
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
...overrides,
})
const createModalStatesMock = (overrides: Partial<ModalStates> = {}): ModalStates => ({
isShowUpdateModal: false,
showUpdateModal: vi.fn<() => void>(),
hideUpdateModal: vi.fn<() => void>(),
isShowPluginInfo: false,
showPluginInfo: vi.fn<() => void>(),
hidePluginInfo: vi.fn<() => void>(),
isShowDeleteConfirm: false,
showDeleteConfirm: vi.fn<() => void>(),
hideDeleteConfirm: vi.fn<() => void>(),
deleting: false,
showDeleting: vi.fn<() => void>(),
hideDeleting: vi.fn<() => void>(),
...overrides,
})
const createTargetVersion = (overrides: Partial<VersionTarget> = {}): VersionTarget => ({
version: '2.0.0',
unique_identifier: 'new-uid',
...overrides,
})
describe('HeaderModals', () => {
let mockOnUpdatedFromMarketplace: () => void
let mockOnDelete: () => void
beforeEach(() => {
vi.clearAllMocks()
mockOnUpdatedFromMarketplace = vi.fn<() => void>()
mockOnDelete = vi.fn<() => void>()
})
describe('Plugin Info Modal', () => {
it('should not render plugin info modal when isShowPluginInfo is false', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: false })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.queryByTestId('plugin-info')).not.toBeInTheDocument()
})
it('should render plugin info modal when isShowPluginInfo is true', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
})
it('should pass GitHub repo to plugin info for GitHub source', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'test-pkg' },
})
render(
<HeaderModals
detail={detail}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('owner/repo')
})
it('should pass empty string for repo for non-GitHub source', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
render(
<HeaderModals
detail={createPluginDetail({ source: PluginSource.marketplace })}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
})
it('should call hidePluginInfo when close button is clicked', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('plugin-info-close'))
expect(modalStates.hidePluginInfo).toHaveBeenCalled()
})
})
describe('Delete Confirm Modal', () => {
it('should not render delete confirm when isShowDeleteConfirm is false', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: false })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
})
it('should render delete confirm when isShowDeleteConfirm is true', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
})
it('should show correct delete title', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('delete-title')).toHaveTextContent('action.delete')
})
it('should call hideDeleteConfirm when cancel is clicked', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('confirm-cancel'))
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
})
it('should call onDelete when confirm is clicked', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('confirm-ok'))
expect(mockOnDelete).toHaveBeenCalled()
})
it('should disable confirm button when deleting', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true, deleting: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('confirm-ok')).toBeDisabled()
})
})
describe('Update Modal', () => {
it('should not render update modal when isShowUpdateModal is false', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: false })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.queryByTestId('update-modal')).not.toBeInTheDocument()
})
it('should render update modal when isShowUpdateModal is true', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
})
it('should pass plugin id to update modal', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail({ plugin_id: 'my-plugin-id' })}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-plugin-id')).toHaveTextContent('my-plugin-id')
})
it('should call onUpdatedFromMarketplace when save is clicked', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('update-modal-save'))
expect(mockOnUpdatedFromMarketplace).toHaveBeenCalled()
})
it('should call hideUpdateModal when cancel is clicked', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('update-modal-cancel'))
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
})
it('should show downgrade warning when isDowngrade and isAutoUpgradeEnabled are true', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={true}
isAutoUpgradeEnabled={true}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('true')
})
it('should not show downgrade warning when only isDowngrade is true', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={true}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
})
it('should not show downgrade warning when only isAutoUpgradeEnabled is true', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={true}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
})
})
describe('Multiple Modals', () => {
it('should render multiple modals when multiple are open', () => {
const modalStates = createModalStatesMock({
isShowPluginInfo: true,
isShowDeleteConfirm: true,
isShowUpdateModal: true,
})
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
})
})
describe('Edge Cases', () => {
it('should handle undefined target version values', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={{ version: undefined, unique_identifier: undefined }}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
})
it('should handle empty meta for GitHub source', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
const detail = createPluginDetail({
source: PluginSource.github,
meta: undefined,
})
render(
<HeaderModals
detail={detail}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
expect(screen.getByTestId('plugin-info-package')).toHaveTextContent('')
})
})
})

View File

@ -1,107 +0,0 @@
'use client'
import type { FC } from 'react'
import type { PluginDetail } from '../../../types'
import type { ModalStates, VersionTarget } from '../hooks'
import { useTranslation } from 'react-i18next'
import Confirm from '@/app/components/base/confirm'
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
import { useGetLanguage } from '@/context/i18n'
import { PluginSource } from '../../../types'
const i18nPrefix = 'action'
type HeaderModalsProps = {
detail: PluginDetail
modalStates: ModalStates
targetVersion: VersionTarget
isDowngrade: boolean
isAutoUpgradeEnabled: boolean
onUpdatedFromMarketplace: () => void
onDelete: () => void
}
const HeaderModals: FC<HeaderModalsProps> = ({
detail,
modalStates,
targetVersion,
isDowngrade,
isAutoUpgradeEnabled,
onUpdatedFromMarketplace,
onDelete,
}) => {
const { t } = useTranslation()
const locale = useGetLanguage()
const { source, version, meta } = detail
const { label } = detail.declaration || detail
const isFromGitHub = source === PluginSource.github
const {
isShowUpdateModal,
hideUpdateModal,
isShowPluginInfo,
hidePluginInfo,
isShowDeleteConfirm,
hideDeleteConfirm,
deleting,
} = modalStates
return (
<>
{/* Plugin Info Modal */}
{isShowPluginInfo && (
<PluginInfo
repository={isFromGitHub ? meta?.repo : ''}
release={version}
packageName={meta?.package || ''}
onHide={hidePluginInfo}
/>
)}
{/* Delete Confirm Modal */}
{isShowDeleteConfirm && (
<Confirm
isShow
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
content={(
<div>
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
<span className="system-md-semibold">{label[locale]}</span>
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
<br />
</div>
)}
onCancel={hideDeleteConfirm}
onConfirm={onDelete}
isLoading={deleting}
isDisabled={deleting}
/>
)}
{/* Update from Marketplace Modal */}
{isShowUpdateModal && (
<UpdateFromMarketplace
pluginId={detail.plugin_id}
payload={{
category: detail.declaration?.category ?? '',
originalPackageInfo: {
id: detail.plugin_unique_identifier,
payload: detail.declaration ?? undefined,
},
targetPackageInfo: {
id: targetVersion.unique_identifier || '',
version: targetVersion.version || '',
},
}}
onCancel={hideUpdateModal}
onSave={onUpdatedFromMarketplace}
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
/>
)}
</>
)
}
export default HeaderModals

View File

@ -1,2 +0,0 @@
export { default as HeaderModals } from './header-modals'
export { default as PluginSourceBadge } from './plugin-source-badge'

View File

@ -1,200 +0,0 @@
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginSource } from '../../../types'
import PluginSourceBadge from './plugin-source-badge'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/app/components/base/tooltip', () => ({
default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => (
<div data-testid="tooltip" data-content={popupContent}>
{children}
</div>
),
}))
describe('PluginSourceBadge', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Source Icon Rendering', () => {
it('should render marketplace source badge', () => {
render(<PluginSourceBadge source={PluginSource.marketplace} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toBeInTheDocument()
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.marketplace')
})
it('should render github source badge', () => {
render(<PluginSourceBadge source={PluginSource.github} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toBeInTheDocument()
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.github')
})
it('should render local source badge', () => {
render(<PluginSourceBadge source={PluginSource.local} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toBeInTheDocument()
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.local')
})
it('should render debugging source badge', () => {
render(<PluginSourceBadge source={PluginSource.debugging} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toBeInTheDocument()
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.debugging')
})
})
describe('Separator Rendering', () => {
it('should render separator dot before marketplace badge', () => {
const { container } = render(<PluginSourceBadge source={PluginSource.marketplace} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).toBeInTheDocument()
expect(separator?.textContent).toBe('·')
})
it('should render separator dot before github badge', () => {
const { container } = render(<PluginSourceBadge source={PluginSource.github} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).toBeInTheDocument()
expect(separator?.textContent).toBe('·')
})
it('should render separator dot before local badge', () => {
const { container } = render(<PluginSourceBadge source={PluginSource.local} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).toBeInTheDocument()
})
it('should render separator dot before debugging badge', () => {
const { container } = render(<PluginSourceBadge source={PluginSource.debugging} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).toBeInTheDocument()
})
})
describe('Tooltip Content', () => {
it('should show marketplace tooltip', () => {
render(<PluginSourceBadge source={PluginSource.marketplace} />)
expect(screen.getByTestId('tooltip')).toHaveAttribute(
'data-content',
'detailPanel.categoryTip.marketplace',
)
})
it('should show github tooltip', () => {
render(<PluginSourceBadge source={PluginSource.github} />)
expect(screen.getByTestId('tooltip')).toHaveAttribute(
'data-content',
'detailPanel.categoryTip.github',
)
})
it('should show local tooltip', () => {
render(<PluginSourceBadge source={PluginSource.local} />)
expect(screen.getByTestId('tooltip')).toHaveAttribute(
'data-content',
'detailPanel.categoryTip.local',
)
})
it('should show debugging tooltip', () => {
render(<PluginSourceBadge source={PluginSource.debugging} />)
expect(screen.getByTestId('tooltip')).toHaveAttribute(
'data-content',
'detailPanel.categoryTip.debugging',
)
})
})
describe('Icon Element Structure', () => {
it('should render icon inside tooltip for marketplace', () => {
render(<PluginSourceBadge source={PluginSource.marketplace} />)
const tooltip = screen.getByTestId('tooltip')
const iconWrapper = tooltip.querySelector('div')
expect(iconWrapper).toBeInTheDocument()
})
it('should render icon inside tooltip for github', () => {
render(<PluginSourceBadge source={PluginSource.github} />)
const tooltip = screen.getByTestId('tooltip')
const iconWrapper = tooltip.querySelector('div')
expect(iconWrapper).toBeInTheDocument()
})
it('should render icon inside tooltip for local', () => {
render(<PluginSourceBadge source={PluginSource.local} />)
const tooltip = screen.getByTestId('tooltip')
const iconWrapper = tooltip.querySelector('div')
expect(iconWrapper).toBeInTheDocument()
})
it('should render icon inside tooltip for debugging', () => {
render(<PluginSourceBadge source={PluginSource.debugging} />)
const tooltip = screen.getByTestId('tooltip')
const iconWrapper = tooltip.querySelector('div')
expect(iconWrapper).toBeInTheDocument()
})
})
describe('Lookup Table Coverage', () => {
it('should handle all PluginSource enum values', () => {
const allSources = Object.values(PluginSource)
allSources.forEach((source) => {
const { container } = render(<PluginSourceBadge source={source} />)
// Should render either tooltip or nothing
expect(container).toBeTruthy()
})
})
})
describe('Invalid Source Handling', () => {
it('should return null for unknown source type', () => {
// Use type assertion to test invalid source value
const invalidSource = 'unknown_source' as PluginSource
const { container } = render(<PluginSourceBadge source={invalidSource} />)
// Should render nothing (empty container)
expect(container.firstChild).toBeNull()
})
it('should not render separator for invalid source', () => {
const invalidSource = 'invalid' as PluginSource
const { container } = render(<PluginSourceBadge source={invalidSource} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).not.toBeInTheDocument()
})
it('should not render tooltip for invalid source', () => {
const invalidSource = '' as PluginSource
render(<PluginSourceBadge source={invalidSource} />)
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument()
})
})
})

View File

@ -1,59 +0,0 @@
'use client'
import type { FC, ReactNode } from 'react'
import {
RiBugLine,
RiHardDrive3Line,
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { Github } from '@/app/components/base/icons/src/public/common'
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
import Tooltip from '@/app/components/base/tooltip'
import { PluginSource } from '../../../types'
type SourceConfig = {
icon: ReactNode
tipKey: string
}
type PluginSourceBadgeProps = {
source: PluginSource
}
const SOURCE_CONFIG_MAP: Record<PluginSource, SourceConfig | null> = {
[PluginSource.marketplace]: {
icon: <BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" />,
tipKey: 'detailPanel.categoryTip.marketplace',
},
[PluginSource.github]: {
icon: <Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" />,
tipKey: 'detailPanel.categoryTip.github',
},
[PluginSource.local]: {
icon: <RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" />,
tipKey: 'detailPanel.categoryTip.local',
},
[PluginSource.debugging]: {
icon: <RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" />,
tipKey: 'detailPanel.categoryTip.debugging',
},
}
const PluginSourceBadge: FC<PluginSourceBadgeProps> = ({ source }) => {
const { t } = useTranslation()
const config = SOURCE_CONFIG_MAP[source]
if (!config)
return null
return (
<>
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
<Tooltip popupContent={t(config.tipKey as never, { ns: 'plugin' })}>
<div>{config.icon}</div>
</Tooltip>
</>
)
}
export default PluginSourceBadge

View File

@ -1,3 +0,0 @@
export { useDetailHeaderState } from './use-detail-header-state'
export type { ModalStates, UseDetailHeaderStateReturn, VersionPickerState, VersionTarget } from './use-detail-header-state'
export { usePluginOperations } from './use-plugin-operations'

View File

@ -1,409 +0,0 @@
import type { PluginDetail } from '../../../types'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginSource } from '../../../types'
import { useDetailHeaderState } from './use-detail-header-state'
let mockEnableMarketplace = true
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) =>
selector({ systemFeatures: { enable_marketplace: mockEnableMarketplace } }),
}))
let mockAutoUpgradeInfo: {
strategy_setting: string
upgrade_mode: string
include_plugins: string[]
exclude_plugins: string[]
upgrade_time_of_day: number
} | null = null
vi.mock('../../../plugin-page/use-reference-setting', () => ({
default: () => ({
referenceSetting: mockAutoUpgradeInfo ? { auto_upgrade: mockAutoUpgradeInfo } : null,
}),
}))
vi.mock('../../../reference-setting-modal/auto-update-setting/types', () => ({
AUTO_UPDATE_MODE: {
update_all: 'update_all',
partial: 'partial',
exclude: 'exclude',
},
}))
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
author: 'test-author',
name: 'test-plugin-name',
category: 'tool',
label: { en_US: 'Test Plugin Label' },
description: { en_US: 'Test description' },
icon: 'icon.png',
verified: true,
} as unknown as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '1.0.0',
latest_unique_identifier: 'test-uid',
source: PluginSource.marketplace,
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
...overrides,
})
describe('useDetailHeaderState', () => {
beforeEach(() => {
vi.clearAllMocks()
mockAutoUpgradeInfo = null
mockEnableMarketplace = true
})
describe('Source Type Detection', () => {
it('should detect marketplace source', () => {
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isFromMarketplace).toBe(true)
expect(result.current.isFromGitHub).toBe(false)
})
it('should detect GitHub source', () => {
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isFromGitHub).toBe(true)
expect(result.current.isFromMarketplace).toBe(false)
})
it('should detect local source', () => {
const detail = createPluginDetail({ source: PluginSource.local })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isFromGitHub).toBe(false)
expect(result.current.isFromMarketplace).toBe(false)
})
})
describe('Version State', () => {
it('should detect new version available for marketplace plugin', () => {
const detail = createPluginDetail({
version: '1.0.0',
latest_version: '2.0.0',
source: PluginSource.marketplace,
})
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.hasNewVersion).toBe(true)
})
it('should not detect new version when versions match', () => {
const detail = createPluginDetail({
version: '1.0.0',
latest_version: '1.0.0',
source: PluginSource.marketplace,
})
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.hasNewVersion).toBe(false)
})
it('should not detect new version for non-marketplace source', () => {
const detail = createPluginDetail({
version: '1.0.0',
latest_version: '2.0.0',
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.hasNewVersion).toBe(false)
})
it('should not detect new version when latest_version is empty', () => {
const detail = createPluginDetail({
version: '1.0.0',
latest_version: '',
source: PluginSource.marketplace,
})
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.hasNewVersion).toBe(false)
})
})
describe('Version Picker State', () => {
it('should initialize version picker as hidden', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.versionPicker.isShow).toBe(false)
})
it('should toggle version picker visibility', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
act(() => {
result.current.versionPicker.setIsShow(true)
})
expect(result.current.versionPicker.isShow).toBe(true)
act(() => {
result.current.versionPicker.setIsShow(false)
})
expect(result.current.versionPicker.isShow).toBe(false)
})
it('should update target version', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
act(() => {
result.current.versionPicker.setTargetVersion({
version: '2.0.0',
unique_identifier: 'new-uid',
})
})
expect(result.current.versionPicker.targetVersion.version).toBe('2.0.0')
expect(result.current.versionPicker.targetVersion.unique_identifier).toBe('new-uid')
})
it('should set isDowngrade when provided in target version', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
act(() => {
result.current.versionPicker.setTargetVersion({
version: '0.5.0',
unique_identifier: 'old-uid',
isDowngrade: true,
})
})
expect(result.current.versionPicker.isDowngrade).toBe(true)
})
})
describe('Modal States', () => {
it('should initialize all modals as hidden', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.modalStates.isShowUpdateModal).toBe(false)
expect(result.current.modalStates.isShowPluginInfo).toBe(false)
expect(result.current.modalStates.isShowDeleteConfirm).toBe(false)
expect(result.current.modalStates.deleting).toBe(false)
})
it('should toggle update modal', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
act(() => {
result.current.modalStates.showUpdateModal()
})
expect(result.current.modalStates.isShowUpdateModal).toBe(true)
act(() => {
result.current.modalStates.hideUpdateModal()
})
expect(result.current.modalStates.isShowUpdateModal).toBe(false)
})
it('should toggle plugin info modal', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
act(() => {
result.current.modalStates.showPluginInfo()
})
expect(result.current.modalStates.isShowPluginInfo).toBe(true)
act(() => {
result.current.modalStates.hidePluginInfo()
})
expect(result.current.modalStates.isShowPluginInfo).toBe(false)
})
it('should toggle delete confirm modal', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
act(() => {
result.current.modalStates.showDeleteConfirm()
})
expect(result.current.modalStates.isShowDeleteConfirm).toBe(true)
act(() => {
result.current.modalStates.hideDeleteConfirm()
})
expect(result.current.modalStates.isShowDeleteConfirm).toBe(false)
})
it('should toggle deleting state', () => {
const detail = createPluginDetail()
const { result } = renderHook(() => useDetailHeaderState(detail))
act(() => {
result.current.modalStates.showDeleting()
})
expect(result.current.modalStates.deleting).toBe(true)
act(() => {
result.current.modalStates.hideDeleting()
})
expect(result.current.modalStates.deleting).toBe(false)
})
})
describe('Auto Upgrade Detection', () => {
it('should disable auto upgrade when marketplace is disabled', () => {
mockEnableMarketplace = false
mockAutoUpgradeInfo = {
strategy_setting: 'enabled',
upgrade_mode: 'update_all',
include_plugins: [],
exclude_plugins: [],
upgrade_time_of_day: 36000,
}
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(false)
})
it('should disable auto upgrade when strategy is disabled', () => {
mockAutoUpgradeInfo = {
strategy_setting: 'disabled',
upgrade_mode: 'update_all',
include_plugins: [],
exclude_plugins: [],
upgrade_time_of_day: 36000,
}
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(false)
})
it('should enable auto upgrade for update_all mode', () => {
mockAutoUpgradeInfo = {
strategy_setting: 'enabled',
upgrade_mode: 'update_all',
include_plugins: [],
exclude_plugins: [],
upgrade_time_of_day: 36000,
}
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(true)
})
it('should enable auto upgrade for partial mode when plugin is included', () => {
mockAutoUpgradeInfo = {
strategy_setting: 'enabled',
upgrade_mode: 'partial',
include_plugins: ['test-plugin'],
exclude_plugins: [],
upgrade_time_of_day: 36000,
}
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(true)
})
it('should disable auto upgrade for partial mode when plugin is not included', () => {
mockAutoUpgradeInfo = {
strategy_setting: 'enabled',
upgrade_mode: 'partial',
include_plugins: ['other-plugin'],
exclude_plugins: [],
upgrade_time_of_day: 36000,
}
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(false)
})
it('should enable auto upgrade for exclude mode when plugin is not excluded', () => {
mockAutoUpgradeInfo = {
strategy_setting: 'enabled',
upgrade_mode: 'exclude',
include_plugins: [],
exclude_plugins: ['other-plugin'],
upgrade_time_of_day: 36000,
}
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(true)
})
it('should disable auto upgrade for exclude mode when plugin is excluded', () => {
mockAutoUpgradeInfo = {
strategy_setting: 'enabled',
upgrade_mode: 'exclude',
include_plugins: [],
exclude_plugins: ['test-plugin'],
upgrade_time_of_day: 36000,
}
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(false)
})
it('should disable auto upgrade for non-marketplace source', () => {
mockAutoUpgradeInfo = {
strategy_setting: 'enabled',
upgrade_mode: 'update_all',
include_plugins: [],
exclude_plugins: [],
upgrade_time_of_day: 36000,
}
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(false)
})
it('should disable auto upgrade when no auto upgrade info', () => {
mockAutoUpgradeInfo = null
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() => useDetailHeaderState(detail))
expect(result.current.isAutoUpgradeEnabled).toBe(false)
})
})
})

View File

@ -1,132 +0,0 @@
'use client'
import type { PluginDetail } from '../../../types'
import { useBoolean } from 'ahooks'
import { useCallback, useMemo, useState } from 'react'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useReferenceSetting from '../../../plugin-page/use-reference-setting'
import { AUTO_UPDATE_MODE } from '../../../reference-setting-modal/auto-update-setting/types'
import { PluginSource } from '../../../types'
export type VersionTarget = {
version: string | undefined
unique_identifier: string | undefined
isDowngrade?: boolean
}
export type ModalStates = {
isShowUpdateModal: boolean
showUpdateModal: () => void
hideUpdateModal: () => void
isShowPluginInfo: boolean
showPluginInfo: () => void
hidePluginInfo: () => void
isShowDeleteConfirm: boolean
showDeleteConfirm: () => void
hideDeleteConfirm: () => void
deleting: boolean
showDeleting: () => void
hideDeleting: () => void
}
export type VersionPickerState = {
isShow: boolean
setIsShow: (show: boolean) => void
targetVersion: VersionTarget
setTargetVersion: (version: VersionTarget) => void
isDowngrade: boolean
setIsDowngrade: (downgrade: boolean) => void
}
export type UseDetailHeaderStateReturn = {
modalStates: ModalStates
versionPicker: VersionPickerState
hasNewVersion: boolean
isAutoUpgradeEnabled: boolean
isFromGitHub: boolean
isFromMarketplace: boolean
}
export const useDetailHeaderState = (detail: PluginDetail): UseDetailHeaderStateReturn => {
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
const { referenceSetting } = useReferenceSetting()
const {
source,
version,
latest_version,
latest_unique_identifier,
plugin_id,
} = detail
const isFromGitHub = source === PluginSource.github
const isFromMarketplace = source === PluginSource.marketplace
const [isShow, setIsShow] = useState(false)
const [targetVersion, setTargetVersion] = useState<VersionTarget>({
version: latest_version,
unique_identifier: latest_unique_identifier,
})
const [isDowngrade, setIsDowngrade] = useState(false)
const [isShowUpdateModal, { setTrue: showUpdateModal, setFalse: hideUpdateModal }] = useBoolean(false)
const [isShowPluginInfo, { setTrue: showPluginInfo, setFalse: hidePluginInfo }] = useBoolean(false)
const [isShowDeleteConfirm, { setTrue: showDeleteConfirm, setFalse: hideDeleteConfirm }] = useBoolean(false)
const [deleting, { setTrue: showDeleting, setFalse: hideDeleting }] = useBoolean(false)
const hasNewVersion = useMemo(() => {
if (isFromMarketplace)
return !!latest_version && latest_version !== version
return false
}, [isFromMarketplace, latest_version, version])
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
const isAutoUpgradeEnabled = useMemo(() => {
if (!enable_marketplace || !autoUpgradeInfo || !isFromMarketplace)
return false
if (autoUpgradeInfo.strategy_setting === 'disabled')
return false
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
return true
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
return true
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
return true
return false
}, [autoUpgradeInfo, plugin_id, isFromMarketplace, enable_marketplace])
const handleSetTargetVersion = useCallback((version: VersionTarget) => {
setTargetVersion(version)
if (version.isDowngrade !== undefined)
setIsDowngrade(version.isDowngrade)
}, [])
return {
modalStates: {
isShowUpdateModal,
showUpdateModal,
hideUpdateModal,
isShowPluginInfo,
showPluginInfo,
hidePluginInfo,
isShowDeleteConfirm,
showDeleteConfirm,
hideDeleteConfirm,
deleting,
showDeleting,
hideDeleting,
},
versionPicker: {
isShow,
setIsShow,
targetVersion,
setTargetVersion: handleSetTargetVersion,
isDowngrade,
setIsDowngrade,
},
hasNewVersion,
isAutoUpgradeEnabled,
isFromGitHub,
isFromMarketplace,
}
}

View File

@ -1,549 +0,0 @@
import type { PluginDetail } from '../../../types'
import type { ModalStates, VersionTarget } from './use-detail-header-state'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import * as amplitude from '@/app/components/base/amplitude'
import Toast from '@/app/components/base/toast'
import { PluginSource } from '../../../types'
import { usePluginOperations } from './use-plugin-operations'
type VersionPickerMock = {
setTargetVersion: (version: VersionTarget) => void
setIsDowngrade: (downgrade: boolean) => void
}
const {
mockSetShowUpdatePluginModal,
mockRefreshModelProviders,
mockInvalidateAllToolProviders,
mockUninstallPlugin,
mockFetchReleases,
mockCheckForUpdates,
} = vi.hoisted(() => {
return {
mockSetShowUpdatePluginModal: vi.fn(),
mockRefreshModelProviders: vi.fn(),
mockInvalidateAllToolProviders: vi.fn(),
mockUninstallPlugin: vi.fn(() => Promise.resolve({ success: true })),
mockFetchReleases: vi.fn(() => Promise.resolve([{ tag_name: 'v2.0.0' }])),
mockCheckForUpdates: vi.fn(() => ({ needUpdate: true, toastProps: { type: 'success', message: 'Update available' } })),
}
})
vi.mock('@/context/modal-context', () => ({
useModalContext: () => ({
setShowUpdatePluginModal: mockSetShowUpdatePluginModal,
}),
}))
vi.mock('@/context/provider-context', () => ({
useProviderContext: () => ({
refreshModelProviders: mockRefreshModelProviders,
}),
}))
vi.mock('@/service/plugins', () => ({
uninstallPlugin: mockUninstallPlugin,
}))
vi.mock('@/service/use-tools', () => ({
useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders,
}))
vi.mock('../../../install-plugin/hooks', () => ({
useGitHubReleases: () => ({
checkForUpdates: mockCheckForUpdates,
fetchReleases: mockFetchReleases,
}),
}))
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
author: 'test-author',
name: 'test-plugin-name',
category: 'tool',
label: { en_US: 'Test Plugin Label' },
description: { en_US: 'Test description' },
icon: 'icon.png',
verified: true,
} as unknown as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '2.0.0',
latest_unique_identifier: 'new-uid',
source: PluginSource.marketplace,
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
...overrides,
})
const createModalStatesMock = (): ModalStates => ({
isShowUpdateModal: false,
showUpdateModal: vi.fn(),
hideUpdateModal: vi.fn(),
isShowPluginInfo: false,
showPluginInfo: vi.fn(),
hidePluginInfo: vi.fn(),
isShowDeleteConfirm: false,
showDeleteConfirm: vi.fn(),
hideDeleteConfirm: vi.fn(),
deleting: false,
showDeleting: vi.fn(),
hideDeleting: vi.fn(),
})
const createVersionPickerMock = (): VersionPickerMock => ({
setTargetVersion: vi.fn<(version: VersionTarget) => void>(),
setIsDowngrade: vi.fn<(downgrade: boolean) => void>(),
})
describe('usePluginOperations', () => {
let modalStates: ModalStates
let versionPicker: VersionPickerMock
let mockOnUpdate: (isDelete?: boolean) => void
beforeEach(() => {
vi.clearAllMocks()
modalStates = createModalStatesMock()
versionPicker = createVersionPickerMock()
mockOnUpdate = vi.fn()
vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() }))
vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {})
})
describe('Marketplace Update Flow', () => {
it('should show update modal for marketplace plugin', async () => {
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate()
})
expect(modalStates.showUpdateModal).toHaveBeenCalled()
})
it('should set isDowngrade when downgrading', async () => {
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate(true)
})
expect(versionPicker.setIsDowngrade).toHaveBeenCalledWith(true)
expect(modalStates.showUpdateModal).toHaveBeenCalled()
})
it('should call onUpdate and hide modal on successful marketplace update', () => {
const detail = createPluginDetail({ source: PluginSource.marketplace })
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
act(() => {
result.current.handleUpdatedFromMarketplace()
})
expect(mockOnUpdate).toHaveBeenCalled()
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
})
})
describe('GitHub Update Flow', () => {
it('should fetch releases from GitHub', async () => {
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: false,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate()
})
expect(mockFetchReleases).toHaveBeenCalledWith('owner', 'repo')
})
it('should check for updates after fetching releases', async () => {
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: false,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate()
})
expect(mockCheckForUpdates).toHaveBeenCalled()
expect(Toast.notify).toHaveBeenCalled()
})
it('should show update plugin modal when update is needed', async () => {
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: false,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate()
})
expect(mockSetShowUpdatePluginModal).toHaveBeenCalled()
})
it('should not show modal when no releases found', async () => {
mockFetchReleases.mockResolvedValueOnce([])
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: false,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate()
})
expect(mockSetShowUpdatePluginModal).not.toHaveBeenCalled()
})
it('should not show modal when no update needed', async () => {
mockCheckForUpdates.mockReturnValueOnce({
needUpdate: false,
toastProps: { type: 'info', message: 'Already up to date' },
})
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: false,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate()
})
expect(mockSetShowUpdatePluginModal).not.toHaveBeenCalled()
})
it('should use author and name as fallback for repo parsing', async () => {
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: '/', version: 'v1.0.0', package: 'pkg' },
declaration: {
author: 'fallback-author',
name: 'fallback-name',
category: 'tool',
label: { en_US: 'Test' },
description: { en_US: 'Test' },
icon: 'icon.png',
verified: true,
} as unknown as PluginDetail['declaration'],
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: false,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate()
})
expect(mockFetchReleases).toHaveBeenCalledWith('fallback-author', 'fallback-name')
})
})
describe('Delete Flow', () => {
it('should call uninstallPlugin with correct id', async () => {
const detail = createPluginDetail({ id: 'plugin-to-delete' })
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleDelete()
})
expect(mockUninstallPlugin).toHaveBeenCalledWith('plugin-to-delete')
})
it('should show and hide deleting state during delete', async () => {
const detail = createPluginDetail()
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleDelete()
})
expect(modalStates.showDeleting).toHaveBeenCalled()
expect(modalStates.hideDeleting).toHaveBeenCalled()
})
it('should call onUpdate with true after successful delete', async () => {
const detail = createPluginDetail()
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleDelete()
})
expect(mockOnUpdate).toHaveBeenCalledWith(true)
})
it('should hide delete confirm after successful delete', async () => {
const detail = createPluginDetail()
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleDelete()
})
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
})
it('should refresh model providers when deleting model plugin', async () => {
const detail = createPluginDetail({
declaration: {
author: 'test-author',
name: 'test-plugin-name',
category: 'model',
label: { en_US: 'Test' },
description: { en_US: 'Test' },
icon: 'icon.png',
verified: true,
} as unknown as PluginDetail['declaration'],
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleDelete()
})
expect(mockRefreshModelProviders).toHaveBeenCalled()
})
it('should invalidate tool providers when deleting tool plugin', async () => {
const detail = createPluginDetail({
declaration: {
author: 'test-author',
name: 'test-plugin-name',
category: 'tool',
label: { en_US: 'Test' },
description: { en_US: 'Test' },
icon: 'icon.png',
verified: true,
} as unknown as PluginDetail['declaration'],
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleDelete()
})
expect(mockInvalidateAllToolProviders).toHaveBeenCalled()
})
it('should track plugin uninstalled event', async () => {
const detail = createPluginDetail()
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleDelete()
})
expect(amplitude.trackEvent).toHaveBeenCalledWith('plugin_uninstalled', expect.objectContaining({
plugin_id: 'test-plugin',
plugin_name: 'test-plugin-name',
}))
})
it('should not call onUpdate when delete fails', async () => {
mockUninstallPlugin.mockResolvedValueOnce({ success: false })
const detail = createPluginDetail()
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleDelete()
})
expect(mockOnUpdate).not.toHaveBeenCalled()
})
})
describe('Optional onUpdate Callback', () => {
it('should not throw when onUpdate is not provided for marketplace update', () => {
const detail = createPluginDetail()
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
}),
)
expect(() => {
result.current.handleUpdatedFromMarketplace()
}).not.toThrow()
})
it('should not throw when onUpdate is not provided for delete', async () => {
const detail = createPluginDetail()
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: true,
}),
)
await expect(
act(async () => {
await result.current.handleDelete()
}),
).resolves.not.toThrow()
})
})
})

View File

@ -1,143 +0,0 @@
'use client'
import type { PluginDetail } from '../../../types'
import type { ModalStates, VersionTarget } from './use-detail-header-state'
import { useCallback } from 'react'
import { trackEvent } from '@/app/components/base/amplitude'
import Toast from '@/app/components/base/toast'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
import { uninstallPlugin } from '@/service/plugins'
import { useInvalidateAllToolProviders } from '@/service/use-tools'
import { useGitHubReleases } from '../../../install-plugin/hooks'
import { PluginCategoryEnum, PluginSource } from '../../../types'
type UsePluginOperationsParams = {
detail: PluginDetail
modalStates: ModalStates
versionPicker: {
setTargetVersion: (version: VersionTarget) => void
setIsDowngrade: (downgrade: boolean) => void
}
isFromMarketplace: boolean
onUpdate?: (isDelete?: boolean) => void
}
type UsePluginOperationsReturn = {
handleUpdate: (isDowngrade?: boolean) => Promise<void>
handleUpdatedFromMarketplace: () => void
handleDelete: () => Promise<void>
}
export const usePluginOperations = ({
detail,
modalStates,
versionPicker,
isFromMarketplace,
onUpdate,
}: UsePluginOperationsParams): UsePluginOperationsReturn => {
const { checkForUpdates, fetchReleases } = useGitHubReleases()
const { setShowUpdatePluginModal } = useModalContext()
const { refreshModelProviders } = useProviderContext()
const invalidateAllToolProviders = useInvalidateAllToolProviders()
const { id, meta, plugin_id } = detail
const { author, category, name } = detail.declaration || detail
const handleUpdate = useCallback(async (isDowngrade?: boolean) => {
if (isFromMarketplace) {
versionPicker.setIsDowngrade(!!isDowngrade)
modalStates.showUpdateModal()
return
}
if (!meta?.repo || !meta?.version || !meta?.package) {
Toast.notify({
type: 'error',
message: 'Missing plugin metadata for GitHub update',
})
return
}
const owner = meta.repo.split('/')[0] || author
const repo = meta.repo.split('/')[1] || name
const fetchedReleases = await fetchReleases(owner, repo)
if (fetchedReleases.length === 0)
return
const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta.version)
Toast.notify(toastProps)
if (needUpdate) {
setShowUpdatePluginModal({
onSaveCallback: () => {
onUpdate?.()
},
payload: {
type: PluginSource.github,
category,
github: {
originalPackageInfo: {
id: detail.plugin_unique_identifier,
repo: meta.repo,
version: meta.version,
package: meta.package,
releases: fetchedReleases,
},
},
},
})
}
}, [
isFromMarketplace,
meta,
author,
name,
fetchReleases,
checkForUpdates,
setShowUpdatePluginModal,
detail,
onUpdate,
modalStates,
versionPicker,
])
const handleUpdatedFromMarketplace = useCallback(() => {
onUpdate?.()
modalStates.hideUpdateModal()
}, [onUpdate, modalStates])
const handleDelete = useCallback(async () => {
modalStates.showDeleting()
const res = await uninstallPlugin(id)
modalStates.hideDeleting()
if (res.success) {
modalStates.hideDeleteConfirm()
onUpdate?.(true)
if (PluginCategoryEnum.model.includes(category))
refreshModelProviders()
if (PluginCategoryEnum.tool.includes(category))
invalidateAllToolProviders()
trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name })
}
}, [
id,
category,
plugin_id,
name,
modalStates,
onUpdate,
refreshModelProviders,
invalidateAllToolProviders,
])
return {
handleUpdate,
handleUpdatedFromMarketplace,
handleDelete,
}
}

View File

@ -1,286 +0,0 @@
'use client'
import type { PluginDetail } from '../../types'
import {
RiArrowLeftRightLine,
RiCloseLine,
} from '@remixicon/react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import ActionButton from '@/app/components/base/action-button'
import Badge from '@/app/components/base/badge'
import Button from '@/app/components/base/button'
import Tooltip from '@/app/components/base/tooltip'
import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth'
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
import { API_PREFIX } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useGetLanguage, useLocale } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { useAllToolProviders } from '@/service/use-tools'
import { cn } from '@/utils/classnames'
import { getMarketplaceUrl } from '@/utils/var'
import { AutoUpdateLine } from '../../../base/icons/src/vender/system'
import Verified from '../../base/badges/verified'
import DeprecationNotice from '../../base/deprecation-notice'
import Icon from '../../card/base/card-icon'
import Description from '../../card/base/description'
import OrgInfo from '../../card/base/org-info'
import Title from '../../card/base/title'
import useReferenceSetting from '../../plugin-page/use-reference-setting'
import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../../reference-setting-modal/auto-update-setting/utils'
import { PluginCategoryEnum, PluginSource } from '../../types'
import { HeaderModals, PluginSourceBadge } from './components'
import { useDetailHeaderState, usePluginOperations } from './hooks'
type Props = {
detail: PluginDetail
isReadmeView?: boolean
onHide?: () => void
onUpdate?: (isDelete?: boolean) => void
}
const getIconSrc = (icon: string | undefined, iconDark: string | undefined, theme: string, tenantId: string): string => {
const iconFileName = theme === 'dark' && iconDark ? iconDark : icon
if (!iconFileName)
return ''
return iconFileName.startsWith('http')
? iconFileName
: `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenantId}&filename=${iconFileName}`
}
const getDetailUrl = (
source: PluginSource,
meta: PluginDetail['meta'],
author: string,
name: string,
locale: string,
theme: string,
): string => {
if (source === PluginSource.github) {
const repo = meta?.repo
if (!repo)
return ''
return `https://github.com/${repo}`
}
if (source === PluginSource.marketplace)
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: locale, theme })
return ''
}
const DetailHeader = ({
detail,
isReadmeView = false,
onHide,
onUpdate,
}: Props) => {
const { t } = useTranslation()
const { userProfile: { timezone } } = useAppContext()
const { theme } = useTheme()
const locale = useGetLanguage()
const currentLocale = useLocale()
const { referenceSetting } = useReferenceSetting()
const {
source,
tenant_id,
version,
latest_version,
latest_unique_identifier,
meta,
plugin_id,
status,
deprecated_reason,
alternative_plugin_id,
} = detail
const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail
const {
modalStates,
versionPicker,
hasNewVersion,
isAutoUpgradeEnabled,
isFromGitHub,
isFromMarketplace,
} = useDetailHeaderState(detail)
const {
handleUpdate,
handleUpdatedFromMarketplace,
handleDelete,
} = usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace,
onUpdate,
})
const isTool = category === PluginCategoryEnum.tool
const providerBriefInfo = tool?.identity
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
const { data: collectionList = [] } = useAllToolProviders(isTool)
const provider = useMemo(() => {
return collectionList.find(collection => collection.name === providerKey)
}, [collectionList, providerKey])
const iconSrc = getIconSrc(icon, icon_dark, theme, tenant_id)
const detailUrl = getDetailUrl(source, meta, author, name, currentLocale, theme)
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
const handleVersionSelect = (state: { version: string, unique_identifier: string, isDowngrade?: boolean }) => {
versionPicker.setTargetVersion(state)
handleUpdate(state.isDowngrade)
}
const handleTriggerLatestUpdate = () => {
if (isFromMarketplace) {
versionPicker.setTargetVersion({
version: latest_version,
unique_identifier: latest_unique_identifier,
})
}
handleUpdate()
}
return (
<div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}>
<div className="flex">
{/* Plugin Icon */}
<div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}>
<Icon src={iconSrc} />
</div>
{/* Plugin Info */}
<div className="ml-3 w-0 grow">
{/* Title Row */}
<div className="flex h-5 items-center">
<Title title={label[locale]} />
{verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />}
{/* Version Picker */}
{!!version && (
<PluginVersionPicker
disabled={!isFromMarketplace || isReadmeView}
isShow={versionPicker.isShow}
onShowChange={versionPicker.setIsShow}
pluginID={plugin_id}
currentVersion={version}
onSelect={handleVersionSelect}
trigger={(
<Badge
className={cn(
'mx-1',
versionPicker.isShow && 'bg-state-base-hover',
(versionPicker.isShow || isFromMarketplace) && 'hover:bg-state-base-hover',
)}
uppercase={false}
text={(
<>
<div>{isFromGitHub ? (meta?.version ?? version ?? '') : version}</div>
{isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />}
</>
)}
hasRedCornerMark={hasNewVersion}
/>
)}
/>
)}
{/* Auto Update Badge */}
{isAutoUpgradeEnabled && !isReadmeView && (
<Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}>
<div>
<Badge className="mr-1 cursor-pointer px-1">
<AutoUpdateLine className="size-3" />
</Badge>
</div>
</Tooltip>
)}
{/* Update Button */}
{(hasNewVersion || isFromGitHub) && (
<Button
variant="secondary-accent"
size="small"
className="!h-5"
onClick={handleTriggerLatestUpdate}
>
{t('detailPanel.operation.update', { ns: 'plugin' })}
</Button>
)}
</div>
{/* Org Info Row */}
<div className="mb-1 flex h-4 items-center justify-between">
<div className="mt-0.5 flex items-center">
<OrgInfo
packageNameClassName="w-auto"
orgName={author}
packageName={name?.includes('/') ? (name.split('/').pop() || '') : name}
/>
{!!source && <PluginSourceBadge source={source} />}
</div>
</div>
</div>
{/* Action Buttons */}
{!isReadmeView && (
<div className="flex gap-1">
<OperationDropdown
source={source}
onInfo={modalStates.showPluginInfo}
onCheckVersion={handleUpdate}
onRemove={modalStates.showDeleteConfirm}
detailUrl={detailUrl}
/>
<ActionButton onClick={onHide}>
<RiCloseLine className="h-4 w-4" />
</ActionButton>
</div>
)}
</div>
{/* Deprecation Notice */}
{isFromMarketplace && (
<DeprecationNotice
status={status}
deprecatedReason={deprecated_reason}
alternativePluginId={alternative_plugin_id}
alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })}
className="mt-3"
/>
)}
{/* Description */}
{!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2} />}
{/* Plugin Auth for Tools */}
{category === PluginCategoryEnum.tool && !isReadmeView && (
<PluginAuth
pluginPayload={{
provider: provider?.name || '',
category: AuthCategory.tool,
providerType: provider?.type || '',
detail,
}}
/>
)}
{/* Modals */}
<HeaderModals
detail={detail}
modalStates={modalStates}
targetVersion={versionPicker.targetVersion}
isDowngrade={versionPicker.isDowngrade}
isAutoUpgradeEnabled={isAutoUpgradeEnabled}
onUpdatedFromMarketplace={handleUpdatedFromMarketplace}
onDelete={handleDelete}
/>
</div>
)
}
export default DetailHeader

View File

@ -2,10 +2,15 @@ import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
// Import after mocks
import { SupportedCreationMethods } from '@/app/components/plugins/types'
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
import { CommonCreateModal } from './common-modal'
// ============================================================================
// Type Definitions
// ============================================================================
type PluginDetail = {
plugin_id: string
provider: string
@ -28,6 +33,10 @@ type TriggerLogEntity = {
level: 'info' | 'warn' | 'error'
}
// ============================================================================
// Mock Factory Functions
// ============================================================================
function createMockPluginDetail(overrides: Partial<PluginDetail> = {}): PluginDetail {
return {
plugin_id: 'test-plugin-id',
@ -65,12 +74,18 @@ function createMockLogData(logs: TriggerLogEntity[] = []): { logs: TriggerLogEnt
return { logs }
}
// ============================================================================
// Mock Setup
// ============================================================================
// Mock plugin store
const mockPluginDetail = createMockPluginDetail()
const mockUsePluginStore = vi.fn(() => mockPluginDetail)
vi.mock('../../store', () => ({
usePluginStore: () => mockUsePluginStore(),
}))
// Mock subscription list hook
const mockRefetch = vi.fn()
vi.mock('../use-subscription-list', () => ({
useSubscriptionList: () => ({
@ -78,11 +93,13 @@ vi.mock('../use-subscription-list', () => ({
}),
}))
// Mock service hooks
const mockVerifyCredentials = vi.fn()
const mockCreateBuilder = vi.fn()
const mockBuildSubscription = vi.fn()
const mockUpdateBuilder = vi.fn()
// Configurable pending states
let mockIsVerifyingCredentials = false
let mockIsBuilding = false
const setMockPendingStates = (verifying: boolean, building: boolean) => {
@ -112,15 +129,18 @@ vi.mock('@/service/use-triggers', () => ({
}),
}))
// Mock error parser
const mockParsePluginErrorMessage = vi.fn().mockResolvedValue(null)
vi.mock('@/utils/error-parser', () => ({
parsePluginErrorMessage: (...args: unknown[]) => mockParsePluginErrorMessage(...args),
}))
// Mock URL validation
vi.mock('@/utils/urlValidation', () => ({
isPrivateOrLocalAddress: vi.fn().mockReturnValue(false),
}))
// Mock toast
const mockToastNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
default: {
@ -128,6 +148,7 @@ vi.mock('@/app/components/base/toast', () => ({
},
}))
// Mock Modal component
vi.mock('@/app/components/base/modal/modal', () => ({
default: ({
children,
@ -158,6 +179,7 @@ vi.mock('@/app/components/base/modal/modal', () => ({
),
}))
// Configurable form mock values
type MockFormValuesConfig = {
values: Record<string, unknown>
isCheckValidated: boolean
@ -168,6 +190,7 @@ let mockFormValuesConfig: MockFormValuesConfig = {
}
let mockGetFormReturnsNull = false
// Separate validation configs for different forms
let mockSubscriptionFormValidated = true
let mockAutoParamsFormValidated = true
let mockManualPropsFormValidated = true
@ -184,6 +207,7 @@ const setMockFormValidation = (subscription: boolean, autoParams: boolean, manua
mockManualPropsFormValidated = manualProps
}
// Mock BaseForm component with ref support
vi.mock('@/app/components/base/form/components/base', async () => {
const React = await import('react')
@ -195,6 +219,7 @@ vi.mock('@/app/components/base/form/components/base', async () => {
type MockBaseFormProps = { formSchemas: Array<{ name: string }>, onChange?: () => void }
function MockBaseFormInner({ formSchemas, onChange }: MockBaseFormProps, ref: React.ForwardedRef<MockFormRef>) {
// Determine which form this is based on schema
const isSubscriptionForm = formSchemas.some((s: { name: string }) => s.name === 'subscription_name')
const isAutoParamsForm = formSchemas.some((s: { name: string }) =>
['repo_name', 'branch', 'repo', 'text_field', 'dynamic_field', 'bool_field', 'text_input_field', 'unknown_field', 'count'].includes(s.name),
@ -240,10 +265,12 @@ vi.mock('@/app/components/base/form/components/base', async () => {
}
})
// Mock EncryptedBottom component
vi.mock('@/app/components/base/encrypted-bottom', () => ({
EncryptedBottom: () => <div data-testid="encrypted-bottom">Encrypted</div>,
}))
// Mock LogViewer component
vi.mock('../log-viewer', () => ({
default: ({ logs }: { logs: TriggerLogEntity[] }) => (
<div data-testid="log-viewer">
@ -254,6 +281,7 @@ vi.mock('../log-viewer', () => ({
),
}))
// Mock debounce
vi.mock('es-toolkit/compat', () => ({
debounce: (fn: (...args: unknown[]) => unknown) => {
const debouncedFn = (...args: unknown[]) => fn(...args)
@ -262,6 +290,10 @@ vi.mock('es-toolkit/compat', () => ({
},
}))
// ============================================================================
// Test Suites
// ============================================================================
describe('CommonCreateModal', () => {
const defaultProps = {
onClose: vi.fn(),
@ -409,8 +441,7 @@ describe('CommonCreateModal', () => {
})
it('should call onConfirm handler when confirm button is clicked', () => {
// Provide builder so the guard passes and credentials check is reached
render(<CommonCreateModal {...defaultProps} builder={createMockSubscriptionBuilder()} />)
render(<CommonCreateModal {...defaultProps} />)
fireEvent.click(screen.getByTestId('modal-confirm'))
@ -1212,22 +1243,13 @@ describe('CommonCreateModal', () => {
render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.MANUAL} />)
// Wait for createBuilder to complete and state to update
await waitFor(() => {
expect(mockCreateBuilder).toHaveBeenCalled()
})
// Allow React to process the state update from createBuilder
await act(async () => {})
const input = screen.getByTestId('form-field-webhook_url')
fireEvent.change(input, { target: { value: 'https://example.com/webhook' } })
// Wait for updateBuilder to be called, then check the toast
await waitFor(() => {
expect(mockUpdateBuilder).toHaveBeenCalled()
})
await waitFor(() => {
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'error',
@ -1428,8 +1450,7 @@ describe('CommonCreateModal', () => {
})
mockUsePluginStore.mockReturnValue(detailWithCredentials)
// Provide builder so the guard passes and credentials check is reached
render(<CommonCreateModal {...defaultProps} builder={createMockSubscriptionBuilder()} />)
render(<CommonCreateModal {...defaultProps} />)
fireEvent.click(screen.getByTestId('modal-confirm'))

View File

@ -1,19 +1,32 @@
'use client'
import type { FormRefObject } from '@/app/components/base/form/types'
import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers'
import { RiLoader2Line } from '@remixicon/react'
import { debounce } from 'es-toolkit/compat'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
// import { CopyFeedbackNew } from '@/app/components/base/copy-feedback'
import { EncryptedBottom } from '@/app/components/base/encrypted-bottom'
import { BaseForm } from '@/app/components/base/form/components/base'
import { FormTypeEnum } from '@/app/components/base/form/types'
import Modal from '@/app/components/base/modal/modal'
import Toast from '@/app/components/base/toast'
import { SupportedCreationMethods } from '@/app/components/plugins/types'
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
import {
ConfigurationStepContent,
MultiSteps,
VerifyStepContent,
} from './components/modal-steps'
import {
ApiKeyStep,
MODAL_TITLE_KEY_MAP,
useCommonModalState,
} from './hooks/use-common-modal-state'
useBuildTriggerSubscription,
useCreateTriggerSubscriptionBuilder,
useTriggerSubscriptionBuilderLogs,
useUpdateTriggerSubscriptionBuilder,
useVerifyAndUpdateTriggerSubscriptionBuilder,
} from '@/service/use-triggers'
import { parsePluginErrorMessage } from '@/utils/error-parser'
import { isPrivateOrLocalAddress } from '@/utils/urlValidation'
import { usePluginStore } from '../../store'
import LogViewer from '../log-viewer'
import { useSubscriptionList } from '../use-subscription-list'
type Props = {
onClose: () => void
@ -21,33 +34,316 @@ type Props = {
builder?: TriggerSubscriptionBuilder
}
const CREDENTIAL_TYPE_MAP: Record<SupportedCreationMethods, TriggerCredentialTypeEnum> = {
[SupportedCreationMethods.APIKEY]: TriggerCredentialTypeEnum.ApiKey,
[SupportedCreationMethods.OAUTH]: TriggerCredentialTypeEnum.Oauth2,
[SupportedCreationMethods.MANUAL]: TriggerCredentialTypeEnum.Unauthorized,
}
const MODAL_TITLE_KEY_MAP: Record<
SupportedCreationMethods,
'modal.apiKey.title' | 'modal.oauth.title' | 'modal.manual.title'
> = {
[SupportedCreationMethods.APIKEY]: 'modal.apiKey.title',
[SupportedCreationMethods.OAUTH]: 'modal.oauth.title',
[SupportedCreationMethods.MANUAL]: 'modal.manual.title',
}
enum ApiKeyStep {
Verify = 'verify',
Configuration = 'configuration',
}
const defaultFormValues = { values: {}, isCheckValidated: false }
const normalizeFormType = (type: FormTypeEnum | string): FormTypeEnum => {
if (Object.values(FormTypeEnum).includes(type as FormTypeEnum))
return type as FormTypeEnum
switch (type) {
case 'string':
case 'text':
return FormTypeEnum.textInput
case 'password':
case 'secret':
return FormTypeEnum.secretInput
case 'number':
case 'integer':
return FormTypeEnum.textNumber
case 'boolean':
return FormTypeEnum.boolean
default:
return FormTypeEnum.textInput
}
}
const StatusStep = ({ isActive, text }: { isActive: boolean, text: string }) => {
return (
<div className={`system-2xs-semibold-uppercase flex items-center gap-1 ${isActive
? 'text-state-accent-solid'
: 'text-text-tertiary'}`}
>
{/* Active indicator dot */}
{isActive && (
<div className="h-1 w-1 rounded-full bg-state-accent-solid"></div>
)}
{text}
</div>
)
}
const MultiSteps = ({ currentStep }: { currentStep: ApiKeyStep }) => {
const { t } = useTranslation()
return (
<div className="mb-6 flex w-1/3 items-center gap-2">
<StatusStep isActive={currentStep === ApiKeyStep.Verify} text={t('modal.steps.verify', { ns: 'pluginTrigger' })} />
<div className="h-px w-3 shrink-0 bg-divider-deep"></div>
<StatusStep isActive={currentStep === ApiKeyStep.Configuration} text={t('modal.steps.configuration', { ns: 'pluginTrigger' })} />
</div>
)
}
export const CommonCreateModal = ({ onClose, createType, builder }: Props) => {
const { t } = useTranslation()
const detail = usePluginStore(state => state.detail)
const { refetch } = useSubscriptionList()
const {
currentStep,
subscriptionBuilder,
isVerifyingCredentials,
isBuilding,
formRefs,
detail,
manualPropertiesSchema,
autoCommonParametersSchema,
apiKeyCredentialsSchema,
logData,
confirmButtonText,
handleConfirm,
handleManualPropertiesChange,
handleApiKeyCredentialsChange,
} = useCommonModalState({
createType,
builder,
onClose,
})
const [currentStep, setCurrentStep] = useState<ApiKeyStep>(createType === SupportedCreationMethods.APIKEY ? ApiKeyStep.Verify : ApiKeyStep.Configuration)
const isApiKeyType = createType === SupportedCreationMethods.APIKEY
const isVerifyStep = currentStep === ApiKeyStep.Verify
const isConfigurationStep = currentStep === ApiKeyStep.Configuration
const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>(builder)
const isInitializedRef = useRef(false)
const { mutate: verifyCredentials, isPending: isVerifyingCredentials } = useVerifyAndUpdateTriggerSubscriptionBuilder()
const { mutateAsync: createBuilder /* isPending: isCreatingBuilder */ } = useCreateTriggerSubscriptionBuilder()
const { mutate: buildSubscription, isPending: isBuilding } = useBuildTriggerSubscription()
const { mutate: updateBuilder } = useUpdateTriggerSubscriptionBuilder()
const manualPropertiesSchema = detail?.declaration?.trigger?.subscription_schema || [] // manual
const manualPropertiesFormRef = React.useRef<FormRefObject>(null)
const subscriptionFormRef = React.useRef<FormRefObject>(null)
const autoCommonParametersSchema = detail?.declaration.trigger?.subscription_constructor?.parameters || [] // apikey and oauth
const autoCommonParametersFormRef = React.useRef<FormRefObject>(null)
const apiKeyCredentialsSchema = useMemo(() => {
const rawSchema = detail?.declaration?.trigger?.subscription_constructor?.credentials_schema || []
return rawSchema.map(schema => ({
...schema,
tooltip: schema.help,
}))
}, [detail?.declaration?.trigger?.subscription_constructor?.credentials_schema])
const apiKeyCredentialsFormRef = React.useRef<FormRefObject>(null)
const { data: logData } = useTriggerSubscriptionBuilderLogs(
detail?.provider || '',
subscriptionBuilder?.id || '',
{
enabled: createType === SupportedCreationMethods.MANUAL,
refetchInterval: 3000,
},
)
useEffect(() => {
const initializeBuilder = async () => {
isInitializedRef.current = true
try {
const response = await createBuilder({
provider: detail?.provider || '',
credential_type: CREDENTIAL_TYPE_MAP[createType],
})
setSubscriptionBuilder(response.subscription_builder)
}
catch (error) {
console.error('createBuilder error:', error)
Toast.notify({
type: 'error',
message: t('modal.errors.createFailed', { ns: 'pluginTrigger' }),
})
}
}
if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider)
initializeBuilder()
}, [subscriptionBuilder, detail?.provider, createType, createBuilder, t])
useEffect(() => {
if (subscriptionBuilder?.endpoint && subscriptionFormRef.current && currentStep === ApiKeyStep.Configuration) {
const form = subscriptionFormRef.current.getForm()
if (form)
form.setFieldValue('callback_url', subscriptionBuilder.endpoint)
if (isPrivateOrLocalAddress(subscriptionBuilder.endpoint)) {
console.warn('callback_url is private or local address', subscriptionBuilder.endpoint)
subscriptionFormRef.current?.setFields([{
name: 'callback_url',
warnings: [t('modal.form.callbackUrl.privateAddressWarning', { ns: 'pluginTrigger' })],
}])
}
else {
subscriptionFormRef.current?.setFields([{
name: 'callback_url',
warnings: [],
}])
}
}
}, [subscriptionBuilder?.endpoint, currentStep, t])
const debouncedUpdate = useMemo(
() => debounce((provider: string, builderId: string, properties: Record<string, unknown>) => {
updateBuilder(
{
provider,
subscriptionBuilderId: builderId,
properties,
},
{
onError: async (error: unknown) => {
const errorMessage = await parsePluginErrorMessage(error) || t('modal.errors.updateFailed', { ns: 'pluginTrigger' })
console.error('Failed to update subscription builder:', error)
Toast.notify({
type: 'error',
message: errorMessage,
})
},
},
)
}, 500),
[updateBuilder, t],
)
const handleManualPropertiesChange = useCallback(() => {
if (!subscriptionBuilder || !detail?.provider)
return
const formValues = manualPropertiesFormRef.current?.getFormValues({ needCheckValidatedValues: false }) || { values: {}, isCheckValidated: true }
debouncedUpdate(detail.provider, subscriptionBuilder.id, formValues.values)
}, [subscriptionBuilder, detail?.provider, debouncedUpdate])
useEffect(() => {
return () => {
debouncedUpdate.cancel()
}
}, [debouncedUpdate])
const handleVerify = () => {
const apiKeyCredentialsFormValues = apiKeyCredentialsFormRef.current?.getFormValues({}) || defaultFormValues
const credentials = apiKeyCredentialsFormValues.values
if (!Object.keys(credentials).length) {
Toast.notify({
type: 'error',
message: 'Please fill in all required credentials',
})
return
}
apiKeyCredentialsFormRef.current?.setFields([{
name: Object.keys(credentials)[0],
errors: [],
}])
verifyCredentials(
{
provider: detail?.provider || '',
subscriptionBuilderId: subscriptionBuilder?.id || '',
credentials,
},
{
onSuccess: () => {
Toast.notify({
type: 'success',
message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }),
})
setCurrentStep(ApiKeyStep.Configuration)
},
onError: async (error: unknown) => {
const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' })
apiKeyCredentialsFormRef.current?.setFields([{
name: Object.keys(credentials)[0],
errors: [errorMessage],
}])
},
},
)
}
const handleCreate = () => {
if (!subscriptionBuilder) {
Toast.notify({
type: 'error',
message: 'Subscription builder not found',
})
return
}
const subscriptionFormValues = subscriptionFormRef.current?.getFormValues({})
if (!subscriptionFormValues?.isCheckValidated)
return
const subscriptionNameValue = subscriptionFormValues?.values?.subscription_name as string
const params: BuildTriggerSubscriptionPayload = {
provider: detail?.provider || '',
subscriptionBuilderId: subscriptionBuilder.id,
name: subscriptionNameValue,
}
if (createType !== SupportedCreationMethods.MANUAL) {
if (autoCommonParametersSchema.length > 0) {
const autoCommonParametersFormValues = autoCommonParametersFormRef.current?.getFormValues({}) || defaultFormValues
if (!autoCommonParametersFormValues?.isCheckValidated)
return
params.parameters = autoCommonParametersFormValues.values
}
}
else if (manualPropertiesSchema.length > 0) {
const manualFormValues = manualPropertiesFormRef.current?.getFormValues({}) || defaultFormValues
if (!manualFormValues?.isCheckValidated)
return
}
buildSubscription(
params,
{
onSuccess: () => {
Toast.notify({
type: 'success',
message: t('subscription.createSuccess', { ns: 'pluginTrigger' }),
})
onClose()
refetch?.()
},
onError: async (error: unknown) => {
const errorMessage = await parsePluginErrorMessage(error) || t('subscription.createFailed', { ns: 'pluginTrigger' })
Toast.notify({
type: 'error',
message: errorMessage,
})
},
},
)
}
const handleConfirm = () => {
if (currentStep === ApiKeyStep.Verify)
handleVerify()
else
handleCreate()
}
const handleApiKeyCredentialsChange = () => {
apiKeyCredentialsFormRef.current?.setFields([{
name: apiKeyCredentialsSchema[0].name,
errors: [],
}])
}
const confirmButtonText = useMemo(() => {
if (currentStep === ApiKeyStep.Verify)
return isVerifyingCredentials ? t('modal.common.verifying', { ns: 'pluginTrigger' }) : t('modal.common.verify', { ns: 'pluginTrigger' })
return isBuilding ? t('modal.common.creating', { ns: 'pluginTrigger' }) : t('modal.common.create', { ns: 'pluginTrigger' })
}, [currentStep, isVerifyingCredentials, isBuilding, t])
return (
<Modal
@ -57,36 +353,121 @@ export const CommonCreateModal = ({ onClose, createType, builder }: Props) => {
onCancel={onClose}
onConfirm={handleConfirm}
disabled={isVerifyingCredentials || isBuilding}
bottomSlot={isVerifyStep ? <EncryptedBottom /> : null}
bottomSlot={currentStep === ApiKeyStep.Verify ? <EncryptedBottom /> : null}
size={createType === SupportedCreationMethods.MANUAL ? 'md' : 'sm'}
containerClassName="min-h-[360px]"
clickOutsideNotClose
>
{isApiKeyType && <MultiSteps currentStep={currentStep} />}
{isVerifyStep && (
<VerifyStepContent
apiKeyCredentialsSchema={apiKeyCredentialsSchema}
apiKeyCredentialsFormRef={formRefs.apiKeyCredentialsFormRef}
onChange={handleApiKeyCredentialsChange}
/>
{createType === SupportedCreationMethods.APIKEY && <MultiSteps currentStep={currentStep} />}
{currentStep === ApiKeyStep.Verify && (
<>
{apiKeyCredentialsSchema.length > 0 && (
<div className="mb-4">
<BaseForm
formSchemas={apiKeyCredentialsSchema}
ref={apiKeyCredentialsFormRef}
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
preventDefaultSubmit={true}
formClassName="space-y-4"
onChange={handleApiKeyCredentialsChange}
/>
</div>
)}
</>
)}
{currentStep === ApiKeyStep.Configuration && (
<div className="max-h-[70vh]">
<BaseForm
formSchemas={[
{
name: 'subscription_name',
label: t('modal.form.subscriptionName.label', { ns: 'pluginTrigger' }),
placeholder: t('modal.form.subscriptionName.placeholder', { ns: 'pluginTrigger' }),
type: FormTypeEnum.textInput,
required: true,
},
{
name: 'callback_url',
label: t('modal.form.callbackUrl.label', { ns: 'pluginTrigger' }),
placeholder: t('modal.form.callbackUrl.placeholder', { ns: 'pluginTrigger' }),
type: FormTypeEnum.textInput,
required: false,
default: subscriptionBuilder?.endpoint || '',
disabled: true,
tooltip: t('modal.form.callbackUrl.tooltip', { ns: 'pluginTrigger' }),
showCopy: true,
},
]}
ref={subscriptionFormRef}
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
formClassName="space-y-4 mb-4"
/>
{/* <div className='system-xs-regular mb-6 mt-[-1rem] text-text-tertiary'>
{t('pluginTrigger.modal.form.callbackUrl.description')}
</div> */}
{createType !== SupportedCreationMethods.MANUAL && autoCommonParametersSchema.length > 0 && (
<BaseForm
formSchemas={autoCommonParametersSchema.map((schema) => {
const normalizedType = normalizeFormType(schema.type as FormTypeEnum | string)
return {
...schema,
tooltip: schema.description,
type: normalizedType,
dynamicSelectParams: normalizedType === FormTypeEnum.dynamicSelect
? {
plugin_id: detail?.plugin_id || '',
provider: detail?.provider || '',
action: 'provider',
parameter: schema.name,
credential_id: subscriptionBuilder?.id || '',
}
: undefined,
fieldClassName: schema.type === FormTypeEnum.boolean ? 'flex items-center justify-between' : undefined,
labelClassName: schema.type === FormTypeEnum.boolean ? 'mb-0' : undefined,
}
})}
ref={autoCommonParametersFormRef}
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
formClassName="space-y-4"
/>
)}
{createType === SupportedCreationMethods.MANUAL && (
<>
{manualPropertiesSchema.length > 0 && (
<div className="mb-6">
<BaseForm
formSchemas={manualPropertiesSchema.map(schema => ({
...schema,
tooltip: schema.description,
}))}
ref={manualPropertiesFormRef}
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
formClassName="space-y-4"
onChange={handleManualPropertiesChange}
/>
</div>
)}
<div className="mb-6">
<div className="mb-3 flex items-center gap-2">
<div className="system-xs-medium-uppercase text-text-tertiary">
{t('modal.manual.logs.title', { ns: 'pluginTrigger' })}
</div>
<div className="h-px flex-1 bg-gradient-to-r from-divider-regular to-transparent" />
</div>
{isConfigurationStep && (
<ConfigurationStepContent
createType={createType}
subscriptionBuilder={subscriptionBuilder}
subscriptionFormRef={formRefs.subscriptionFormRef}
autoCommonParametersSchema={autoCommonParametersSchema}
autoCommonParametersFormRef={formRefs.autoCommonParametersFormRef}
manualPropertiesSchema={manualPropertiesSchema}
manualPropertiesFormRef={formRefs.manualPropertiesFormRef}
onManualPropertiesChange={handleManualPropertiesChange}
logs={logData?.logs || []}
pluginId={detail?.plugin_id || ''}
pluginName={detail?.name || ''}
provider={detail?.provider || ''}
/>
<div className="mb-1 flex items-center justify-center gap-1 rounded-lg bg-background-section p-3">
<div className="h-3.5 w-3.5">
<RiLoader2Line className="h-full w-full animate-spin" />
</div>
<div className="system-xs-regular text-text-tertiary">
{t('modal.manual.logs.loading', { ns: 'pluginTrigger', pluginName: detail?.name || '' })}
</div>
</div>
<LogViewer logs={logData?.logs || []} />
</div>
</>
)}
</div>
)}
</Modal>
)

View File

@ -1,304 +0,0 @@
'use client'
import type { FormRefObject, FormSchema } from '@/app/components/base/form/types'
import type { TriggerLogEntity, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
import { RiLoader2Line } from '@remixicon/react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { BaseForm } from '@/app/components/base/form/components/base'
import { FormTypeEnum } from '@/app/components/base/form/types'
import { SupportedCreationMethods } from '@/app/components/plugins/types'
import LogViewer from '../../log-viewer'
import { ApiKeyStep } from '../hooks/use-common-modal-state'
export type SchemaItem = Partial<FormSchema> & Record<string, unknown> & {
name: string
}
type StatusStepProps = {
isActive: boolean
text: string
}
export const StatusStep = ({ isActive, text }: StatusStepProps) => {
return (
<div className={`system-2xs-semibold-uppercase flex items-center gap-1 ${isActive
? 'text-state-accent-solid'
: 'text-text-tertiary'}`}
>
{isActive && (
<div className="h-1 w-1 rounded-full bg-state-accent-solid"></div>
)}
{text}
</div>
)
}
type MultiStepsProps = {
currentStep: ApiKeyStep
}
export const MultiSteps = ({ currentStep }: MultiStepsProps) => {
const { t } = useTranslation()
return (
<div className="mb-6 flex w-1/3 items-center gap-2">
<StatusStep isActive={currentStep === ApiKeyStep.Verify} text={t('modal.steps.verify', { ns: 'pluginTrigger' })} />
<div className="h-px w-3 shrink-0 bg-divider-deep"></div>
<StatusStep isActive={currentStep === ApiKeyStep.Configuration} text={t('modal.steps.configuration', { ns: 'pluginTrigger' })} />
</div>
)
}
type VerifyStepContentProps = {
apiKeyCredentialsSchema: SchemaItem[]
apiKeyCredentialsFormRef: React.RefObject<FormRefObject | null>
onChange: () => void
}
export const VerifyStepContent = ({
apiKeyCredentialsSchema,
apiKeyCredentialsFormRef,
onChange,
}: VerifyStepContentProps) => {
if (!apiKeyCredentialsSchema.length)
return null
return (
<div className="mb-4">
<BaseForm
formSchemas={apiKeyCredentialsSchema as FormSchema[]}
ref={apiKeyCredentialsFormRef}
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
preventDefaultSubmit={true}
formClassName="space-y-4"
onChange={onChange}
/>
</div>
)
}
type SubscriptionFormProps = {
subscriptionFormRef: React.RefObject<FormRefObject | null>
endpoint?: string
}
export const SubscriptionForm = ({
subscriptionFormRef,
endpoint,
}: SubscriptionFormProps) => {
const { t } = useTranslation()
const formSchemas = React.useMemo(() => [
{
name: 'subscription_name',
label: t('modal.form.subscriptionName.label', { ns: 'pluginTrigger' }),
placeholder: t('modal.form.subscriptionName.placeholder', { ns: 'pluginTrigger' }),
type: FormTypeEnum.textInput,
required: true,
},
{
name: 'callback_url',
label: t('modal.form.callbackUrl.label', { ns: 'pluginTrigger' }),
placeholder: t('modal.form.callbackUrl.placeholder', { ns: 'pluginTrigger' }),
type: FormTypeEnum.textInput,
required: false,
default: endpoint || '',
disabled: true,
tooltip: t('modal.form.callbackUrl.tooltip', { ns: 'pluginTrigger' }),
showCopy: true,
},
], [endpoint, t])
return (
<BaseForm
formSchemas={formSchemas}
ref={subscriptionFormRef}
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
formClassName="space-y-4 mb-4"
/>
)
}
const normalizeFormType = (type: FormTypeEnum | string): FormTypeEnum => {
if (Object.values(FormTypeEnum).includes(type as FormTypeEnum))
return type as FormTypeEnum
const TYPE_MAP: Record<string, FormTypeEnum> = {
string: FormTypeEnum.textInput,
text: FormTypeEnum.textInput,
password: FormTypeEnum.secretInput,
secret: FormTypeEnum.secretInput,
number: FormTypeEnum.textNumber,
integer: FormTypeEnum.textNumber,
boolean: FormTypeEnum.boolean,
}
return TYPE_MAP[type] || FormTypeEnum.textInput
}
type AutoParametersFormProps = {
schemas: SchemaItem[]
formRef: React.RefObject<FormRefObject | null>
pluginId: string
provider: string
credentialId: string
}
export const AutoParametersForm = ({
schemas,
formRef,
pluginId,
provider,
credentialId,
}: AutoParametersFormProps) => {
const formSchemas = React.useMemo(() =>
schemas.map((schema) => {
const normalizedType = normalizeFormType((schema.type || FormTypeEnum.textInput) as FormTypeEnum | string)
return {
...schema,
tooltip: schema.description,
type: normalizedType,
dynamicSelectParams: normalizedType === FormTypeEnum.dynamicSelect
? {
plugin_id: pluginId,
provider,
action: 'provider',
parameter: schema.name,
credential_id: credentialId,
}
: undefined,
fieldClassName: normalizedType === FormTypeEnum.boolean ? 'flex items-center justify-between' : undefined,
labelClassName: normalizedType === FormTypeEnum.boolean ? 'mb-0' : undefined,
}
}) as FormSchema[], [schemas, pluginId, provider, credentialId])
if (!schemas.length)
return null
return (
<BaseForm
formSchemas={formSchemas}
ref={formRef}
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
formClassName="space-y-4"
/>
)
}
type ManualPropertiesSectionProps = {
schemas: SchemaItem[]
formRef: React.RefObject<FormRefObject | null>
onChange: () => void
logs: TriggerLogEntity[]
pluginName: string
}
export const ManualPropertiesSection = ({
schemas,
formRef,
onChange,
logs,
pluginName,
}: ManualPropertiesSectionProps) => {
const { t } = useTranslation()
const formSchemas = React.useMemo(() =>
schemas.map(schema => ({
...schema,
tooltip: schema.description,
})) as FormSchema[], [schemas])
return (
<>
{schemas.length > 0 && (
<div className="mb-6">
<BaseForm
formSchemas={formSchemas}
ref={formRef}
labelClassName="system-sm-medium mb-2 flex items-center gap-1 text-text-primary"
formClassName="space-y-4"
onChange={onChange}
/>
</div>
)}
<div className="mb-6">
<div className="mb-3 flex items-center gap-2">
<div className="system-xs-medium-uppercase text-text-tertiary">
{t('modal.manual.logs.title', { ns: 'pluginTrigger' })}
</div>
<div className="h-px flex-1 bg-gradient-to-r from-divider-regular to-transparent" />
</div>
<div className="mb-1 flex items-center justify-center gap-1 rounded-lg bg-background-section p-3">
<div className="h-3.5 w-3.5">
<RiLoader2Line className="h-full w-full animate-spin" />
</div>
<div className="system-xs-regular text-text-tertiary">
{t('modal.manual.logs.loading', { ns: 'pluginTrigger', pluginName })}
</div>
</div>
<LogViewer logs={logs} />
</div>
</>
)
}
type ConfigurationStepContentProps = {
createType: SupportedCreationMethods
subscriptionBuilder?: TriggerSubscriptionBuilder
subscriptionFormRef: React.RefObject<FormRefObject | null>
autoCommonParametersSchema: SchemaItem[]
autoCommonParametersFormRef: React.RefObject<FormRefObject | null>
manualPropertiesSchema: SchemaItem[]
manualPropertiesFormRef: React.RefObject<FormRefObject | null>
onManualPropertiesChange: () => void
logs: TriggerLogEntity[]
pluginId: string
pluginName: string
provider: string
}
export const ConfigurationStepContent = ({
createType,
subscriptionBuilder,
subscriptionFormRef,
autoCommonParametersSchema,
autoCommonParametersFormRef,
manualPropertiesSchema,
manualPropertiesFormRef,
onManualPropertiesChange,
logs,
pluginId,
pluginName,
provider,
}: ConfigurationStepContentProps) => {
const isManualType = createType === SupportedCreationMethods.MANUAL
return (
<div className="max-h-[70vh]">
<SubscriptionForm
subscriptionFormRef={subscriptionFormRef}
endpoint={subscriptionBuilder?.endpoint}
/>
{!isManualType && autoCommonParametersSchema.length > 0 && (
<AutoParametersForm
schemas={autoCommonParametersSchema}
formRef={autoCommonParametersFormRef}
pluginId={pluginId}
provider={provider}
credentialId={subscriptionBuilder?.id || ''}
/>
)}
{isManualType && (
<ManualPropertiesSection
schemas={manualPropertiesSchema}
formRef={manualPropertiesFormRef}
onChange={onManualPropertiesChange}
logs={logs}
pluginName={pluginName}
/>
)}
</div>
)
}

View File

@ -1,401 +0,0 @@
'use client'
import type { SimpleDetail } from '../../../store'
import type { SchemaItem } from '../components/modal-steps'
import type { FormRefObject } from '@/app/components/base/form/types'
import type { TriggerLogEntity, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers'
import { debounce } from 'es-toolkit/compat'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Toast from '@/app/components/base/toast'
import { SupportedCreationMethods } from '@/app/components/plugins/types'
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
import {
useBuildTriggerSubscription,
useCreateTriggerSubscriptionBuilder,
useTriggerSubscriptionBuilderLogs,
useUpdateTriggerSubscriptionBuilder,
useVerifyAndUpdateTriggerSubscriptionBuilder,
} from '@/service/use-triggers'
import { parsePluginErrorMessage } from '@/utils/error-parser'
import { isPrivateOrLocalAddress } from '@/utils/urlValidation'
import { usePluginStore } from '../../../store'
import { useSubscriptionList } from '../../use-subscription-list'
// ============================================================================
// Types
// ============================================================================
export enum ApiKeyStep {
Verify = 'verify',
Configuration = 'configuration',
}
export const CREDENTIAL_TYPE_MAP: Record<SupportedCreationMethods, TriggerCredentialTypeEnum> = {
[SupportedCreationMethods.APIKEY]: TriggerCredentialTypeEnum.ApiKey,
[SupportedCreationMethods.OAUTH]: TriggerCredentialTypeEnum.Oauth2,
[SupportedCreationMethods.MANUAL]: TriggerCredentialTypeEnum.Unauthorized,
}
export const MODAL_TITLE_KEY_MAP: Record<
SupportedCreationMethods,
'modal.apiKey.title' | 'modal.oauth.title' | 'modal.manual.title'
> = {
[SupportedCreationMethods.APIKEY]: 'modal.apiKey.title',
[SupportedCreationMethods.OAUTH]: 'modal.oauth.title',
[SupportedCreationMethods.MANUAL]: 'modal.manual.title',
}
type UseCommonModalStateParams = {
createType: SupportedCreationMethods
builder?: TriggerSubscriptionBuilder
onClose: () => void
}
type FormRefs = {
manualPropertiesFormRef: React.RefObject<FormRefObject | null>
subscriptionFormRef: React.RefObject<FormRefObject | null>
autoCommonParametersFormRef: React.RefObject<FormRefObject | null>
apiKeyCredentialsFormRef: React.RefObject<FormRefObject | null>
}
type UseCommonModalStateReturn = {
// State
currentStep: ApiKeyStep
subscriptionBuilder: TriggerSubscriptionBuilder | undefined
isVerifyingCredentials: boolean
isBuilding: boolean
// Form refs
formRefs: FormRefs
// Computed values
detail: SimpleDetail | undefined
manualPropertiesSchema: SchemaItem[]
autoCommonParametersSchema: SchemaItem[]
apiKeyCredentialsSchema: SchemaItem[]
logData: { logs: TriggerLogEntity[] } | undefined
confirmButtonText: string
// Handlers
handleVerify: () => void
handleCreate: () => void
handleConfirm: () => void
handleManualPropertiesChange: () => void
handleApiKeyCredentialsChange: () => void
}
const DEFAULT_FORM_VALUES = { values: {}, isCheckValidated: false }
// ============================================================================
// Hook Implementation
// ============================================================================
export const useCommonModalState = ({
createType,
builder,
onClose,
}: UseCommonModalStateParams): UseCommonModalStateReturn => {
const { t } = useTranslation()
const detail = usePluginStore(state => state.detail)
const { refetch } = useSubscriptionList()
// State
const [currentStep, setCurrentStep] = useState<ApiKeyStep>(
createType === SupportedCreationMethods.APIKEY ? ApiKeyStep.Verify : ApiKeyStep.Configuration,
)
const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>(builder)
const isInitializedRef = useRef(false)
// Form refs
const manualPropertiesFormRef = useRef<FormRefObject>(null)
const subscriptionFormRef = useRef<FormRefObject>(null)
const autoCommonParametersFormRef = useRef<FormRefObject>(null)
const apiKeyCredentialsFormRef = useRef<FormRefObject>(null)
// Mutations
const { mutate: verifyCredentials, isPending: isVerifyingCredentials } = useVerifyAndUpdateTriggerSubscriptionBuilder()
const { mutateAsync: createBuilder } = useCreateTriggerSubscriptionBuilder()
const { mutate: buildSubscription, isPending: isBuilding } = useBuildTriggerSubscription()
const { mutate: updateBuilder } = useUpdateTriggerSubscriptionBuilder()
// Schemas
const manualPropertiesSchema = detail?.declaration?.trigger?.subscription_schema || []
const autoCommonParametersSchema = detail?.declaration.trigger?.subscription_constructor?.parameters || []
const apiKeyCredentialsSchema = useMemo(() => {
const rawSchema = detail?.declaration?.trigger?.subscription_constructor?.credentials_schema || []
return rawSchema.map(schema => ({
...schema,
tooltip: schema.help,
}))
}, [detail?.declaration?.trigger?.subscription_constructor?.credentials_schema])
// Log data for manual mode
const { data: logData } = useTriggerSubscriptionBuilderLogs(
detail?.provider || '',
subscriptionBuilder?.id || '',
{
enabled: createType === SupportedCreationMethods.MANUAL,
refetchInterval: 3000,
},
)
// Debounced update for manual properties
const debouncedUpdate = useMemo(
() => debounce((provider: string, builderId: string, properties: Record<string, unknown>) => {
updateBuilder(
{
provider,
subscriptionBuilderId: builderId,
properties,
},
{
onError: async (error: unknown) => {
const errorMessage = await parsePluginErrorMessage(error) || t('modal.errors.updateFailed', { ns: 'pluginTrigger' })
console.error('Failed to update subscription builder:', error)
Toast.notify({
type: 'error',
message: errorMessage,
})
},
},
)
}, 500),
[updateBuilder, t],
)
// Initialize builder
useEffect(() => {
const initializeBuilder = async () => {
isInitializedRef.current = true
try {
const response = await createBuilder({
provider: detail?.provider || '',
credential_type: CREDENTIAL_TYPE_MAP[createType],
})
setSubscriptionBuilder(response.subscription_builder)
}
catch (error) {
console.error('createBuilder error:', error)
Toast.notify({
type: 'error',
message: t('modal.errors.createFailed', { ns: 'pluginTrigger' }),
})
}
}
if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider)
initializeBuilder()
}, [subscriptionBuilder, detail?.provider, createType, createBuilder, t])
// Cleanup debounced function
useEffect(() => {
return () => {
debouncedUpdate.cancel()
}
}, [debouncedUpdate])
// Update endpoint in form when endpoint changes
useEffect(() => {
if (!subscriptionBuilder?.endpoint || !subscriptionFormRef.current || currentStep !== ApiKeyStep.Configuration)
return
const form = subscriptionFormRef.current.getForm()
if (form)
form.setFieldValue('callback_url', subscriptionBuilder.endpoint)
const warnings = isPrivateOrLocalAddress(subscriptionBuilder.endpoint)
? [t('modal.form.callbackUrl.privateAddressWarning', { ns: 'pluginTrigger' })]
: []
subscriptionFormRef.current?.setFields([{
name: 'callback_url',
warnings,
}])
}, [subscriptionBuilder?.endpoint, currentStep, t])
// Handle manual properties change
const handleManualPropertiesChange = useCallback(() => {
if (!subscriptionBuilder || !detail?.provider)
return
const formValues = manualPropertiesFormRef.current?.getFormValues({ needCheckValidatedValues: false })
|| { values: {}, isCheckValidated: true }
debouncedUpdate(detail.provider, subscriptionBuilder.id, formValues.values)
}, [subscriptionBuilder, detail?.provider, debouncedUpdate])
// Handle API key credentials change
const handleApiKeyCredentialsChange = useCallback(() => {
if (!apiKeyCredentialsSchema.length)
return
apiKeyCredentialsFormRef.current?.setFields([{
name: apiKeyCredentialsSchema[0].name,
errors: [],
}])
}, [apiKeyCredentialsSchema])
// Handle verify
const handleVerify = useCallback(() => {
// Guard against uninitialized state
if (!detail?.provider || !subscriptionBuilder?.id) {
Toast.notify({
type: 'error',
message: 'Subscription builder not initialized',
})
return
}
const apiKeyCredentialsFormValues = apiKeyCredentialsFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES
const credentials = apiKeyCredentialsFormValues.values
if (!Object.keys(credentials).length) {
Toast.notify({
type: 'error',
message: 'Please fill in all required credentials',
})
return
}
apiKeyCredentialsFormRef.current?.setFields([{
name: Object.keys(credentials)[0],
errors: [],
}])
verifyCredentials(
{
provider: detail.provider,
subscriptionBuilderId: subscriptionBuilder.id,
credentials,
},
{
onSuccess: () => {
Toast.notify({
type: 'success',
message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }),
})
setCurrentStep(ApiKeyStep.Configuration)
},
onError: async (error: unknown) => {
const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' })
apiKeyCredentialsFormRef.current?.setFields([{
name: Object.keys(credentials)[0],
errors: [errorMessage],
}])
},
},
)
}, [detail?.provider, subscriptionBuilder?.id, verifyCredentials, t])
// Handle create
const handleCreate = useCallback(() => {
if (!subscriptionBuilder) {
Toast.notify({
type: 'error',
message: 'Subscription builder not found',
})
return
}
const subscriptionFormValues = subscriptionFormRef.current?.getFormValues({})
if (!subscriptionFormValues?.isCheckValidated)
return
const subscriptionNameValue = subscriptionFormValues?.values?.subscription_name as string
const params: BuildTriggerSubscriptionPayload = {
provider: detail?.provider || '',
subscriptionBuilderId: subscriptionBuilder.id,
name: subscriptionNameValue,
}
if (createType !== SupportedCreationMethods.MANUAL) {
if (autoCommonParametersSchema.length > 0) {
const autoCommonParametersFormValues = autoCommonParametersFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES
if (!autoCommonParametersFormValues?.isCheckValidated)
return
params.parameters = autoCommonParametersFormValues.values
}
}
else if (manualPropertiesSchema.length > 0) {
const manualFormValues = manualPropertiesFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES
if (!manualFormValues?.isCheckValidated)
return
}
buildSubscription(
params,
{
onSuccess: () => {
Toast.notify({
type: 'success',
message: t('subscription.createSuccess', { ns: 'pluginTrigger' }),
})
onClose()
refetch?.()
},
onError: async (error: unknown) => {
const errorMessage = await parsePluginErrorMessage(error) || t('subscription.createFailed', { ns: 'pluginTrigger' })
Toast.notify({
type: 'error',
message: errorMessage,
})
},
},
)
}, [
subscriptionBuilder,
detail?.provider,
createType,
autoCommonParametersSchema.length,
manualPropertiesSchema.length,
buildSubscription,
onClose,
refetch,
t,
])
// Handle confirm (dispatch based on step)
const handleConfirm = useCallback(() => {
if (currentStep === ApiKeyStep.Verify)
handleVerify()
else
handleCreate()
}, [currentStep, handleVerify, handleCreate])
// Confirm button text
const confirmButtonText = useMemo(() => {
if (currentStep === ApiKeyStep.Verify) {
return isVerifyingCredentials
? t('modal.common.verifying', { ns: 'pluginTrigger' })
: t('modal.common.verify', { ns: 'pluginTrigger' })
}
return isBuilding
? t('modal.common.creating', { ns: 'pluginTrigger' })
: t('modal.common.create', { ns: 'pluginTrigger' })
}, [currentStep, isVerifyingCredentials, isBuilding, t])
return {
currentStep,
subscriptionBuilder,
isVerifyingCredentials,
isBuilding,
formRefs: {
manualPropertiesFormRef,
subscriptionFormRef,
autoCommonParametersFormRef,
apiKeyCredentialsFormRef,
},
detail,
manualPropertiesSchema,
autoCommonParametersSchema,
apiKeyCredentialsSchema,
logData,
confirmButtonText,
handleVerify,
handleCreate,
handleConfirm,
handleManualPropertiesChange,
handleApiKeyCredentialsChange,
}
}

View File

@ -1,719 +0,0 @@
import type { TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
import { act, renderHook, waitFor } from '@testing-library/react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
import {
AuthorizationStatusEnum,
ClientTypeEnum,
getErrorMessage,
useOAuthClientState,
} from './use-oauth-client-state'
// ============================================================================
// Mock Factory Functions
// ============================================================================
function createMockOAuthConfig(overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig {
return {
configured: true,
custom_configured: false,
custom_enabled: false,
system_configured: true,
redirect_uri: 'https://example.com/oauth/callback',
params: {
client_id: 'default-client-id',
client_secret: 'default-client-secret',
},
oauth_client_schema: [
{ name: 'client_id', type: 'text-input' as unknown, required: true, label: { 'en-US': 'Client ID' } as unknown },
{ name: 'client_secret', type: 'secret-input' as unknown, required: true, label: { 'en-US': 'Client Secret' } as unknown },
] as TriggerOAuthConfig['oauth_client_schema'],
...overrides,
}
}
function createMockSubscriptionBuilder(overrides: Partial<TriggerSubscriptionBuilder> = {}): TriggerSubscriptionBuilder {
return {
id: 'builder-123',
name: 'Test Builder',
provider: 'test-provider',
credential_type: TriggerCredentialTypeEnum.Oauth2,
credentials: {},
endpoint: 'https://example.com/callback',
parameters: {},
properties: {},
workflows_in_use: 0,
...overrides,
}
}
// ============================================================================
// Mock Setup
// ============================================================================
const mockInitiateOAuth = vi.fn()
const mockVerifyBuilder = vi.fn()
const mockConfigureOAuth = vi.fn()
const mockDeleteOAuth = vi.fn()
vi.mock('@/service/use-triggers', () => ({
useInitiateTriggerOAuth: () => ({
mutate: mockInitiateOAuth,
}),
useVerifyAndUpdateTriggerSubscriptionBuilder: () => ({
mutate: mockVerifyBuilder,
}),
useConfigureTriggerOAuth: () => ({
mutate: mockConfigureOAuth,
}),
useDeleteTriggerOAuth: () => ({
mutate: mockDeleteOAuth,
}),
}))
const mockOpenOAuthPopup = vi.fn()
vi.mock('@/hooks/use-oauth', () => ({
openOAuthPopup: (url: string, callback: (data: unknown) => void) => mockOpenOAuthPopup(url, callback),
}))
const mockToastNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (params: unknown) => mockToastNotify(params),
},
}))
// ============================================================================
// Test Suites
// ============================================================================
describe('getErrorMessage', () => {
it('should extract message from Error instance', () => {
const error = new Error('Test error message')
expect(getErrorMessage(error, 'fallback')).toBe('Test error message')
})
it('should extract message from object with message property', () => {
const error = { message: 'Object error message' }
expect(getErrorMessage(error, 'fallback')).toBe('Object error message')
})
it('should return fallback when error is empty object', () => {
expect(getErrorMessage({}, 'fallback')).toBe('fallback')
})
it('should return fallback when error.message is not a string', () => {
expect(getErrorMessage({ message: 123 }, 'fallback')).toBe('fallback')
})
it('should return fallback when error.message is empty string', () => {
expect(getErrorMessage({ message: '' }, 'fallback')).toBe('fallback')
})
it('should return fallback when error is null', () => {
expect(getErrorMessage(null, 'fallback')).toBe('fallback')
})
it('should return fallback when error is undefined', () => {
expect(getErrorMessage(undefined, 'fallback')).toBe('fallback')
})
it('should return fallback when error is a primitive', () => {
expect(getErrorMessage('string error', 'fallback')).toBe('fallback')
expect(getErrorMessage(123, 'fallback')).toBe('fallback')
})
})
describe('useOAuthClientState', () => {
const defaultParams = {
oauthConfig: createMockOAuthConfig(),
providerName: 'test-provider',
onClose: vi.fn(),
showOAuthCreateModal: vi.fn(),
}
beforeEach(() => {
vi.clearAllMocks()
})
afterEach(() => {
vi.clearAllMocks()
})
describe('Initial State', () => {
it('should default to Default client type when system_configured is true', () => {
const { result } = renderHook(() => useOAuthClientState(defaultParams))
expect(result.current.clientType).toBe(ClientTypeEnum.Default)
})
it('should default to Custom client type when system_configured is false', () => {
const config = createMockOAuthConfig({ system_configured: false })
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
oauthConfig: config,
}))
expect(result.current.clientType).toBe(ClientTypeEnum.Custom)
})
it('should have undefined authorizationStatus initially', () => {
const { result } = renderHook(() => useOAuthClientState(defaultParams))
expect(result.current.authorizationStatus).toBeUndefined()
})
it('should provide clientFormRef', () => {
const { result } = renderHook(() => useOAuthClientState(defaultParams))
expect(result.current.clientFormRef).toBeDefined()
expect(result.current.clientFormRef.current).toBeNull()
})
})
describe('OAuth Client Schema', () => {
it('should compute schema with default values from params', () => {
const config = createMockOAuthConfig({
params: {
client_id: 'my-client-id',
client_secret: 'my-secret',
},
})
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
oauthConfig: config,
}))
expect(result.current.oauthClientSchema).toHaveLength(2)
expect(result.current.oauthClientSchema[0].default).toBe('my-client-id')
expect(result.current.oauthClientSchema[1].default).toBe('my-secret')
})
it('should return empty array when oauth_client_schema is empty', () => {
const config = createMockOAuthConfig({
oauth_client_schema: [],
})
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
oauthConfig: config,
}))
expect(result.current.oauthClientSchema).toEqual([])
})
it('should return empty array when params is undefined', () => {
const config = createMockOAuthConfig({
params: undefined as unknown as TriggerOAuthConfig['params'],
})
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
oauthConfig: config,
}))
expect(result.current.oauthClientSchema).toEqual([])
})
it('should preserve original schema default when param key not found', () => {
const config = createMockOAuthConfig({
params: {
client_id: 'only-client-id',
client_secret: '', // empty
},
oauth_client_schema: [
{ name: 'client_id', type: 'text-input' as unknown, required: true, label: {} as unknown, default: 'original-default' },
{ name: 'extra_field', type: 'text-input' as unknown, required: false, label: {} as unknown, default: 'extra-default' },
] as TriggerOAuthConfig['oauth_client_schema'],
})
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
oauthConfig: config,
}))
// client_id should be overridden
expect(result.current.oauthClientSchema[0].default).toBe('only-client-id')
// extra_field should keep original default since key not in params
expect(result.current.oauthClientSchema[1].default).toBe('extra-default')
})
})
describe('Confirm Button Text', () => {
it('should show saveAndAuth text by default', () => {
const { result } = renderHook(() => useOAuthClientState(defaultParams))
expect(result.current.confirmButtonText).toBe('plugin.auth.saveAndAuth')
})
it('should show authorizing text when status is Pending', async () => {
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation(() => {
// Don't resolve - stays pending
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
await waitFor(() => {
expect(result.current.confirmButtonText).toBe('pluginTrigger.modal.common.authorizing')
})
})
})
describe('setClientType', () => {
it('should update client type when called', () => {
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.setClientType(ClientTypeEnum.Custom)
})
expect(result.current.clientType).toBe(ClientTypeEnum.Custom)
})
it('should toggle between client types', () => {
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.setClientType(ClientTypeEnum.Custom)
})
expect(result.current.clientType).toBe(ClientTypeEnum.Custom)
act(() => {
result.current.setClientType(ClientTypeEnum.Default)
})
expect(result.current.clientType).toBe(ClientTypeEnum.Default)
})
})
describe('handleRemove', () => {
it('should call deleteOAuth with provider name', () => {
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleRemove()
})
expect(mockDeleteOAuth).toHaveBeenCalledWith(
'test-provider',
expect.any(Object),
)
})
it('should call onClose and show success toast on success', () => {
mockDeleteOAuth.mockImplementation((provider, { onSuccess }) => onSuccess())
const onClose = vi.fn()
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
onClose,
}))
act(() => {
result.current.handleRemove()
})
expect(onClose).toHaveBeenCalled()
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'success',
message: 'pluginTrigger.modal.oauth.remove.success',
})
})
it('should show error toast with error message on failure', () => {
mockDeleteOAuth.mockImplementation((provider, { onError }) => {
onError(new Error('Delete failed'))
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleRemove()
})
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'error',
message: 'Delete failed',
})
})
})
describe('handleSave', () => {
it('should call configureOAuth with enabled: false for Default type', () => {
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(false)
})
expect(mockConfigureOAuth).toHaveBeenCalledWith(
expect.objectContaining({
provider: 'test-provider',
enabled: false,
}),
expect.any(Object),
)
})
it('should call configureOAuth with enabled: true for Custom type', () => {
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
const config = createMockOAuthConfig({ system_configured: false })
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
oauthConfig: config,
}))
// Mock the form ref
const mockFormRef = {
getFormValues: () => ({
values: { client_id: 'new-id', client_secret: 'new-secret' },
isCheckValidated: true,
}),
}
// @ts-expect-error - mocking ref
result.current.clientFormRef.current = mockFormRef
act(() => {
result.current.handleSave(false)
})
expect(mockConfigureOAuth).toHaveBeenCalledWith(
expect.objectContaining({
enabled: true,
}),
expect.any(Object),
)
})
it('should show success toast and call onClose when needAuth is false', () => {
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
const onClose = vi.fn()
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
onClose,
}))
act(() => {
result.current.handleSave(false)
})
expect(onClose).toHaveBeenCalled()
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'success',
message: 'pluginTrigger.modal.oauth.save.success',
})
})
it('should trigger authorization when needAuth is true', () => {
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
onSuccess({
authorization_url: 'https://oauth.example.com/authorize',
subscription_builder: createMockSubscriptionBuilder(),
})
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
expect(mockInitiateOAuth).toHaveBeenCalledWith(
'test-provider',
expect.any(Object),
)
})
})
describe('handleAuthorization', () => {
it('should set status to Pending and call initiateOAuth', () => {
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation(() => {})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Pending)
expect(mockInitiateOAuth).toHaveBeenCalled()
})
it('should open OAuth popup on success', () => {
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
onSuccess({
authorization_url: 'https://oauth.example.com/authorize',
subscription_builder: createMockSubscriptionBuilder(),
})
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
expect(mockOpenOAuthPopup).toHaveBeenCalledWith(
'https://oauth.example.com/authorize',
expect.any(Function),
)
})
it('should set status to Failed and show error toast on error', () => {
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onError }) => {
onError(new Error('OAuth failed'))
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Failed)
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'error',
message: 'pluginTrigger.modal.oauth.authorization.authFailed',
})
})
it('should call onClose and showOAuthCreateModal on callback success', () => {
const onClose = vi.fn()
const showOAuthCreateModal = vi.fn()
const builder = createMockSubscriptionBuilder()
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
onSuccess({
authorization_url: 'https://oauth.example.com/authorize',
subscription_builder: builder,
})
})
mockOpenOAuthPopup.mockImplementation((url, callback) => {
callback({ success: true })
})
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
onClose,
showOAuthCreateModal,
}))
act(() => {
result.current.handleSave(true)
})
expect(onClose).toHaveBeenCalled()
expect(showOAuthCreateModal).toHaveBeenCalledWith(builder)
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'success',
message: 'pluginTrigger.modal.oauth.authorization.authSuccess',
})
})
it('should not call callbacks when OAuth callback returns falsy', () => {
const onClose = vi.fn()
const showOAuthCreateModal = vi.fn()
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
onSuccess({
authorization_url: 'https://oauth.example.com/authorize',
subscription_builder: createMockSubscriptionBuilder(),
})
})
mockOpenOAuthPopup.mockImplementation((url, callback) => {
callback(null)
})
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
onClose,
showOAuthCreateModal,
}))
act(() => {
result.current.handleSave(true)
})
expect(onClose).not.toHaveBeenCalled()
expect(showOAuthCreateModal).not.toHaveBeenCalled()
})
})
describe('Polling Effect', () => {
it('should start polling after authorization starts', async () => {
vi.useFakeTimers({ shouldAdvanceTime: true })
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
onSuccess({
authorization_url: 'https://oauth.example.com/authorize',
subscription_builder: createMockSubscriptionBuilder(),
})
})
mockVerifyBuilder.mockImplementation((params, { onSuccess }) => {
onSuccess({ verified: false })
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
// Advance timer to trigger first poll
await act(async () => {
vi.advanceTimersByTime(3000)
})
expect(mockVerifyBuilder).toHaveBeenCalled()
vi.useRealTimers()
})
it('should set status to Success when verified', async () => {
vi.useFakeTimers({ shouldAdvanceTime: true })
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
onSuccess({
authorization_url: 'https://oauth.example.com/authorize',
subscription_builder: createMockSubscriptionBuilder(),
})
})
mockVerifyBuilder.mockImplementation((params, { onSuccess }) => {
onSuccess({ verified: true })
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
await act(async () => {
vi.advanceTimersByTime(3000)
})
await waitFor(() => {
expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Success)
})
vi.useRealTimers()
})
it('should continue polling on error', async () => {
vi.useFakeTimers({ shouldAdvanceTime: true })
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
onSuccess({
authorization_url: 'https://oauth.example.com/authorize',
subscription_builder: createMockSubscriptionBuilder(),
})
})
mockVerifyBuilder.mockImplementation((params, { onError }) => {
onError(new Error('Verify failed'))
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
await act(async () => {
vi.advanceTimersByTime(3000)
})
expect(mockVerifyBuilder).toHaveBeenCalled()
// Status should still be Pending
expect(result.current.authorizationStatus).toBe(AuthorizationStatusEnum.Pending)
vi.useRealTimers()
})
it('should stop polling when verified', async () => {
vi.useFakeTimers({ shouldAdvanceTime: true })
mockConfigureOAuth.mockImplementation((params, { onSuccess }) => onSuccess())
mockInitiateOAuth.mockImplementation((provider, { onSuccess }) => {
onSuccess({
authorization_url: 'https://oauth.example.com/authorize',
subscription_builder: createMockSubscriptionBuilder(),
})
})
mockVerifyBuilder.mockImplementation((params, { onSuccess }) => {
onSuccess({ verified: true })
})
const { result } = renderHook(() => useOAuthClientState(defaultParams))
act(() => {
result.current.handleSave(true)
})
// First poll - should verify
await act(async () => {
vi.advanceTimersByTime(3000)
})
expect(mockVerifyBuilder).toHaveBeenCalledTimes(1)
// Second poll - should not happen as interval is cleared
await act(async () => {
vi.advanceTimersByTime(3000)
})
// Still only 1 call because polling stopped
expect(mockVerifyBuilder).toHaveBeenCalledTimes(1)
vi.useRealTimers()
})
})
describe('Edge Cases', () => {
it('should handle undefined oauthConfig', () => {
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
oauthConfig: undefined,
}))
expect(result.current.clientType).toBe(ClientTypeEnum.Custom)
expect(result.current.oauthClientSchema).toEqual([])
})
it('should handle empty providerName', () => {
const { result } = renderHook(() => useOAuthClientState({
...defaultParams,
providerName: '',
}))
// Should not throw
expect(result.current.clientType).toBe(ClientTypeEnum.Default)
})
})
})
describe('Enum Exports', () => {
it('should export AuthorizationStatusEnum', () => {
expect(AuthorizationStatusEnum.Pending).toBe('pending')
expect(AuthorizationStatusEnum.Success).toBe('success')
expect(AuthorizationStatusEnum.Failed).toBe('failed')
})
it('should export ClientTypeEnum', () => {
expect(ClientTypeEnum.Default).toBe('default')
expect(ClientTypeEnum.Custom).toBe('custom')
})
})

View File

@ -1,241 +0,0 @@
'use client'
import type { FormRefObject } from '@/app/components/base/form/types'
import type { TriggerOAuthClientParams, TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
import type { ConfigureTriggerOAuthPayload } from '@/service/use-triggers'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Toast from '@/app/components/base/toast'
import { openOAuthPopup } from '@/hooks/use-oauth'
import {
useConfigureTriggerOAuth,
useDeleteTriggerOAuth,
useInitiateTriggerOAuth,
useVerifyAndUpdateTriggerSubscriptionBuilder,
} from '@/service/use-triggers'
export enum AuthorizationStatusEnum {
Pending = 'pending',
Success = 'success',
Failed = 'failed',
}
export enum ClientTypeEnum {
Default = 'default',
Custom = 'custom',
}
const POLL_INTERVAL_MS = 3000
// Extract error message from various error formats
export const getErrorMessage = (error: unknown, fallback: string): string => {
if (error instanceof Error && error.message)
return error.message
if (typeof error === 'object' && error && 'message' in error) {
const message = (error as { message?: string }).message
if (typeof message === 'string' && message)
return message
}
return fallback
}
type UseOAuthClientStateParams = {
oauthConfig?: TriggerOAuthConfig
providerName: string
onClose: () => void
showOAuthCreateModal: (builder: TriggerSubscriptionBuilder) => void
}
type UseOAuthClientStateReturn = {
// State
clientType: ClientTypeEnum
setClientType: (type: ClientTypeEnum) => void
authorizationStatus: AuthorizationStatusEnum | undefined
// Refs
clientFormRef: React.RefObject<FormRefObject | null>
// Computed values
oauthClientSchema: TriggerOAuthConfig['oauth_client_schema']
confirmButtonText: string
// Handlers
handleAuthorization: () => void
handleRemove: () => void
handleSave: (needAuth: boolean) => void
}
export const useOAuthClientState = ({
oauthConfig,
providerName,
onClose,
showOAuthCreateModal,
}: UseOAuthClientStateParams): UseOAuthClientStateReturn => {
const { t } = useTranslation()
// State management
const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>()
const [authorizationStatus, setAuthorizationStatus] = useState<AuthorizationStatusEnum>()
const [clientType, setClientType] = useState<ClientTypeEnum>(
oauthConfig?.system_configured ? ClientTypeEnum.Default : ClientTypeEnum.Custom,
)
const clientFormRef = useRef<FormRefObject>(null)
// Mutations
const { mutate: initiateOAuth } = useInitiateTriggerOAuth()
const { mutate: verifyBuilder } = useVerifyAndUpdateTriggerSubscriptionBuilder()
const { mutate: configureOAuth } = useConfigureTriggerOAuth()
const { mutate: deleteOAuth } = useDeleteTriggerOAuth()
// Compute OAuth client schema with default values
const oauthClientSchema = useMemo(() => {
const { oauth_client_schema, params } = oauthConfig || {}
if (!oauth_client_schema?.length || !params)
return []
const paramKeys = Object.keys(params)
return oauth_client_schema.map(schema => ({
...schema,
default: paramKeys.includes(schema.name) ? params[schema.name] : schema.default,
}))
}, [oauthConfig])
// Compute confirm button text based on authorization status
const confirmButtonText = useMemo(() => {
if (authorizationStatus === AuthorizationStatusEnum.Pending)
return t('modal.common.authorizing', { ns: 'pluginTrigger' })
if (authorizationStatus === AuthorizationStatusEnum.Success)
return t('modal.oauth.authorization.waitingJump', { ns: 'pluginTrigger' })
return t('auth.saveAndAuth', { ns: 'plugin' })
}, [authorizationStatus, t])
// Authorization handler
const handleAuthorization = useCallback(() => {
setAuthorizationStatus(AuthorizationStatusEnum.Pending)
initiateOAuth(providerName, {
onSuccess: (response) => {
setSubscriptionBuilder(response.subscription_builder)
openOAuthPopup(response.authorization_url, (callbackData) => {
if (!callbackData)
return
Toast.notify({
type: 'success',
message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }),
})
onClose()
showOAuthCreateModal(response.subscription_builder)
})
},
onError: () => {
setAuthorizationStatus(AuthorizationStatusEnum.Failed)
Toast.notify({
type: 'error',
message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }),
})
},
})
}, [providerName, initiateOAuth, onClose, showOAuthCreateModal, t])
// Remove handler
const handleRemove = useCallback(() => {
deleteOAuth(providerName, {
onSuccess: () => {
onClose()
Toast.notify({
type: 'success',
message: t('modal.oauth.remove.success', { ns: 'pluginTrigger' }),
})
},
onError: (error: unknown) => {
Toast.notify({
type: 'error',
message: getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' })),
})
},
})
}, [providerName, deleteOAuth, onClose, t])
// Save handler
const handleSave = useCallback((needAuth: boolean) => {
const isCustom = clientType === ClientTypeEnum.Custom
const params: ConfigureTriggerOAuthPayload = {
provider: providerName,
enabled: isCustom,
}
if (isCustom && oauthClientSchema?.length) {
const clientFormValues = clientFormRef.current?.getFormValues({}) as {
values: TriggerOAuthClientParams
isCheckValidated: boolean
} | undefined
// Handle missing ref or form values
if (!clientFormValues || !clientFormValues.isCheckValidated)
return
const clientParams = { ...clientFormValues.values }
// Preserve hidden values if unchanged
if (clientParams.client_id === oauthConfig?.params.client_id)
clientParams.client_id = '[__HIDDEN__]'
if (clientParams.client_secret === oauthConfig?.params.client_secret)
clientParams.client_secret = '[__HIDDEN__]'
params.client_params = clientParams
}
configureOAuth(params, {
onSuccess: () => {
if (needAuth) {
handleAuthorization()
return
}
onClose()
Toast.notify({
type: 'success',
message: t('modal.oauth.save.success', { ns: 'pluginTrigger' }),
})
},
})
}, [clientType, providerName, oauthClientSchema, oauthConfig?.params, configureOAuth, handleAuthorization, onClose, t])
// Polling effect for authorization verification
useEffect(() => {
const shouldPoll = providerName
&& subscriptionBuilder
&& authorizationStatus === AuthorizationStatusEnum.Pending
if (!shouldPoll)
return
const pollInterval = setInterval(() => {
verifyBuilder(
{
provider: providerName,
subscriptionBuilderId: subscriptionBuilder.id,
},
{
onSuccess: (response) => {
if (response.verified) {
setAuthorizationStatus(AuthorizationStatusEnum.Success)
clearInterval(pollInterval)
}
},
onError: () => {
// Continue polling on error - auth might still be in progress
},
},
)
}, POLL_INTERVAL_MS)
return () => clearInterval(pollInterval)
}, [subscriptionBuilder, authorizationStatus, verifyBuilder, providerName])
return {
clientType,
setClientType,
authorizationStatus,
clientFormRef,
oauthClientSchema,
confirmButtonText,
handleAuthorization,
handleRemove,
handleSave,
}
}

View File

@ -6,6 +6,9 @@ import { SupportedCreationMethods } from '@/app/components/plugins/types'
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
import { CreateButtonType, CreateSubscriptionButton, DEFAULT_METHOD } from './index'
// ==================== Mock Setup ====================
// Mock shared state for portal
let mockPortalOpenState = false
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
@ -33,18 +36,21 @@ vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
},
}))
// Mock Toast
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: vi.fn(),
},
}))
// Mock zustand store
let mockStoreDetail: SimpleDetail | undefined
vi.mock('../../store', () => ({
usePluginStore: (selector: (state: { detail: SimpleDetail | undefined }) => SimpleDetail | undefined) =>
selector({ detail: mockStoreDetail }),
}))
// Mock subscription list hook
const mockSubscriptions: TriggerSubscription[] = []
const mockRefetch = vi.fn()
vi.mock('../use-subscription-list', () => ({
@ -54,6 +60,7 @@ vi.mock('../use-subscription-list', () => ({
}),
}))
// Mock trigger service hooks
let mockProviderInfo: { data: TriggerProviderApiEntity | undefined } = { data: undefined }
let mockOAuthConfig: { data: TriggerOAuthConfig | undefined, refetch: () => void } = { data: undefined, refetch: vi.fn() }
const mockInitiateOAuth = vi.fn()
@ -66,12 +73,14 @@ vi.mock('@/service/use-triggers', () => ({
}),
}))
// Mock OAuth popup
vi.mock('@/hooks/use-oauth', () => ({
openOAuthPopup: vi.fn((url: string, callback: (data?: unknown) => void) => {
callback({ success: true, subscriptionId: 'test-subscription' })
}),
}))
// Mock child modals
vi.mock('./common-modal', () => ({
CommonCreateModal: ({ createType, onClose, builder }: {
createType: SupportedCreationMethods
@ -119,6 +128,7 @@ vi.mock('./oauth-client', () => ({
),
}))
// Mock CustomSelect
vi.mock('@/app/components/base/select/custom', () => ({
default: ({ options, value, onChange, CustomTrigger, CustomOption, containerProps }: {
options: Array<{ value: string, label: string, show: boolean, extra?: React.ReactNode, tag?: React.ReactNode }>
@ -150,6 +160,11 @@ vi.mock('@/app/components/base/select/custom', () => ({
),
}))
// ==================== Test Utilities ====================
/**
* Factory function to create a TriggerProviderApiEntity with defaults
*/
const createProviderInfo = (overrides: Partial<TriggerProviderApiEntity> = {}): TriggerProviderApiEntity => ({
author: 'test-author',
name: 'test-provider',
@ -164,6 +179,9 @@ const createProviderInfo = (overrides: Partial<TriggerProviderApiEntity> = {}):
...overrides,
})
/**
* Factory function to create a TriggerOAuthConfig with defaults
*/
const createOAuthConfig = (overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig => ({
configured: false,
custom_configured: false,
@ -178,6 +196,9 @@ const createOAuthConfig = (overrides: Partial<TriggerOAuthConfig> = {}): Trigger
...overrides,
})
/**
* Factory function to create a SimpleDetail with defaults
*/
const createStoreDetail = (overrides: Partial<SimpleDetail> = {}): SimpleDetail => ({
plugin_id: 'test-plugin',
name: 'Test Plugin',
@ -188,6 +209,9 @@ const createStoreDetail = (overrides: Partial<SimpleDetail> = {}): SimpleDetail
...overrides,
})
/**
* Factory function to create a TriggerSubscription with defaults
*/
const createSubscription = (overrides: Partial<TriggerSubscription> = {}): TriggerSubscription => ({
id: 'test-subscription',
name: 'Test Subscription',
@ -201,10 +225,16 @@ const createSubscription = (overrides: Partial<TriggerSubscription> = {}): Trigg
...overrides,
})
/**
* Factory function to create default props
*/
const createDefaultProps = (overrides: Partial<Parameters<typeof CreateSubscriptionButton>[0]> = {}) => ({
...overrides,
})
/**
* Helper to set up mock data for testing
*/
const setupMocks = (config: {
providerInfo?: TriggerProviderApiEntity
oauthConfig?: TriggerOAuthConfig
@ -219,6 +249,8 @@ const setupMocks = (config: {
mockSubscriptions.push(...config.subscriptions)
}
// ==================== Tests ====================
describe('CreateSubscriptionButton', () => {
beforeEach(() => {
vi.clearAllMocks()
@ -226,6 +258,7 @@ describe('CreateSubscriptionButton', () => {
setupMocks()
})
// ==================== Rendering Tests ====================
describe('Rendering', () => {
it('should render null when supportedMethods is empty', () => {
// Arrange
@ -289,6 +322,7 @@ describe('CreateSubscriptionButton', () => {
})
})
// ==================== Props Testing ====================
describe('Props', () => {
it('should apply default buttonType as FULL_BUTTON', () => {
// Arrange
@ -321,6 +355,7 @@ describe('CreateSubscriptionButton', () => {
})
})
// ==================== State Management ====================
describe('State Management', () => {
it('should show CommonCreateModal when selectedCreateInfo is set', async () => {
// Arrange
@ -439,6 +474,7 @@ describe('CreateSubscriptionButton', () => {
})
})
// ==================== Memoization Logic ====================
describe('Memoization - buttonTextMap', () => {
it('should display correct button text for OAUTH method', () => {
// Arrange

View File

@ -2,7 +2,7 @@ import type { Option } from '@/app/components/base/select/custom'
import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
import { RiAddLine, RiEqualizer2Line } from '@remixicon/react'
import { useBoolean } from 'ahooks'
import { useCallback, useMemo, useState } from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { ActionButton, ActionButtonState } from '@/app/components/base/action-button'
import Badge from '@/app/components/base/badge'
@ -18,7 +18,11 @@ import { usePluginStore } from '../../store'
import { useSubscriptionList } from '../use-subscription-list'
import { CommonCreateModal } from './common-modal'
import { OAuthClientSettingsModal } from './oauth-client'
import { CreateButtonType, DEFAULT_METHOD } from './types'
export enum CreateButtonType {
FULL_BUTTON = 'full-button',
ICON_BUTTON = 'icon-button',
}
type Props = {
className?: string
@ -28,6 +32,8 @@ type Props = {
const MAX_COUNT = 10
export const DEFAULT_METHOD = 'default'
export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BUTTON, shape = 'square' }: Props) => {
const { t } = useTranslation()
const { subscriptions } = useSubscriptionList()
@ -37,7 +43,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
const detail = usePluginStore(state => state.detail)
const { data: providerInfo } = useTriggerProviderInfo(detail?.provider || '')
const supportedMethods = useMemo(() => providerInfo?.supported_creation_methods || [], [providerInfo?.supported_creation_methods])
const supportedMethods = providerInfo?.supported_creation_methods || []
const { data: oauthConfig, refetch: refetchOAuthConfig } = useTriggerOAuthConfig(detail?.provider || '', supportedMethods.includes(SupportedCreationMethods.OAUTH))
const { mutate: initiateOAuth } = useInitiateTriggerOAuth()
@ -57,11 +63,11 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
}
}, [t])
const onClickClientSettings = useCallback((e: React.MouseEvent<HTMLDivElement | HTMLButtonElement>) => {
const onClickClientSettings = (e: React.MouseEvent<HTMLDivElement | HTMLButtonElement>) => {
e.stopPropagation()
e.preventDefault()
showClientSettingsModal()
}, [showClientSettingsModal])
}
const allOptions = useMemo(() => {
const showCustomBadge = oauthConfig?.custom_enabled && oauthConfig?.custom_configured
@ -98,7 +104,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
show: supportedMethods.includes(SupportedCreationMethods.MANUAL),
},
]
}, [t, oauthConfig, supportedMethods, methodType, onClickClientSettings])
}, [t, oauthConfig, supportedMethods, methodType])
const onChooseCreateType = async (type: SupportedCreationMethods) => {
if (type === SupportedCreationMethods.OAUTH) {
@ -154,7 +160,7 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
<CustomSelect<Option & { show: boolean, extra?: React.ReactNode, tag?: React.ReactNode }>
options={allOptions.filter(option => option.show)}
value={methodType}
onChange={value => onChooseCreateType(value as SupportedCreationMethods)}
onChange={value => onChooseCreateType(value as any)}
containerProps={{
open: (methodType === DEFAULT_METHOD || (methodType === SupportedCreationMethods.OAUTH && supportedMethods.length === 1)) ? undefined : false,
placement: 'bottom-start',
@ -248,5 +254,3 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU
</>
)
}
export { CreateButtonType, DEFAULT_METHOD } from './types'

View File

@ -3,14 +3,24 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
// Import after mocks
import { OAuthClientSettingsModal } from './oauth-client'
// ============================================================================
// Type Definitions
// ============================================================================
type PluginDetail = {
plugin_id: string
provider: string
name: string
}
// ============================================================================
// Mock Factory Functions
// ============================================================================
function createMockOAuthConfig(overrides: Partial<TriggerOAuthConfig> = {}): TriggerOAuthConfig {
return {
configured: true,
@ -54,12 +64,18 @@ function createMockSubscriptionBuilder(overrides: Partial<TriggerSubscriptionBui
}
}
// ============================================================================
// Mock Setup
// ============================================================================
// Mock plugin store
const mockPluginDetail = createMockPluginDetail()
const mockUsePluginStore = vi.fn(() => mockPluginDetail)
vi.mock('../../store', () => ({
usePluginStore: () => mockUsePluginStore(),
}))
// Mock service hooks
const mockInitiateOAuth = vi.fn()
const mockVerifyBuilder = vi.fn()
const mockConfigureOAuth = vi.fn()
@ -80,11 +96,13 @@ vi.mock('@/service/use-triggers', () => ({
}),
}))
// Mock OAuth popup
const mockOpenOAuthPopup = vi.fn()
vi.mock('@/hooks/use-oauth', () => ({
openOAuthPopup: (url: string, callback: (data: unknown) => void) => mockOpenOAuthPopup(url, callback),
}))
// Mock toast
const mockToastNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
default: {
@ -92,6 +110,7 @@ vi.mock('@/app/components/base/toast', () => ({
},
}))
// Mock clipboard API
const mockClipboardWriteText = vi.fn()
Object.assign(navigator, {
clipboard: {
@ -99,6 +118,7 @@ Object.assign(navigator, {
},
})
// Mock Modal component
vi.mock('@/app/components/base/modal/modal', () => ({
default: ({
children,
@ -141,6 +161,24 @@ vi.mock('@/app/components/base/modal/modal', () => ({
),
}))
// Mock Button component
vi.mock('@/app/components/base/button', () => ({
default: ({ children, onClick, variant, className }: {
children: React.ReactNode
onClick?: () => void
variant?: string
className?: string
}) => (
<button
data-testid={`button-${variant || 'default'}`}
onClick={onClick}
className={className}
>
{children}
</button>
),
}))
// Configurable form mock values
let mockFormValues: { values: Record<string, string>, isCheckValidated: boolean } = {
values: { client_id: 'test-client-id', client_secret: 'test-client-secret' },
isCheckValidated: true,
@ -172,6 +210,29 @@ vi.mock('@/app/components/base/form/components/base', () => ({
}),
}))
// Mock OptionCard component
vi.mock('@/app/components/workflow/nodes/_base/components/option-card', () => ({
default: ({ title, onSelect, selected, className }: {
title: string
onSelect: () => void
selected: boolean
className?: string
}) => (
<div
data-testid={`option-card-${title}`}
onClick={onSelect}
className={`${className} ${selected ? 'selected' : ''}`}
data-selected={selected}
>
{title}
</div>
),
}))
// ============================================================================
// Test Suites
// ============================================================================
describe('OAuthClientSettingsModal', () => {
const defaultProps = {
oauthConfig: createMockOAuthConfig(),
@ -183,6 +244,7 @@ describe('OAuthClientSettingsModal', () => {
vi.clearAllMocks()
mockUsePluginStore.mockReturnValue(mockPluginDetail)
mockClipboardWriteText.mockResolvedValue(undefined)
// Reset form values to default
setMockFormValues({
values: { client_id: 'test-client-id', client_secret: 'test-client-secret' },
isCheckValidated: true,
@ -203,8 +265,8 @@ describe('OAuthClientSettingsModal', () => {
it('should render client type selector when system_configured is true', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
expect(screen.getByText('pluginTrigger.subscription.addType.options.oauth.default')).toBeInTheDocument()
expect(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom')).toBeInTheDocument()
expect(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')).toBeInTheDocument()
expect(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')).toBeInTheDocument()
})
it('should not render client type selector when system_configured is false', () => {
@ -214,7 +276,7 @@ describe('OAuthClientSettingsModal', () => {
render(<OAuthClientSettingsModal {...defaultProps} oauthConfig={configWithoutSystemConfigured} />)
expect(screen.queryByText('pluginTrigger.subscription.addType.options.oauth.default')).not.toBeInTheDocument()
expect(screen.queryByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')).not.toBeInTheDocument()
})
it('should render redirect URI info when custom client type is selected', () => {
@ -257,29 +319,29 @@ describe('OAuthClientSettingsModal', () => {
it('should default to Default client type when system_configured is true', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
const defaultCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.default').closest('div')
expect(defaultCard).toHaveClass('border-[1.5px]')
const defaultCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')
expect(defaultCard).toHaveAttribute('data-selected', 'true')
})
it('should switch to Custom client type when Custom card is clicked', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')
fireEvent.click(customCard!)
const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')
fireEvent.click(customCard)
expect(customCard).toHaveClass('border-[1.5px]')
expect(customCard).toHaveAttribute('data-selected', 'true')
})
it('should switch back to Default client type when Default card is clicked', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')
fireEvent.click(customCard!)
const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')
fireEvent.click(customCard)
const defaultCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.default').closest('div')
fireEvent.click(defaultCard!)
const defaultCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.default')
fireEvent.click(defaultCard)
expect(defaultCard).toHaveClass('border-[1.5px]')
expect(defaultCard).toHaveAttribute('data-selected', 'true')
})
})
@ -790,8 +852,8 @@ describe('OAuthClientSettingsModal', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
// Switch to custom
const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')
fireEvent.click(customCard!)
const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')
fireEvent.click(customCard)
fireEvent.click(screen.getByTestId('modal-cancel'))
@ -992,7 +1054,7 @@ describe('OAuthClientSettingsModal', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
// Switch to custom type
const customCard = screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!
const customCard = screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom')
fireEvent.click(customCard)
fireEvent.click(screen.getByTestId('modal-cancel'))
@ -1015,7 +1077,7 @@ describe('OAuthClientSettingsModal', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
// Switch to custom type
fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!)
fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom'))
fireEvent.click(screen.getByTestId('modal-cancel'))
@ -1042,7 +1104,7 @@ describe('OAuthClientSettingsModal', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
// Switch to custom type
fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!)
fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom'))
fireEvent.click(screen.getByTestId('modal-cancel'))
@ -1069,7 +1131,7 @@ describe('OAuthClientSettingsModal', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
// Switch to custom type
fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!)
fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom'))
fireEvent.click(screen.getByTestId('modal-cancel'))
@ -1096,7 +1158,7 @@ describe('OAuthClientSettingsModal', () => {
render(<OAuthClientSettingsModal {...defaultProps} />)
// Switch to custom type
fireEvent.click(screen.getByText('pluginTrigger.subscription.addType.options.oauth.custom').closest('div')!)
fireEvent.click(screen.getByTestId('option-card-pluginTrigger.subscription.addType.options.oauth.custom'))
fireEvent.click(screen.getByTestId('modal-cancel'))

View File

@ -1,17 +1,27 @@
'use client'
import type { TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
import type { FormRefObject } from '@/app/components/base/form/types'
import type { TriggerOAuthClientParams, TriggerOAuthConfig, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
import type { ConfigureTriggerOAuthPayload } from '@/service/use-triggers'
import {
RiClipboardLine,
RiInformation2Fill,
} from '@remixicon/react'
import * as React from 'react'
import { useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import { BaseForm } from '@/app/components/base/form/components/base'
import Modal from '@/app/components/base/modal/modal'
import Toast from '@/app/components/base/toast'
import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card'
import { openOAuthPopup } from '@/hooks/use-oauth'
import {
useConfigureTriggerOAuth,
useDeleteTriggerOAuth,
useInitiateTriggerOAuth,
useVerifyAndUpdateTriggerSubscriptionBuilder,
} from '@/service/use-triggers'
import { usePluginStore } from '../../store'
import { ClientTypeEnum, useOAuthClientState } from './hooks/use-oauth-client-state'
type Props = {
oauthConfig?: TriggerOAuthConfig
@ -19,38 +29,169 @@ type Props = {
showOAuthCreateModal: (builder: TriggerSubscriptionBuilder) => void
}
const CLIENT_TYPE_OPTIONS = [ClientTypeEnum.Default, ClientTypeEnum.Custom] as const
enum AuthorizationStatusEnum {
Pending = 'pending',
Success = 'success',
Failed = 'failed',
}
enum ClientTypeEnum {
Default = 'default',
Custom = 'custom',
}
export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreateModal }: Props) => {
const { t } = useTranslation()
const detail = usePluginStore(state => state.detail)
const { system_configured, params, oauth_client_schema } = oauthConfig || {}
const [subscriptionBuilder, setSubscriptionBuilder] = useState<TriggerSubscriptionBuilder | undefined>()
const [authorizationStatus, setAuthorizationStatus] = useState<AuthorizationStatusEnum>()
const [clientType, setClientType] = useState<ClientTypeEnum>(system_configured ? ClientTypeEnum.Default : ClientTypeEnum.Custom)
const clientFormRef = React.useRef<FormRefObject>(null)
const oauthClientSchema = useMemo(() => {
if (oauth_client_schema && oauth_client_schema.length > 0 && params) {
const oauthConfigPramaKeys = Object.keys(params || {})
for (const schema of oauth_client_schema) {
if (oauthConfigPramaKeys.includes(schema.name))
schema.default = params?.[schema.name]
}
return oauth_client_schema
}
return []
}, [oauth_client_schema, params])
const providerName = detail?.provider || ''
const { mutate: initiateOAuth } = useInitiateTriggerOAuth()
const { mutate: verifyBuilder } = useVerifyAndUpdateTriggerSubscriptionBuilder()
const { mutate: configureOAuth } = useConfigureTriggerOAuth()
const { mutate: deleteOAuth } = useDeleteTriggerOAuth()
const {
clientType,
setClientType,
clientFormRef,
oauthClientSchema,
confirmButtonText,
handleRemove,
handleSave,
} = useOAuthClientState({
oauthConfig,
providerName,
onClose,
showOAuthCreateModal,
})
const confirmButtonText = useMemo(() => {
if (authorizationStatus === AuthorizationStatusEnum.Pending)
return t('modal.common.authorizing', { ns: 'pluginTrigger' })
if (authorizationStatus === AuthorizationStatusEnum.Success)
return t('modal.oauth.authorization.waitingJump', { ns: 'pluginTrigger' })
return t('auth.saveAndAuth', { ns: 'plugin' })
}, [authorizationStatus, t])
const isCustomClient = clientType === ClientTypeEnum.Custom
const showRemoveButton = oauthConfig?.custom_enabled && oauthConfig?.params && isCustomClient
const showRedirectInfo = isCustomClient && oauthConfig?.redirect_uri
const showClientForm = isCustomClient && oauthClientSchema.length > 0
const getErrorMessage = (error: unknown, fallback: string) => {
if (error instanceof Error && error.message)
return error.message
if (typeof error === 'object' && error && 'message' in error) {
const message = (error as { message?: string }).message
if (typeof message === 'string' && message)
return message
}
return fallback
}
const handleCopyRedirectUri = () => {
navigator.clipboard.writeText(oauthConfig?.redirect_uri || '')
Toast.notify({
type: 'success',
message: t('actionMsg.copySuccessfully', { ns: 'common' }),
const handleAuthorization = () => {
setAuthorizationStatus(AuthorizationStatusEnum.Pending)
initiateOAuth(providerName, {
onSuccess: (response) => {
setSubscriptionBuilder(response.subscription_builder)
openOAuthPopup(response.authorization_url, (callbackData) => {
if (callbackData) {
Toast.notify({
type: 'success',
message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }),
})
onClose()
showOAuthCreateModal(response.subscription_builder)
}
})
},
onError: () => {
setAuthorizationStatus(AuthorizationStatusEnum.Failed)
Toast.notify({
type: 'error',
message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }),
})
},
})
}
useEffect(() => {
if (providerName && subscriptionBuilder && authorizationStatus === AuthorizationStatusEnum.Pending) {
const pollInterval = setInterval(() => {
verifyBuilder(
{
provider: providerName,
subscriptionBuilderId: subscriptionBuilder.id,
},
{
onSuccess: (response) => {
if (response.verified) {
setAuthorizationStatus(AuthorizationStatusEnum.Success)
clearInterval(pollInterval)
}
},
onError: () => {
// Continue polling - auth might still be in progress
},
},
)
}, 3000)
return () => clearInterval(pollInterval)
}
}, [subscriptionBuilder, authorizationStatus, verifyBuilder, providerName, t])
const handleRemove = () => {
deleteOAuth(providerName, {
onSuccess: () => {
onClose()
Toast.notify({
type: 'success',
message: t('modal.oauth.remove.success', { ns: 'pluginTrigger' }),
})
},
onError: (error: unknown) => {
Toast.notify({
type: 'error',
message: getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' })),
})
},
})
}
const handleSave = (needAuth: boolean) => {
const isCustom = clientType === ClientTypeEnum.Custom
const params: ConfigureTriggerOAuthPayload = {
provider: providerName,
enabled: isCustom,
}
if (isCustom) {
const clientFormValues = clientFormRef.current?.getFormValues({}) as { values: TriggerOAuthClientParams, isCheckValidated: boolean }
if (!clientFormValues.isCheckValidated)
return
const clientParams = clientFormValues.values
if (clientParams.client_id === oauthConfig?.params.client_id)
clientParams.client_id = '[__HIDDEN__]'
if (clientParams.client_secret === oauthConfig?.params.client_secret)
clientParams.client_secret = '[__HIDDEN__]'
params.client_params = clientParams
}
configureOAuth(params, {
onSuccess: () => {
if (needAuth) {
handleAuthorization()
}
else {
onClose()
Toast.notify({
type: 'success',
message: t('modal.oauth.save.success', { ns: 'pluginTrigger' }),
})
}
},
})
}
@ -67,25 +208,25 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate
onClose={onClose}
onCancel={() => handleSave(false)}
onConfirm={() => handleSave(true)}
footerSlot={showRemoveButton && (
<div className="grow">
<Button
variant="secondary"
className="text-components-button-destructive-secondary-text"
onClick={handleRemove}
>
{t('operation.remove', { ns: 'common' })}
</Button>
</div>
)}
footerSlot={
oauthConfig?.custom_enabled && oauthConfig?.params && clientType === ClientTypeEnum.Custom && (
<div className="grow">
<Button
variant="secondary"
className="text-components-button-destructive-secondary-text"
// disabled={disabled || doingAction || !editValues}
onClick={handleRemove}
>
{t('operation.remove', { ns: 'common' })}
</Button>
</div>
)
}
>
<div className="system-sm-medium mb-2 text-text-secondary">
{t('subscription.addType.options.oauth.clientTitle', { ns: 'pluginTrigger' })}
</div>
<div className="system-sm-medium mb-2 text-text-secondary">{t('subscription.addType.options.oauth.clientTitle', { ns: 'pluginTrigger' })}</div>
{oauthConfig?.system_configured && (
<div className="mb-4 flex w-full items-start justify-between gap-2">
{CLIENT_TYPE_OPTIONS.map(option => (
{[ClientTypeEnum.Default, ClientTypeEnum.Custom].map(option => (
<OptionCard
key={option}
title={t(`subscription.addType.options.oauth.${option}`, { ns: 'pluginTrigger' })}
@ -96,8 +237,7 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate
))}
</div>
)}
{showRedirectInfo && (
{clientType === ClientTypeEnum.Custom && oauthConfig?.redirect_uri && (
<div className="mb-4 flex items-start gap-3 rounded-xl bg-background-section-burn p-4">
<div className="rounded-lg border-[0.5px] border-components-card-border bg-components-card-bg p-2 shadow-xs shadow-shadow-shadow-3">
<RiInformation2Fill className="h-5 w-5 shrink-0 text-text-accent" />
@ -107,12 +247,18 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate
{t('modal.oauthRedirectInfo', { ns: 'pluginTrigger' })}
</div>
<div className="system-sm-medium my-1.5 break-all leading-4">
{oauthConfig?.redirect_uri}
{oauthConfig.redirect_uri}
</div>
<Button
variant="secondary"
size="small"
onClick={handleCopyRedirectUri}
onClick={() => {
navigator.clipboard.writeText(oauthConfig.redirect_uri)
Toast.notify({
type: 'success',
message: t('actionMsg.copySuccessfully', { ns: 'common' }),
})
}}
>
<RiClipboardLine className="mr-1 h-[14px] w-[14px]" />
{t('operation.copy', { ns: 'common' })}
@ -120,8 +266,7 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate
</div>
</div>
)}
{showClientForm && (
{clientType === ClientTypeEnum.Custom && oauthClientSchema.length > 0 && (
<BaseForm
formSchemas={oauthClientSchema}
ref={clientFormRef}

View File

@ -1,6 +0,0 @@
export enum CreateButtonType {
FULL_BUTTON = 'full-button',
ICON_BUTTON = 'icon-button',
}
export const DEFAULT_METHOD = 'default'

View File

@ -129,6 +129,7 @@ export const useToolSelectorState = ({
extra: {
description: tool.tool_description,
},
type: tool.provider_type,
}
}, [])

View File

@ -87,6 +87,7 @@ export type ToolValue = {
enabled?: boolean
extra?: { description?: string } & Record<string, unknown>
credential_id?: string
type?: string
}
export type DataSourceItem = {

View File

@ -196,19 +196,19 @@ describe('useDocLink', () => {
const { result } = renderHook(() => useDocLink())
const url = result.current('/api-reference/annotations/create-annotation')
expect(url).toBe(`${defaultDocBaseUrl}/api-reference/annotations/create-annotation`)
expect(url).toBe(`${defaultDocBaseUrl}/en/api-reference/annotations/create-annotation`)
})
it('should keep original path when no translation exists for non-English locale', () => {
vi.mocked(useTranslation).mockReturnValue({
i18n: { language: 'zh-Hans' },
i18n: { language: 'ja-JP' },
} as ReturnType<typeof useTranslation>)
vi.mocked(getDocLanguage).mockReturnValue('zh')
vi.mocked(getDocLanguage).mockReturnValue('ja')
const { result } = renderHook(() => useDocLink())
// This path has no Japanese translation
const url = result.current('/api-reference/annotations/create-annotation')
expect(url).toBe(`${defaultDocBaseUrl}/api-reference/标注管理/创建标注`)
expect(url).toBe(`${defaultDocBaseUrl}/ja/api-reference/annotations/create-annotation`)
})
it('should remove language prefix when translation is applied', () => {

View File

@ -35,13 +35,12 @@ export const useDocLink = (baseUrl?: string): ((path?: DocPathWithoutLang, pathM
let targetPath = (pathMap) ? pathMap[locale] || pathUrl : pathUrl
let languagePrefix = `/${docLanguage}`
if (targetPath.startsWith('/api-reference/')) {
languagePrefix = ''
if (docLanguage !== 'en') {
const translatedPath = apiReferencePathTranslations[targetPath]?.[docLanguage]
if (translatedPath) {
targetPath = translatedPath
}
// Translate API reference paths for non-English locales
if (targetPath.startsWith('/api-reference/') && docLanguage !== 'en') {
const translatedPath = apiReferencePathTranslations[targetPath]?.[docLanguage as 'zh' | 'ja']
if (translatedPath) {
targetPath = translatedPath
languagePrefix = ''
}
}

View File

@ -2445,6 +2445,11 @@
"count": 8
}
},
"app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx": {
"ts/no-explicit-any": {
"count": 8
}
},
"app/components/plugins/plugin-detail-panel/datasource-action-list.tsx": {
"ts/no-explicit-any": {
"count": 1
@ -2498,6 +2503,14 @@
"count": 2
}
},
"app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": {
"react-refresh/only-export-components": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}
},
"app/components/plugins/plugin-detail-panel/subscription-list/delete-confirm.tsx": {
"ts/no-explicit-any": {
"count": 1

View File

@ -1,7 +1,7 @@
{
"name": "dify-web",
"type": "module",
"version": "1.12.0",
"version": "1.12.1",
"private": true,
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
"imports": {

View File

@ -1,7 +1,6 @@
import type { App, AppCategory } from '@/models/explore'
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useLocale } from '@/context/i18n'
import { AccessMode } from '@/models/access-control'
import { fetchAppList, fetchBanners, fetchInstalledAppList, getAppAccessModeByAppId, uninstallApp, updatePinStatus } from './explore'
import { AppSourceType, fetchAppMeta, fetchAppParams } from './share'
@ -14,9 +13,8 @@ type ExploreAppListData = {
}
export const useExploreAppList = () => {
const locale = useLocale()
return useQuery<ExploreAppListData>({
queryKey: [NAME_SPACE, 'appList', locale],
queryKey: [NAME_SPACE, 'appList'],
queryFn: async () => {
const { categories, recommended_apps } = await fetchAppList()
return {

View File

@ -9,7 +9,6 @@ import type {
} from '@/types/workflow'
import { get, post } from './base'
import { getFlowPrefix } from './utils'
import { sanitizeWorkflowDraftPayload } from './workflow-payload'
export const fetchWorkflowDraft = (url: string) => {
return get(url, {}, { silent: true }) as Promise<FetchWorkflowDraftResponse>
@ -19,8 +18,7 @@ export const syncWorkflowDraft = ({ url, params }: {
url: string
params: Pick<FetchWorkflowDraftResponse, 'graph' | 'features' | 'environment_variables' | 'conversation_variables'>
}) => {
const sanitized = sanitizeWorkflowDraftPayload(params)
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: sanitized }, { silent: true })
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: params }, { silent: true })
}
export const fetchNodesDefaultConfigs = (url: string) => {