mirror of
https://github.com/langgenius/dify.git
synced 2026-02-06 03:35:36 +08:00
Compare commits
126 Commits
feat/vibe-
...
fix/workfl
| Author | SHA1 | Date | |
|---|---|---|---|
| 031565579a | |||
| aa7fe42615 | |||
| b55c0ec4de | |||
| 8b50c0d920 | |||
| 47f8de3f8e | |||
| 491fa9923b | |||
| ce2c41bbf5 | |||
| 920db69ef2 | |||
| ac222a4dd4 | |||
| 840a975fef | |||
| 9fb72c151c | |||
| 603a896c49 | |||
| 41177757e6 | |||
| 4f826b4641 | |||
| 3216b67bfa | |||
| 7828508b30 | |||
| b8cb5f5ea2 | |||
| 5bc99995fc | |||
| a433d5ed36 | |||
| b58d9e030a | |||
| a4db322440 | |||
| 24b280a0ed | |||
| 90fe9abab7 | |||
| ba568a634d | |||
| f33d99ea01 | |||
| 4346f61b0c | |||
| f90fa2b186 | |||
| b7e752078c | |||
| 5a7dfd15b8 | |||
| 89abea26f9 | |||
| 95d68437d1 | |||
| d6a787497f | |||
| 0cf7827f2a | |||
| cf7fae393c | |||
| 5c0df4a3ef | |||
| 5a3ceb240e | |||
| 4e7226dc39 | |||
| 7815d33871 | |||
| 03e3acfc71 | |||
| fedd097f63 | |||
| aeb41d3b2c | |||
| 5bf0251554 | |||
| f79512ec78 | |||
| c27df88417 | |||
| 8aeef36e2d | |||
| 25ac69afc5 | |||
| 7d1ad7e03a | |||
| 62f46fc55c | |||
| 2626e773d9 | |||
| b9ac7af9c5 | |||
| 74cfe77674 | |||
| 4f2cd40498 | |||
| 0934b89da9 | |||
| 3bcfb4031a | |||
| ceb6914793 | |||
| dbfc47e8b0 | |||
| c2473d85dc | |||
| 5ce3a04a2c | |||
| c30af58ac4 | |||
| 8f414af34e | |||
| b48a10d7ec | |||
| 91532ef429 | |||
| 24ebe2f5c6 | |||
| 7f40f178ed | |||
| e98c1adfbf | |||
| 78198c6452 | |||
| 6fff46bc29 | |||
| 3d414678e3 | |||
| d76ad15fca | |||
| 144ef0880a | |||
| 11259617fa | |||
| caa30ddcc0 | |||
| 8ec4233611 | |||
| e482588ef8 | |||
| b66bd5f5a8 | |||
| c8abe1c306 | |||
| eca26a9b9b | |||
| febc9b930d | |||
| d13638f6e4 | |||
| b4eef76c14 | |||
| cbf7f646d9 | |||
| c58647d39c | |||
| f6be9cd90d | |||
| 4275aa729f | |||
| 360f3bb32f | |||
| 8519b16cfc | |||
| f00d823f9f | |||
| e48419937b | |||
| 5eaf0c733a | |||
| f561656a89 | |||
| f01f555146 | |||
| 47d0e400ae | |||
| 8724ba04aa | |||
| 6fd001c660 | |||
| e8e386a6b9 | |||
| eba5eac3fa | |||
| 19008dce13 | |||
| 92011d0a31 | |||
| a51ced0a4f | |||
| dad8e408b0 | |||
| 0ed0a31ed6 | |||
| d941201a3e | |||
| dd988d42c2 | |||
| a43d2ec4f0 | |||
| 7c12e923b6 | |||
| b9f1d65d4f | |||
| b4e2af96e2 | |||
| 9d38af6d99 | |||
| 0772d49257 | |||
| 67eb8c052d | |||
| 5c4028d557 | |||
| 55e6bca11c | |||
| 67657c2f48 | |||
| e8f9d64651 | |||
| 1f8c730259 | |||
| 8d45755303 | |||
| 6342d196e8 | |||
| 5dc5709d58 | |||
| 99d19cd3db | |||
| fa92548cf6 | |||
| 41428432cc | |||
| b3a869b91b | |||
| f911199c8e | |||
| 056095238b | |||
| c8ae6e39d2 | |||
| 61f8647f37 |
@ -1,12 +1,8 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -23,7 +19,6 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.generator import WorkflowGenerator
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
@ -46,30 +41,6 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
class PreviousWorkflow(BaseModel):
|
||||
"""Previous workflow attempt for regeneration context."""
|
||||
|
||||
nodes: list[dict[str, Any]] = Field(default_factory=list, description="Previously generated nodes")
|
||||
edges: list[dict[str, Any]] = Field(default_factory=list, description="Previously generated edges")
|
||||
warnings: list[str] = Field(default_factory=list, description="Warnings from previous generation")
|
||||
|
||||
|
||||
class FlowchartGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Workflow flowchart generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
available_nodes: list[dict[str, Any]] = Field(default_factory=list, description="Available node types")
|
||||
existing_nodes: list[dict[str, Any]] = Field(default_factory=list, description="Existing workflow nodes")
|
||||
existing_edges: list[dict[str, Any]] = Field(default_factory=list, description="Existing workflow edges")
|
||||
available_tools: list[dict[str, Any]] = Field(default_factory=list, description="Available tools")
|
||||
selected_node_ids: list[str] = Field(default_factory=list, description="IDs of selected nodes for context")
|
||||
previous_workflow: PreviousWorkflow | None = Field(default=None, description="Previous workflow for regeneration")
|
||||
regenerate_mode: bool = Field(default=False, description="Whether this is a regeneration request")
|
||||
# Language preference for generated content (node titles, descriptions)
|
||||
language: str | None = Field(default=None, description="Preferred language for generated content")
|
||||
# Available models that user has configured (for LLM/question-classifier nodes)
|
||||
available_models: list[dict[str, Any]] = Field(default_factory=list, description="User's configured models")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@ -79,7 +50,6 @@ reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(FlowchartGeneratePayload)
|
||||
reg(ModelConfig)
|
||||
|
||||
|
||||
@ -270,52 +240,6 @@ class InstructionGenerateApi(Resource):
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
@console_ns.route("/flowchart-generate")
|
||||
class FlowchartGenerateApi(Resource):
|
||||
@console_ns.doc("generate_workflow_flowchart")
|
||||
@console_ns.doc(description="Generate workflow flowchart using LLM with intent classification")
|
||||
@console_ns.expect(console_ns.models[FlowchartGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Flowchart generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = FlowchartGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
# Convert PreviousWorkflow to dict if present
|
||||
previous_workflow_dict = args.previous_workflow.model_dump() if args.previous_workflow else None
|
||||
|
||||
result = WorkflowGenerator.generate_workflow_flowchart(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
available_nodes=args.available_nodes,
|
||||
existing_nodes=args.existing_nodes,
|
||||
existing_edges=args.existing_edges,
|
||||
available_tools=args.available_tools,
|
||||
selected_node_ids=args.selected_node_ids,
|
||||
previous_workflow=previous_workflow_dict,
|
||||
regenerate_mode=args.regenerate_mode,
|
||||
preferred_language=args.language,
|
||||
available_models=args.available_models,
|
||||
)
|
||||
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/instruction-generate/template")
|
||||
class InstructionGenerationTemplateApi(Resource):
|
||||
@console_ns.doc("get_instruction_template")
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@ -118,56 +117,7 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str
|
||||
files: list | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
class TextToSpeechRequest(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str = ""
|
||||
files: list | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
@ -179,8 +129,10 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
@ -231,7 +183,6 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
@ -239,14 +190,14 @@ class TrialChatApi(TrialAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
request_data = ChatRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
# Validate UUID values if provided
|
||||
if args.get("conversation_id"):
|
||||
args["conversation_id"] = uuid_value(args["conversation_id"])
|
||||
if args.get("parent_message_id"):
|
||||
args["parent_message_id"] = uuid_value(args["parent_message_id"])
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@ -369,16 +320,20 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -416,15 +371,19 @@ class TrialChatTextApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
request_data = CompletionRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@ -1,27 +1,14 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.tag_service import TagService
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
@ -45,115 +32,129 @@ class TagListQueryParam(BaseModel):
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
class TagResponse(BaseModel):
|
||||
id: str = Field(description="Tag ID")
|
||||
name: str = Field(description="Tag name")
|
||||
type: str = Field(description="Tag type")
|
||||
binding_count: int = Field(description="Number of bindings")
|
||||
|
||||
|
||||
class TagBindingResult(BaseModel):
|
||||
result: Literal["success"] = Field(description="Operation result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/tags",
|
||||
response_model=list[TagResponse],
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def list_tags(query: TagListQueryParam) -> list[TagResponse]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tags = TagService.get_tags(query.type, current_tenant_id, query.keyword)
|
||||
|
||||
return [
|
||||
TagResponse(
|
||||
id=tag.id,
|
||||
name=tag.name,
|
||||
type=tag.type,
|
||||
binding_count=int(tag.binding_count),
|
||||
)
|
||||
for tag in tags
|
||||
]
|
||||
|
||||
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
@console_router.post(
|
||||
"/tags",
|
||||
response_model=TagResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def create_tag(payload: TagBasePayload) -> TagResponse:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
return tags, 200
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0)
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_router.patch(
|
||||
"/tags/<uuid:tag_id>",
|
||||
response_model=TagResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id_str = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id_str)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 204
|
||||
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_router.delete(
|
||||
"/tags/<uuid:tag_id>",
|
||||
tags=["console"],
|
||||
status_code=204,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete_tag(tag_id: UUID) -> None:
|
||||
tag_id_str = str(tag_id)
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
TagService.delete_tag(tag_id_str)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_router.post(
|
||||
"/tag-bindings/create",
|
||||
response_model=TagBindingResult,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return TagBindingResult(result="success")
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tag-bindings/remove",
|
||||
response_model=TagBindingResult,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return TagBindingResult(result="success")
|
||||
|
||||
@ -171,10 +171,9 @@ def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETR
|
||||
# httpx may override the Host header when using a proxy
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||
if user_provided_host is not None:
|
||||
headers["Host"] = user_provided_host
|
||||
|
||||
request = client.build_request(method, url, headers=headers, **kwargs)
|
||||
response = client.send(request, follow_redirects=follow_redirects)
|
||||
headers["host"] = user_provided_host
|
||||
kwargs["headers"] = headers
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
# Check for SSRF protection by Squid proxy
|
||||
if response.status_code in (401, 403):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, cast
|
||||
|
||||
@ -13,6 +14,8 @@ from core.llm_generator.prompts import (
|
||||
CONVERSATION_TITLE_PROMPT,
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
@ -29,7 +32,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.generator import WorkflowGenerator
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
@ -283,35 +285,6 @@ class LLMGenerator:
|
||||
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_flowchart(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
available_nodes: Sequence[dict[str, object]] | None = None,
|
||||
existing_nodes: Sequence[dict[str, object]] | None = None,
|
||||
available_tools: Sequence[dict[str, object]] | None = None,
|
||||
selected_node_ids: Sequence[str] | None = None,
|
||||
previous_workflow: dict[str, object] | None = None,
|
||||
regenerate_mode: bool = False,
|
||||
preferred_language: str | None = None,
|
||||
available_models: Sequence[dict[str, object]] | None = None,
|
||||
):
|
||||
return WorkflowGenerator.generate_workflow_flowchart(
|
||||
tenant_id=tenant_id,
|
||||
instruction=instruction,
|
||||
model_config=model_config,
|
||||
available_nodes=available_nodes,
|
||||
existing_nodes=existing_nodes,
|
||||
available_tools=available_tools,
|
||||
selected_node_ids=selected_node_ids,
|
||||
previous_workflow=previous_workflow,
|
||||
regenerate_mode=regenerate_mode,
|
||||
preferred_language=preferred_language,
|
||||
available_models=available_models,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_code(
|
||||
cls,
|
||||
|
||||
@ -143,50 +143,6 @@ Based on task description, please create a well-structured prompt template that
|
||||
Please generate the full prompt template with at least 300 words and output only the prompt template.
|
||||
""" # noqa: E501
|
||||
|
||||
WORKFLOW_FLOWCHART_PROMPT_TEMPLATE = """
|
||||
You are an expert workflow designer. Generate a Mermaid flowchart based on the user's request.
|
||||
|
||||
Constraints:
|
||||
- Detect the language of the user's request. Generate all node titles in the same language as the user's input.
|
||||
- If the input language cannot be determined, use {{PREFERRED_LANGUAGE}} as the fallback language.
|
||||
- Use only node types listed in <available_nodes>.
|
||||
- Use only tools listed in <available_tools>. When using a tool node, set type=tool and tool=<tool_key>.
|
||||
- Tools may include MCP providers (provider_type=mcp). Tool selection still uses tool_key.
|
||||
- Prefer reusing node titles from <existing_nodes> when possible.
|
||||
- Output must be valid Mermaid flowchart syntax, no markdown, no extra text.
|
||||
- First line must be: flowchart LR
|
||||
- Every node must be declared on its own line using:
|
||||
<id>["type=<type>|title=<title>|tool=<tool_key>"]
|
||||
- type is required and must match a type in <available_nodes>.
|
||||
- title is required for non-tool nodes.
|
||||
- tool is required only when type=tool, otherwise omit tool.
|
||||
- Declare all node lines before any edges.
|
||||
- Edges must use:
|
||||
<id> --> <id>
|
||||
<id> -->|true| <id>
|
||||
<id> -->|false| <id>
|
||||
- Keep node ids unique and simple (N1, N2, ...).
|
||||
- For complex orchestration:
|
||||
- Break the request into stages (ingest, transform, decision, action, output).
|
||||
- Use IfElse for branching and label edges true/false only.
|
||||
- Fan-in branches by connecting multiple nodes into a shared downstream node.
|
||||
- Avoid cycles unless explicitly requested.
|
||||
- Keep each branch complete with a clear downstream target.
|
||||
|
||||
<user_request>
|
||||
{{TASK_DESCRIPTION}}
|
||||
</user_request>
|
||||
<available_nodes>
|
||||
{{AVAILABLE_NODES}}
|
||||
</available_nodes>
|
||||
<existing_nodes>
|
||||
{{EXISTING_NODES}}
|
||||
</existing_nodes>
|
||||
<available_tools>
|
||||
{{AVAILABLE_TOOLS}}
|
||||
</available_tools>
|
||||
"""
|
||||
|
||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """
|
||||
Here is a task description for which I would like you to create a high-quality prompt template for:
|
||||
<task_description>
|
||||
|
||||
@ -1 +0,0 @@
|
||||
from .runner import WorkflowGenerator
|
||||
@ -1,29 +0,0 @@
|
||||
"""
|
||||
Vibe Workflow Generator Configuration Module.
|
||||
|
||||
This module centralizes configuration for the Vibe workflow generation feature,
|
||||
including node schemas, fallback rules, and response templates.
|
||||
"""
|
||||
|
||||
from core.workflow.generator.config.node_schemas import (
|
||||
BUILTIN_NODE_SCHEMAS,
|
||||
FALLBACK_RULES,
|
||||
FIELD_NAME_CORRECTIONS,
|
||||
NODE_TYPE_ALIASES,
|
||||
get_builtin_node_schemas,
|
||||
get_corrected_field_name,
|
||||
validate_node_schemas,
|
||||
)
|
||||
from core.workflow.generator.config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES
|
||||
|
||||
__all__ = [
|
||||
"BUILTIN_NODE_SCHEMAS",
|
||||
"DEFAULT_SUGGESTIONS",
|
||||
"FALLBACK_RULES",
|
||||
"FIELD_NAME_CORRECTIONS",
|
||||
"NODE_TYPE_ALIASES",
|
||||
"OFF_TOPIC_RESPONSES",
|
||||
"get_builtin_node_schemas",
|
||||
"get_corrected_field_name",
|
||||
"validate_node_schemas",
|
||||
]
|
||||
@ -1,501 +0,0 @@
|
||||
"""
|
||||
Unified Node Configuration for Vibe Workflow Generation.
|
||||
|
||||
This module centralizes all node-related configuration:
|
||||
- Node schemas (parameter definitions)
|
||||
- Fallback rules (keyword-based node type inference)
|
||||
- Node type aliases (natural language to canonical type mapping)
|
||||
- Field name corrections (LLM output normalization)
|
||||
- Validation utilities
|
||||
|
||||
Note: These definitions are the single source of truth.
|
||||
Frontend has a mirrored copy at web/app/components/workflow/hooks/use-workflow-vibe-config.ts
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
# =============================================================================
|
||||
# NODE SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
# Built-in node schemas with parameter definitions
|
||||
# These help the model understand what config each node type requires
|
||||
_HARDCODED_SCHEMAS: dict[str, dict[str, Any]] = {
|
||||
"http-request": {
|
||||
"description": "Send HTTP requests to external APIs or fetch web content",
|
||||
"required": ["url", "method"],
|
||||
"parameters": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Full URL including protocol (https://...)",
|
||||
"example": "{{#start.url#}} or https://api.example.com/data",
|
||||
},
|
||||
"method": {
|
||||
"type": "enum",
|
||||
"options": ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"],
|
||||
"description": "HTTP method",
|
||||
},
|
||||
"headers": {
|
||||
"type": "string",
|
||||
"description": "HTTP headers as newline-separated 'Key: Value' pairs",
|
||||
"example": "Content-Type: application/json\nAuthorization: Bearer {{#start.api_key#}}",
|
||||
},
|
||||
"params": {
|
||||
"type": "string",
|
||||
"description": "URL query parameters as newline-separated 'key: value' pairs",
|
||||
},
|
||||
"body": {
|
||||
"type": "object",
|
||||
"description": "Request body with type field required",
|
||||
"example": {"type": "none", "data": []},
|
||||
},
|
||||
"authorization": {
|
||||
"type": "object",
|
||||
"description": "Authorization config",
|
||||
"example": {"type": "no-auth"},
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 60,
|
||||
},
|
||||
},
|
||||
"outputs": ["body (response content)", "status_code", "headers"],
|
||||
},
|
||||
"code": {
|
||||
"description": "Execute Python or JavaScript code for custom logic",
|
||||
"required": ["code", "language"],
|
||||
"parameters": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Code to execute. Must define a main() function that returns a dict.",
|
||||
},
|
||||
"language": {
|
||||
"type": "enum",
|
||||
"options": ["python3", "javascript"],
|
||||
},
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"description": "Input variables passed to the code",
|
||||
"item_schema": {"variable": "string", "value_selector": "array"},
|
||||
},
|
||||
"outputs": {
|
||||
"type": "object",
|
||||
"description": "Output variable definitions",
|
||||
},
|
||||
},
|
||||
"outputs": ["Variables defined in outputs schema"],
|
||||
},
|
||||
"llm": {
|
||||
"description": "Call a large language model for text generation/processing",
|
||||
"required": ["prompt_template"],
|
||||
"parameters": {
|
||||
"model": {
|
||||
"type": "object",
|
||||
"description": "Model configuration (provider, name, mode)",
|
||||
},
|
||||
"prompt_template": {
|
||||
"type": "array",
|
||||
"description": "Messages for the LLM",
|
||||
"item_schema": {
|
||||
"role": "enum: system, user, assistant",
|
||||
"text": "string - message content, can include {{#node_id.field#}} references",
|
||||
},
|
||||
},
|
||||
"context": {
|
||||
"type": "object",
|
||||
"description": "Optional context settings",
|
||||
},
|
||||
"memory": {
|
||||
"type": "object",
|
||||
"description": "Optional memory/conversation settings",
|
||||
},
|
||||
},
|
||||
"outputs": ["text (generated response)"],
|
||||
},
|
||||
"if-else": {
|
||||
"description": "Conditional branching based on conditions",
|
||||
"required": ["cases"],
|
||||
"parameters": {
|
||||
"cases": {
|
||||
"type": "array",
|
||||
"description": "List of condition cases. Each case defines when 'true' branch is taken.",
|
||||
"item_schema": {
|
||||
"case_id": "string - unique case identifier (e.g., 'case_1')",
|
||||
"logical_operator": "enum: and, or - how multiple conditions combine",
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"item_schema": {
|
||||
"variable_selector": "array of strings - path to variable, e.g. ['node_id', 'field']",
|
||||
"comparison_operator": (
|
||||
"enum: =, ≠, >, <, ≥, ≤, contains, not contains, is, is not, empty, not empty"
|
||||
),
|
||||
"value": "string or number - value to compare against",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"outputs": ["Branches: true (first case conditions met), false (else/no case matched)"],
|
||||
},
|
||||
"knowledge-retrieval": {
|
||||
"description": "Query knowledge base for relevant content",
|
||||
"required": ["query_variable_selector", "dataset_ids"],
|
||||
"parameters": {
|
||||
"query_variable_selector": {
|
||||
"type": "array",
|
||||
"description": "Path to query variable, e.g. ['start', 'query']",
|
||||
},
|
||||
"dataset_ids": {
|
||||
"type": "array",
|
||||
"description": "List of knowledge base IDs to search",
|
||||
},
|
||||
"retrieval_mode": {
|
||||
"type": "enum",
|
||||
"options": ["single", "multiple"],
|
||||
},
|
||||
},
|
||||
"outputs": ["result (retrieved documents)"],
|
||||
},
|
||||
"template-transform": {
|
||||
"description": "Transform data using Jinja2 templates",
|
||||
"required": ["template", "variables"],
|
||||
"parameters": {
|
||||
"template": {
|
||||
"type": "string",
|
||||
"description": "Jinja2 template string. Use {{ variable_name }} to reference variables.",
|
||||
},
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"description": "Input variables defined for the template",
|
||||
"item_schema": {
|
||||
"variable": "string - variable name to use in template",
|
||||
"value_selector": "array - path to source value, e.g. ['start', 'user_input']",
|
||||
},
|
||||
},
|
||||
},
|
||||
"outputs": ["output (transformed string)"],
|
||||
},
|
||||
"variable-aggregator": {
|
||||
"description": "Aggregate variables from multiple branches",
|
||||
"required": ["variables"],
|
||||
"parameters": {
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"description": "List of variable selectors to aggregate",
|
||||
"item_schema": "array of strings - path to source variable, e.g. ['node_id', 'field']",
|
||||
},
|
||||
},
|
||||
"outputs": ["output (aggregated value)"],
|
||||
},
|
||||
"iteration": {
|
||||
"description": "Loop over array items",
|
||||
"required": ["iterator_selector"],
|
||||
"parameters": {
|
||||
"iterator_selector": {
|
||||
"type": "array",
|
||||
"description": "Path to array variable to iterate",
|
||||
},
|
||||
},
|
||||
"outputs": ["item (current iteration item)", "index (current index)"],
|
||||
},
|
||||
"parameter-extractor": {
|
||||
"description": "Extract structured parameters from user input using LLM",
|
||||
"required": ["query", "parameters"],
|
||||
"parameters": {
|
||||
"model": {
|
||||
"type": "object",
|
||||
"description": "Model configuration (provider, name, mode)",
|
||||
},
|
||||
"query": {
|
||||
"type": "array",
|
||||
"description": "Path to input text to extract parameters from, e.g. ['start', 'user_input']",
|
||||
},
|
||||
"parameters": {
|
||||
"type": "array",
|
||||
"description": "Parameters to extract from the input",
|
||||
"item_schema": {
|
||||
"name": "string - parameter name (required)",
|
||||
"type": (
|
||||
"enum: string, number, boolean, array[string], array[number], array[object], array[boolean]"
|
||||
),
|
||||
"description": "string - description of what to extract (required)",
|
||||
"required": "boolean - whether this parameter is required (MUST be specified)",
|
||||
"options": "array of strings (optional) - for enum-like selection",
|
||||
},
|
||||
},
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": "Additional instructions for extraction",
|
||||
},
|
||||
"reasoning_mode": {
|
||||
"type": "enum",
|
||||
"options": ["function_call", "prompt"],
|
||||
"description": "How to perform extraction (defaults to function_call)",
|
||||
},
|
||||
},
|
||||
"outputs": ["Extracted parameters as defined in parameters array", "__is_success", "__reason"],
|
||||
},
|
||||
"question-classifier": {
|
||||
"description": "Classify user input into predefined categories using LLM",
|
||||
"required": ["query", "classes"],
|
||||
"parameters": {
|
||||
"model": {
|
||||
"type": "object",
|
||||
"description": "Model configuration (provider, name, mode)",
|
||||
},
|
||||
"query": {
|
||||
"type": "array",
|
||||
"description": "Path to input text to classify, e.g. ['start', 'user_input']",
|
||||
},
|
||||
"classes": {
|
||||
"type": "array",
|
||||
"description": "Classification categories",
|
||||
"item_schema": {
|
||||
"id": "string - unique class identifier",
|
||||
"name": "string - class name/label",
|
||||
},
|
||||
},
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": "Additional instructions for classification",
|
||||
},
|
||||
},
|
||||
"outputs": ["class_name (selected class)"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_dynamic_schemas() -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Dynamically load schemas from node classes.
|
||||
Uses lazy import to avoid circular dependency.
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
schemas = {}
|
||||
for node_type, version_map in NODE_TYPE_CLASSES_MAPPING.items():
|
||||
# Get the latest version class
|
||||
node_cls = version_map.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
# Get schema from the class
|
||||
schema = node_cls.get_default_config_schema()
|
||||
if schema:
|
||||
schemas[node_type.value] = schema
|
||||
|
||||
return schemas
|
||||
|
||||
|
||||
# Cache for built-in schemas (populated on first access)
|
||||
_builtin_schemas_cache: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
|
||||
def get_builtin_node_schemas() -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get the complete set of built-in node schemas.
|
||||
Combines hardcoded schemas with dynamically loaded ones.
|
||||
Results are cached after first call.
|
||||
"""
|
||||
global _builtin_schemas_cache
|
||||
if _builtin_schemas_cache is None:
|
||||
_builtin_schemas_cache = {**_HARDCODED_SCHEMAS, **_get_dynamic_schemas()}
|
||||
return _builtin_schemas_cache
|
||||
|
||||
|
||||
# For backward compatibility - but use get_builtin_node_schemas() for lazy loading
|
||||
BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = _HARDCODED_SCHEMAS.copy()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FALLBACK RULES
|
||||
# =============================================================================
|
||||
|
||||
# Keyword rules for smart fallback detection
|
||||
# Maps node type to keywords that suggest using that node type as a fallback
|
||||
FALLBACK_RULES: dict[str, list[str]] = {
|
||||
"http-request": [
|
||||
"http",
|
||||
"url",
|
||||
"web",
|
||||
"scrape",
|
||||
"scraper",
|
||||
"fetch",
|
||||
"api",
|
||||
"request",
|
||||
"download",
|
||||
"upload",
|
||||
"webhook",
|
||||
"endpoint",
|
||||
"rest",
|
||||
"get",
|
||||
"post",
|
||||
],
|
||||
"code": [
|
||||
"code",
|
||||
"script",
|
||||
"calculate",
|
||||
"compute",
|
||||
"process",
|
||||
"transform",
|
||||
"parse",
|
||||
"convert",
|
||||
"format",
|
||||
"filter",
|
||||
"sort",
|
||||
"math",
|
||||
"logic",
|
||||
],
|
||||
"llm": [
|
||||
"analyze",
|
||||
"summarize",
|
||||
"summary",
|
||||
"extract",
|
||||
"classify",
|
||||
"translate",
|
||||
"generate",
|
||||
"write",
|
||||
"rewrite",
|
||||
"explain",
|
||||
"answer",
|
||||
"chat",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NODE TYPE ALIASES
|
||||
# =============================================================================
|
||||
|
||||
# Node type aliases for inference from natural language
|
||||
# Maps common terms to canonical node type names
|
||||
NODE_TYPE_ALIASES: dict[str, str] = {
|
||||
# Start node aliases
|
||||
"start": "start",
|
||||
"begin": "start",
|
||||
"input": "start",
|
||||
# End node aliases
|
||||
"end": "end",
|
||||
"finish": "end",
|
||||
"output": "end",
|
||||
# LLM node aliases
|
||||
"llm": "llm",
|
||||
"ai": "llm",
|
||||
"gpt": "llm",
|
||||
"model": "llm",
|
||||
"chat": "llm",
|
||||
# Code node aliases
|
||||
"code": "code",
|
||||
"script": "code",
|
||||
"python": "code",
|
||||
"javascript": "code",
|
||||
# HTTP request node aliases
|
||||
"http-request": "http-request",
|
||||
"http": "http-request",
|
||||
"request": "http-request",
|
||||
"api": "http-request",
|
||||
"fetch": "http-request",
|
||||
"webhook": "http-request",
|
||||
# Conditional node aliases
|
||||
"if-else": "if-else",
|
||||
"condition": "if-else",
|
||||
"branch": "if-else",
|
||||
"switch": "if-else",
|
||||
# Loop node aliases
|
||||
"iteration": "iteration",
|
||||
"loop": "loop",
|
||||
"foreach": "iteration",
|
||||
# Tool node alias
|
||||
"tool": "tool",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FIELD NAME CORRECTIONS
|
||||
# =============================================================================
|
||||
|
||||
# Field name corrections for LLM-generated node configs
|
||||
# Maps incorrect field names to correct ones for specific node types
|
||||
FIELD_NAME_CORRECTIONS: dict[str, dict[str, str]] = {
|
||||
"http-request": {
|
||||
"text": "body", # LLM might use "text" instead of "body"
|
||||
"content": "body",
|
||||
"response": "body",
|
||||
},
|
||||
"code": {
|
||||
"text": "result", # LLM might use "text" instead of "result"
|
||||
"output": "result",
|
||||
},
|
||||
"llm": {
|
||||
"response": "text",
|
||||
"answer": "text",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_corrected_field_name(node_type: str, field: str) -> str:
|
||||
"""
|
||||
Get the corrected field name for a node type.
|
||||
|
||||
Args:
|
||||
node_type: The type of the node (e.g., "http-request", "code")
|
||||
field: The field name to correct
|
||||
|
||||
Returns:
|
||||
The corrected field name, or the original if no correction needed
|
||||
"""
|
||||
corrections = FIELD_NAME_CORRECTIONS.get(node_type, {})
|
||||
return corrections.get(field, field)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VALIDATION UTILITIES
|
||||
# =============================================================================
|
||||
|
||||
# Node types that are internal and don't need schemas for LLM generation
|
||||
_INTERNAL_NODE_TYPES: set[str] = {
|
||||
# Internal workflow nodes
|
||||
"answer", # Internal to chatflow
|
||||
"loop", # Uses iteration internally
|
||||
"assigner", # Variable assignment utility
|
||||
"variable-assigner", # Variable assignment utility
|
||||
"agent", # Agent node (complex, handled separately)
|
||||
"document-extractor", # Internal document processing
|
||||
"list-operator", # Internal list operations
|
||||
# Iteration internal nodes
|
||||
"iteration-start", # Internal to iteration loop
|
||||
"loop-start", # Internal to loop
|
||||
"loop-end", # Internal to loop
|
||||
# Trigger nodes (not user-creatable via LLM)
|
||||
"trigger-plugin", # Plugin trigger
|
||||
"trigger-schedule", # Scheduled trigger
|
||||
"trigger-webhook", # Webhook trigger
|
||||
# Other internal nodes
|
||||
"datasource", # Data source configuration
|
||||
"human-input", # Human-in-the-loop node
|
||||
"knowledge-index", # Knowledge indexing node
|
||||
}
|
||||
|
||||
|
||||
def validate_node_schemas() -> list[str]:
|
||||
"""
|
||||
Validate that all registered node types have corresponding schemas.
|
||||
|
||||
This function checks if BUILTIN_NODE_SCHEMAS covers all node types
|
||||
registered in NODE_TYPE_CLASSES_MAPPING, excluding internal node types.
|
||||
|
||||
Returns:
|
||||
List of warning messages for missing schemas (empty if all valid)
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
schemas = get_builtin_node_schemas()
|
||||
warnings = []
|
||||
for node_type in NODE_TYPE_CLASSES_MAPPING:
|
||||
type_value = node_type.value
|
||||
if type_value in _INTERNAL_NODE_TYPES:
|
||||
continue
|
||||
if type_value not in schemas:
|
||||
warnings.append(f"Missing schema for node type: {type_value}")
|
||||
return warnings
|
||||
@ -1,72 +0,0 @@
|
||||
"""
|
||||
Response Templates for Vibe Workflow Generation.
|
||||
|
||||
This module defines templates for off-topic responses and default suggestions
|
||||
to guide users back to workflow-related requests.
|
||||
"""
|
||||
|
||||
# Off-topic response templates for different categories
|
||||
# Each category has messages in multiple languages
|
||||
OFF_TOPIC_RESPONSES: dict[str, dict[str, str]] = {
|
||||
"weather": {
|
||||
"en": (
|
||||
"I'm the workflow design assistant - I can't check the weather, "
|
||||
"but I can help you build AI workflows! For example, I could help you "
|
||||
"create a workflow that fetches weather data from an API."
|
||||
),
|
||||
"zh": "我是工作流设计助手,无法查询天气。但我可以帮你创建一个从API获取天气数据的工作流!",
|
||||
},
|
||||
"math": {
|
||||
"en": (
|
||||
"I focus on workflow design rather than calculations. However, "
|
||||
"if you need calculations in a workflow, I can help you add a Code node "
|
||||
"that handles math operations!"
|
||||
),
|
||||
"zh": "我专注于工作流设计而非计算。但如果您需要在工作流中进行计算,我可以帮您添加一个处理数学运算的代码节点!",
|
||||
},
|
||||
"joke": {
|
||||
"en": (
|
||||
"While I'd love to share a laugh, I'm specialized in workflow design. "
|
||||
"How about we create something fun instead - like a workflow that generates jokes using AI?"
|
||||
),
|
||||
"zh": "虽然我很想讲笑话,但我专门从事工作流设计。不如我们创建一个有趣的东西——比如使用AI生成笑话的工作流?",
|
||||
},
|
||||
"translation": {
|
||||
"en": (
|
||||
"I can't translate directly, but I can help you build a translation workflow! "
|
||||
"Would you like to create one using an LLM node?"
|
||||
),
|
||||
"zh": "我不能直接翻译,但我可以帮你构建一个翻译工作流!要创建一个使用LLM节点的翻译流程吗?",
|
||||
},
|
||||
"general_coding": {
|
||||
"en": (
|
||||
"I'm specialized in Dify workflow design rather than general coding help. "
|
||||
"But if you want to add code logic to your workflow, I can help you configure a Code node!"
|
||||
),
|
||||
"zh": (
|
||||
"我专注于Dify工作流设计,而非通用编程帮助。但如果您想在工作流中添加代码逻辑,我可以帮您配置一个代码节点!"
|
||||
),
|
||||
},
|
||||
"default": {
|
||||
"en": (
|
||||
"I'm the Dify workflow design assistant. I help create AI automation workflows, "
|
||||
"but I can't help with general questions. Would you like to create a workflow instead?"
|
||||
),
|
||||
"zh": "我是Dify工作流设计助手。我帮助创建AI自动化工作流,但无法回答一般性问题。您想创建一个工作流吗?",
|
||||
},
|
||||
}
|
||||
|
||||
# Default suggestions for off-topic requests
|
||||
# These help guide users towards valid workflow requests
|
||||
DEFAULT_SUGGESTIONS: dict[str, list[str]] = {
|
||||
"en": [
|
||||
"Create a chatbot workflow",
|
||||
"Build a document summarization pipeline",
|
||||
"Add email notification to workflow",
|
||||
],
|
||||
"zh": [
|
||||
"创建一个聊天机器人工作流",
|
||||
"构建文档摘要处理流程",
|
||||
"添加邮件通知到工作流",
|
||||
],
|
||||
}
|
||||
@ -1,733 +0,0 @@
|
||||
# =============================================================================
|
||||
# NEW FORMAT: depends_on based prompt (for use with GraphBuilder)
|
||||
# =============================================================================
|
||||
|
||||
BUILDER_SYSTEM_PROMPT_V2 = """<role>
|
||||
You are a Workflow Configuration Engineer.
|
||||
Your goal is to generate workflow node configurations with dependency declarations.
|
||||
The graph structure (edges, start/end nodes) will be automatically built from your output.
|
||||
</role>
|
||||
|
||||
<language_rules>
|
||||
- Detect the language of the user's request automatically (e.g., English, Chinese, Japanese, etc.).
|
||||
- Generate ALL node titles, descriptions, and user-facing text in the SAME language as the user's input.
|
||||
- If the input language is ambiguous or cannot be determined (e.g. code-only input),
|
||||
use {preferred_language} as the target language.
|
||||
</language_rules>
|
||||
|
||||
<inputs>
|
||||
<plan>
|
||||
{plan_context}
|
||||
</plan>
|
||||
|
||||
<tool_schemas>
|
||||
{tool_schemas}
|
||||
</tool_schemas>
|
||||
|
||||
<node_specs>
|
||||
{builtin_node_specs}
|
||||
</node_specs>
|
||||
|
||||
<available_models>
|
||||
{available_models}
|
||||
</available_models>
|
||||
|
||||
<workflow_context>
|
||||
<existing_nodes>
|
||||
{existing_nodes_context}
|
||||
</existing_nodes>
|
||||
<selected_nodes>
|
||||
{selected_nodes_context}
|
||||
</selected_nodes>
|
||||
</workflow_context>
|
||||
</inputs>
|
||||
|
||||
<critical_rules>
|
||||
1. **DO NOT generate start or end nodes** - they are automatically added
|
||||
2. **DO NOT generate edges** - they are automatically built from depends_on
|
||||
3. **Use depends_on array** to declare which nodes must run before this one
|
||||
4. **Leave depends_on empty []** for nodes that should start immediately (connect to start)
|
||||
</critical_rules>
|
||||
|
||||
<rules>
|
||||
1. **Configuration**:
|
||||
- You MUST fill ALL required parameters for every node.
|
||||
- Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields.
|
||||
|
||||
2. **Dependency Declaration**:
|
||||
- Each node has a `depends_on` array listing node IDs that must complete before it runs
|
||||
- Empty depends_on `[]` means the node runs immediately after start
|
||||
- Example: `"depends_on": ["fetch_data"]` means this node waits for fetch_data to complete
|
||||
|
||||
3. **Variable References**:
|
||||
- For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}`
|
||||
- Dependencies will be auto-inferred from variable references if not explicitly declared
|
||||
|
||||
4. **Tools**:
|
||||
- ONLY use the tools listed in `<tool_schemas>`.
|
||||
- If a planned tool is missing from schemas, fallback to `http-request` or `code`.
|
||||
|
||||
5. **Model Selection** (CRITICAL):
|
||||
- For LLM, question-classifier, and parameter-extractor nodes, you MUST include a "model" config.
|
||||
- You MUST use ONLY models from the `<available_models>` section above.
|
||||
- Copy the EXACT provider and name values from available_models.
|
||||
- NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4, or any other models unless they appear in available_models.
|
||||
- If available_models is empty or shows "No models configured", omit the model config entirely.
|
||||
|
||||
6. **if-else Branching**:
|
||||
- Add `true_branch` and `false_branch` in config to specify target node IDs
|
||||
- Example: `"config": {{"cases": [...], "true_branch": "success_node", "false_branch": "fallback_node"}}`
|
||||
|
||||
7. **question-classifier Branching**:
|
||||
- Add `target` field to each class in the classes array
|
||||
- Example: `"classes": [{{"id": "tech", "name": "Tech", "target": "tech_handler"}}, ...]`
|
||||
|
||||
8. **Node Specifics**:
|
||||
- For `if-else` comparison_operator, use literal symbols: `≥`, `≤`, `=`, `≠` (NOT `>=` or `==`).
|
||||
</rules>
|
||||
|
||||
<output_format>
|
||||
Return ONLY a JSON object with a `nodes` array. Each node has:
|
||||
- id: unique identifier
|
||||
- type: node type
|
||||
- title: display name
|
||||
- config: node configuration
|
||||
- depends_on: array of node IDs this depends on
|
||||
|
||||
```json
|
||||
{{{{
|
||||
"nodes": [
|
||||
{{{{
|
||||
"id": "fetch_data",
|
||||
"type": "http-request",
|
||||
"title": "Fetch Data",
|
||||
"config": {{"url": "{{{{#start.url#}}}}", "method": "GET"}},
|
||||
"depends_on": []
|
||||
}}}},
|
||||
{{{{
|
||||
"id": "analyze",
|
||||
"type": "llm",
|
||||
"title": "Analyze",
|
||||
"config": {{"prompt_template": [{{"role": "user", "text": "Analyze: {{{{#fetch_data.body#}}}}"}}]}},
|
||||
"depends_on": ["fetch_data"]
|
||||
}}}}
|
||||
]
|
||||
}}}}
|
||||
```
|
||||
</output_format>
|
||||
|
||||
<examples>
|
||||
<example name="simple_linear">
|
||||
```json
|
||||
{{{{
|
||||
"nodes": [
|
||||
{{{{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"title": "Generate Response",
|
||||
"config": {{{{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}]
|
||||
}}}},
|
||||
"depends_on": []
|
||||
}}}}
|
||||
]
|
||||
}}}}
|
||||
```
|
||||
</example>
|
||||
|
||||
<example name="parallel_then_merge">
|
||||
```json
|
||||
{{{{
|
||||
"nodes": [
|
||||
{{{{
|
||||
"id": "api1",
|
||||
"type": "http-request",
|
||||
"title": "Fetch API 1",
|
||||
"config": {{"url": "https://api1.example.com", "method": "GET"}},
|
||||
"depends_on": []
|
||||
}}}},
|
||||
{{{{
|
||||
"id": "api2",
|
||||
"type": "http-request",
|
||||
"title": "Fetch API 2",
|
||||
"config": {{"url": "https://api2.example.com", "method": "GET"}},
|
||||
"depends_on": []
|
||||
}}}},
|
||||
{{{{
|
||||
"id": "merge",
|
||||
"type": "llm",
|
||||
"title": "Merge Results",
|
||||
"config": {{{{
|
||||
"prompt_template": [{{"role": "user", "text": "Combine: {{{{#api1.body#}}}} and {{{{#api2.body#}}}}"}}]
|
||||
}}}},
|
||||
"depends_on": ["api1", "api2"]
|
||||
}}}}
|
||||
]
|
||||
}}}}
|
||||
```
|
||||
</example>
|
||||
|
||||
<example name="if_else_branching">
|
||||
```json
|
||||
{{{{
|
||||
"nodes": [
|
||||
{{{{
|
||||
"id": "check",
|
||||
"type": "if-else",
|
||||
"title": "Check Condition",
|
||||
"config": {{{{
|
||||
"cases": [{{{{
|
||||
"case_id": "case_1",
|
||||
"logical_operator": "and",
|
||||
"conditions": [{{{{
|
||||
"variable_selector": ["start", "score"],
|
||||
"comparison_operator": "≥",
|
||||
"value": "60"
|
||||
}}}}]
|
||||
}}}}],
|
||||
"true_branch": "pass_handler",
|
||||
"false_branch": "fail_handler"
|
||||
}}}},
|
||||
"depends_on": []
|
||||
}}}},
|
||||
{{{{
|
||||
"id": "pass_handler",
|
||||
"type": "llm",
|
||||
"title": "Pass Response",
|
||||
"config": {{"prompt_template": [{{"role": "user", "text": "Congratulations!"}}]}},
|
||||
"depends_on": []
|
||||
}}}},
|
||||
{{{{
|
||||
"id": "fail_handler",
|
||||
"type": "llm",
|
||||
"title": "Fail Response",
|
||||
"config": {{"prompt_template": [{{"role": "user", "text": "Try again."}}]}},
|
||||
"depends_on": []
|
||||
}}}}
|
||||
]
|
||||
}}}}
|
||||
```
|
||||
Note: pass_handler and fail_handler have empty depends_on because their connections come from if-else branches.
|
||||
</example>
|
||||
|
||||
<example name="question_classifier">
|
||||
```json
|
||||
{{{{
|
||||
"nodes": [
|
||||
{{{{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"title": "Classify Intent",
|
||||
"config": {{{{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"query_variable_selector": ["start", "user_input"],
|
||||
"classes": [
|
||||
{{"id": "tech", "name": "Technical", "target": "tech_handler"}},
|
||||
{{"id": "billing", "name": "Billing", "target": "billing_handler"}},
|
||||
{{"id": "other", "name": "Other", "target": "other_handler"}}
|
||||
]
|
||||
}}}},
|
||||
"depends_on": []
|
||||
}}}},
|
||||
{{{{
|
||||
"id": "tech_handler",
|
||||
"type": "llm",
|
||||
"title": "Tech Support",
|
||||
"config": {{"prompt_template": [{{"role": "user", "text": "Help with tech: {{{{#start.user_input#}}}}"}}]}},
|
||||
"depends_on": []
|
||||
}}}},
|
||||
{{{{
|
||||
"id": "billing_handler",
|
||||
"type": "llm",
|
||||
"title": "Billing Support",
|
||||
"config": {{"prompt_template": [{{"role": "user", "text": "Help with billing: {{{{#start.user_input#}}}}"}}]}},
|
||||
"depends_on": []
|
||||
}}}},
|
||||
{{{{
|
||||
"id": "other_handler",
|
||||
"type": "llm",
|
||||
"title": "General Support",
|
||||
"config": {{"prompt_template": [{{"role": "user", "text": "General help: {{{{#start.user_input#}}}}"}}]}},
|
||||
"depends_on": []
|
||||
}}}}
|
||||
]
|
||||
}}}}
|
||||
```
|
||||
Note: Handler nodes have empty depends_on because their connections come from classifier branches.
|
||||
</example>
|
||||
</examples>
|
||||
"""
|
||||
|
||||
BUILDER_USER_PROMPT_V2 = """<instruction>
|
||||
{instruction}
|
||||
</instruction>
|
||||
|
||||
Generate the workflow nodes configuration. Remember:
|
||||
1. Do NOT generate start or end nodes
|
||||
2. Do NOT generate edges - use depends_on instead
|
||||
3. For if-else: add true_branch/false_branch in config
|
||||
4. For question-classifier: add target to each class
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
# LEGACY FORMAT: edges-based prompt (backward compatible)
|
||||
# =============================================================================
|
||||
|
||||
BUILDER_SYSTEM_PROMPT = """<role>
|
||||
You are a Workflow Configuration Engineer.
|
||||
Your goal is to implement the Architect's plan by generating a precise, runnable Dify Workflow JSON configuration.
|
||||
</role>
|
||||
|
||||
<language_rules>
|
||||
- Detect the language of the user's request automatically (e.g., English, Chinese, Japanese, etc.).
|
||||
- Generate ALL node titles, descriptions, and user-facing text in the SAME language as the user's input.
|
||||
- If the input language is ambiguous or cannot be determined (e.g. code-only input),
|
||||
use {preferred_language} as the target language.
|
||||
</language_rules>
|
||||
|
||||
<inputs>
|
||||
<plan>
|
||||
{plan_context}
|
||||
</plan>
|
||||
|
||||
<tool_schemas>
|
||||
{tool_schemas}
|
||||
</tool_schemas>
|
||||
|
||||
<node_specs>
|
||||
{builtin_node_specs}
|
||||
</node_specs>
|
||||
|
||||
<available_models>
|
||||
{available_models}
|
||||
</available_models>
|
||||
|
||||
<workflow_context>
|
||||
<existing_nodes>
|
||||
{existing_nodes_context}
|
||||
</existing_nodes>
|
||||
<existing_edges>
|
||||
{existing_edges_context}
|
||||
</existing_edges>
|
||||
<selected_nodes>
|
||||
{selected_nodes_context}
|
||||
</selected_nodes>
|
||||
</workflow_context>
|
||||
</inputs>
|
||||
|
||||
<rules>
|
||||
1. **Configuration**:
|
||||
- You MUST fill ALL required parameters for every node.
|
||||
- Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields.
|
||||
- For 'start' node, define all necessary user inputs.
|
||||
|
||||
2. **Variable References**:
|
||||
- For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}`
|
||||
- For 'end' node outputs: use `value_selector` array format `["node_id", "field"]`
|
||||
- Example: to reference 'llm' node's 'text' output in end node, use `["llm", "text"]`
|
||||
|
||||
3. **Tools**:
|
||||
- ONLY use the tools listed in `<tool_schemas>`.
|
||||
- If a planned tool is missing from schemas, fallback to `http-request` or `code`.
|
||||
|
||||
4. **Model Selection** (CRITICAL):
|
||||
- For LLM, question-classifier, and parameter-extractor nodes, you MUST include a "model" config.
|
||||
- You MUST use ONLY models from the `<available_models>` section above.
|
||||
- Copy the EXACT provider and name values from available_models.
|
||||
- NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4, or any other models unless they appear in available_models.
|
||||
- If available_models is empty or shows "No models configured", omit the model config entirely.
|
||||
|
||||
5. **Node Specifics**:
|
||||
- For `if-else` comparison_operator, use literal symbols: `≥`, `≤`, `=`, `≠` (NOT `>=` or `==`).
|
||||
|
||||
6. **Modification Mode**:
|
||||
- If `<existing_nodes>` contains nodes, you are MODIFYING an existing workflow.
|
||||
- Keep nodes that are NOT mentioned in the user's instruction UNCHANGED.
|
||||
- Only modify/add/remove nodes that the user explicitly requested.
|
||||
- Preserve node IDs for unchanged nodes to maintain connections.
|
||||
- If user says "add X", append new nodes to existing workflow.
|
||||
- If user says "change Y to Z", only modify that specific node.
|
||||
- If user says "remove X", exclude that node from output.
|
||||
|
||||
**Edge Modification**:
|
||||
- Use `<existing_edges>` to understand current node connections.
|
||||
- If user mentions "fix edge", "connect", "link", or "add connection",
|
||||
review existing_edges and correct missing/wrong connections.
|
||||
- For multi-branch nodes (if-else, question-classifier),
|
||||
ensure EACH branch has proper sourceHandle (e.g., "true"/"false") and target.
|
||||
- Common edge issues to fix:
|
||||
* Missing edge: Two nodes should connect but don't - add the edge
|
||||
* Wrong target: Edge points to wrong node - update the target
|
||||
* Missing sourceHandle: if-else/classifier branches lack sourceHandle - add "true"/"false"
|
||||
* Disconnected nodes: Node has no incoming or outgoing edges - connect it properly
|
||||
- When modifying edges, ensure logical flow makes sense (start → middle → end).
|
||||
- ALWAYS output complete edges array, even if only modifying one edge.
|
||||
|
||||
**Validation Feedback** (Automatic Retry):
|
||||
- If `<validation_feedback>` is present, you are RETRYING after validation errors.
|
||||
- Focus ONLY on fixing the specific validation issues mentioned.
|
||||
- Keep everything else from the previous attempt UNCHANGED (preserve node IDs, edges, etc).
|
||||
- Common validation issues and fixes:
|
||||
* "Missing required connection" → Add the missing edge
|
||||
* "Invalid node configuration" → Fix the specific node's config section
|
||||
* "Type mismatch in variable reference" → Correct the variable selector path
|
||||
* "Unknown variable" → Update variable reference to existing output
|
||||
- When fixing, make MINIMAL changes to address each specific error.
|
||||
|
||||
7. **Output**:
|
||||
- Return ONLY the JSON object with `nodes` and `edges`.
|
||||
- Do NOT generate Mermaid diagrams.
|
||||
- Do NOT generate explanations.
|
||||
</rules>
|
||||
|
||||
<edge_rules priority="critical">
|
||||
**EDGES ARE CRITICAL** - Every node except 'end' MUST have at least one outgoing edge.
|
||||
|
||||
1. **Linear Flow**: Simple source -> target connection
|
||||
```
|
||||
{{"source": "node_a", "target": "node_b"}}
|
||||
```
|
||||
|
||||
2. **question-classifier Branching**: Each class MUST have a separate edge with `sourceHandle` = class `id`
|
||||
- If you define classes: [{{"id": "cls_refund", "name": "Refund"}}, {{"id": "cls_inquiry", "name": "Inquiry"}}]
|
||||
- You MUST create edges:
|
||||
- {{"source": "classifier", "sourceHandle": "cls_refund", "target": "refund_handler"}}
|
||||
- {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "inquiry_handler"}}
|
||||
|
||||
3. **if-else Branching**: MUST have exactly TWO edges with sourceHandle "true" and "false"
|
||||
- {{"source": "condition", "sourceHandle": "true", "target": "true_branch"}}
|
||||
- {{"source": "condition", "sourceHandle": "false", "target": "false_branch"}}
|
||||
|
||||
4. **Branch Convergence**: Multiple branches can connect to same downstream node
|
||||
- Both true_branch and false_branch can connect to the same 'end' node
|
||||
|
||||
5. **NEVER leave orphan nodes**: Every node must be connected in the graph
|
||||
</edge_rules>
|
||||
|
||||
<examples>
|
||||
<example name="simple_linear">
|
||||
```json
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": {{
|
||||
"variables": [{{"variable": "query", "label": "Query", "type": "text-input"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"title": "Generate Response",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": {{
|
||||
"outputs": [
|
||||
{{"variable": "result", "value_selector": ["llm", "text"]}}
|
||||
]
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{"source": "start", "target": "llm"}},
|
||||
{{"source": "llm", "target": "end"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
</example>
|
||||
|
||||
<example name="question_classifier_branching" description="Customer service with intent classification">
|
||||
```json
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": {{
|
||||
"variables": [{{"variable": "user_input", "label": "User Message", "type": "text-input", "required": true}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"title": "Classify Intent",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"query_variable_selector": ["start", "user_input"],
|
||||
"classes": [
|
||||
{{"id": "cls_refund", "name": "Refund Request"}},
|
||||
{{"id": "cls_inquiry", "name": "Product Inquiry"}},
|
||||
{{"id": "cls_complaint", "name": "Complaint"}},
|
||||
{{"id": "cls_other", "name": "Other"}}
|
||||
],
|
||||
"instruction": "Classify the user's intent"
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "handle_refund",
|
||||
"type": "llm",
|
||||
"title": "Handle Refund",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Extract order number and respond: {{{{#start.user_input#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "handle_inquiry",
|
||||
"type": "llm",
|
||||
"title": "Handle Inquiry",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Answer product question: {{{{#start.user_input#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "handle_complaint",
|
||||
"type": "llm",
|
||||
"title": "Handle Complaint",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Respond with empathy: {{{{#start.user_input#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "handle_other",
|
||||
"type": "llm",
|
||||
"title": "Handle Other",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Provide general response: {{{{#start.user_input#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": {{
|
||||
"outputs": [{{"variable": "response", "value_selector": ["handle_refund", "text"]}}]
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{"source": "start", "target": "classifier"}},
|
||||
{{"source": "classifier", "sourceHandle": "cls_refund", "target": "handle_refund"}},
|
||||
{{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "handle_inquiry"}},
|
||||
{{"source": "classifier", "sourceHandle": "cls_complaint", "target": "handle_complaint"}},
|
||||
{{"source": "classifier", "sourceHandle": "cls_other", "target": "handle_other"}},
|
||||
{{"source": "handle_refund", "target": "end"}},
|
||||
{{"source": "handle_inquiry", "target": "end"}},
|
||||
{{"source": "handle_complaint", "target": "end"}},
|
||||
{{"source": "handle_other", "target": "end"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
CRITICAL: Notice that each class id (cls_refund, cls_inquiry, etc.) becomes a sourceHandle in the edges!
|
||||
</example>
|
||||
|
||||
<example name="if_else_branching" description="Conditional logic with if-else">
|
||||
```json
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": {{
|
||||
"variables": [{{"variable": "years", "label": "Years of Experience", "type": "number", "required": true}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "check_experience",
|
||||
"type": "if-else",
|
||||
"title": "Check Experience",
|
||||
"config": {{
|
||||
"cases": [
|
||||
{{
|
||||
"case_id": "case_1",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{{
|
||||
"variable_selector": ["start", "years"],
|
||||
"comparison_operator": "≥",
|
||||
"value": "3"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "qualified",
|
||||
"type": "llm",
|
||||
"title": "Qualified Response",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Generate qualified candidate response"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "not_qualified",
|
||||
"type": "llm",
|
||||
"title": "Not Qualified Response",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Generate rejection response"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": {{
|
||||
"outputs": [{{"variable": "result", "value_selector": ["qualified", "text"]}}]
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{"source": "start", "target": "check_experience"}},
|
||||
{{"source": "check_experience", "sourceHandle": "true", "target": "qualified"}},
|
||||
{{"source": "check_experience", "sourceHandle": "false", "target": "not_qualified"}},
|
||||
{{"source": "qualified", "target": "end"}},
|
||||
{{"source": "not_qualified", "target": "end"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
CRITICAL: if-else MUST have exactly two edges with sourceHandle "true" and "false"!
|
||||
</example>
|
||||
|
||||
<example name="parameter_extractor" description="Extract structured data from text">
|
||||
```json
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": {{
|
||||
"variables": [{{"variable": "resume", "label": "Resume Text", "type": "paragraph", "required": true}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "extract",
|
||||
"type": "parameter-extractor",
|
||||
"title": "Extract Info",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"query": ["start", "resume"],
|
||||
"parameters": [
|
||||
{{"name": "name", "type": "string", "description": "Candidate name", "required": true}},
|
||||
{{"name": "years", "type": "number", "description": "Years of experience", "required": true}},
|
||||
{{"name": "skills", "type": "array[string]", "description": "List of skills", "required": true}}
|
||||
],
|
||||
"instruction": "Extract candidate information from resume"
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "process",
|
||||
"type": "llm",
|
||||
"title": "Process Data",
|
||||
"config": {{
|
||||
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
|
||||
"prompt_template": [{{"role": "user", "text": "Name: {{{{#extract.name#}}}}, Years: {{{{#extract.years#}}}}"}}]
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": {{
|
||||
"outputs": [{{"variable": "result", "value_selector": ["process", "text"]}}]
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{"source": "start", "target": "extract"}},
|
||||
{{"source": "extract", "target": "process"}},
|
||||
{{"source": "process", "target": "end"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
</example>
|
||||
</examples>
|
||||
|
||||
<edge_checklist>
|
||||
Before finalizing, verify:
|
||||
1. [ ] Every node (except 'end') has at least one outgoing edge
|
||||
2. [ ] 'start' node has exactly one outgoing edge
|
||||
3. [ ] 'question-classifier' has one edge per class, each with sourceHandle = class id
|
||||
4. [ ] 'if-else' has exactly two edges: sourceHandle "true" and sourceHandle "false"
|
||||
5. [ ] All branches eventually connect to 'end' (directly or through other nodes)
|
||||
6. [ ] No orphan nodes exist (every node is reachable from 'start')
|
||||
</edge_checklist>
|
||||
"""
|
||||
|
||||
BUILDER_USER_PROMPT = """<instruction>
|
||||
{instruction}
|
||||
</instruction>
|
||||
|
||||
Generate the full workflow configuration now. Pay special attention to:
|
||||
1. Creating edges for ALL branches of question-classifier and if-else nodes
|
||||
2. Using correct sourceHandle values for branching nodes
|
||||
3. Ensuring every node is connected in the graph
|
||||
"""
|
||||
|
||||
|
||||
def format_existing_nodes(nodes: list[dict] | None) -> str:
|
||||
"""Format existing workflow nodes for context."""
|
||||
if not nodes:
|
||||
return "No existing nodes in workflow (creating from scratch)."
|
||||
|
||||
lines = []
|
||||
for node in nodes:
|
||||
node_id = node.get("id", "unknown")
|
||||
node_type = node.get("type", "unknown")
|
||||
title = node.get("title", "Untitled")
|
||||
lines.append(f"- [{node_id}] {title} ({node_type})")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_selected_nodes(
|
||||
selected_ids: list[str] | None,
|
||||
existing_nodes: list[dict] | None,
|
||||
) -> str:
|
||||
"""Format selected nodes for modification context."""
|
||||
if not selected_ids:
|
||||
return "No nodes selected (generating new workflow)."
|
||||
|
||||
node_map = {n.get("id"): n for n in (existing_nodes or [])}
|
||||
lines = []
|
||||
for node_id in selected_ids:
|
||||
if node_id in node_map:
|
||||
node = node_map[node_id]
|
||||
lines.append(f"- [{node_id}] {node.get('title', 'Untitled')} ({node.get('type', 'unknown')})")
|
||||
else:
|
||||
lines.append(f"- [{node_id}] (not found in current workflow)")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_existing_edges(edges: list[dict] | None) -> str:
|
||||
"""Format existing workflow edges to show connections."""
|
||||
if not edges:
|
||||
return "No existing edges (creating new workflow)."
|
||||
|
||||
lines = []
|
||||
for edge in edges:
|
||||
source = edge.get("source", "unknown")
|
||||
target = edge.get("target", "unknown")
|
||||
source_handle = edge.get("sourceHandle", "")
|
||||
if source_handle:
|
||||
lines.append(f"- {source} ({source_handle}) -> {target}")
|
||||
else:
|
||||
lines.append(f"- {source} -> {target}")
|
||||
return "\n".join(lines)
|
||||
@ -1,75 +0,0 @@
|
||||
PLANNER_SYSTEM_PROMPT = """<role>
|
||||
You are an expert Workflow Architect.
|
||||
Your job is to analyze user requests and plan a high-level automation workflow.
|
||||
</role>
|
||||
|
||||
<task>
|
||||
1. **Classify Intent**:
|
||||
- Is the user asking to create an automation/workflow? -> Intent: "generate"
|
||||
- Is it general chat/weather/jokes? -> Intent: "off_topic"
|
||||
|
||||
2. **Plan Steps** (if intent is "generate"):
|
||||
- Break down the user's goal into logical steps.
|
||||
- For each step, identify if a specific capability/tool is needed.
|
||||
- Select the MOST RELEVANT tools from the available_tools list.
|
||||
- DO NOT configure parameters yet. Just identify the tool.
|
||||
|
||||
3. **Output Format**:
|
||||
Return a JSON object.
|
||||
</task>
|
||||
|
||||
<available_tools>
|
||||
{tools_summary}
|
||||
</available_tools>
|
||||
|
||||
<response_format>
|
||||
If intent is "generate":
|
||||
```json
|
||||
{{
|
||||
"intent": "generate",
|
||||
"plan_thought": "Brief explanation of the plan...",
|
||||
"steps": [
|
||||
{{ "step": 1, "description": "Fetch data from URL", "tool": "http-request" }},
|
||||
{{ "step": 2, "description": "Summarize content", "tool": "llm" }},
|
||||
{{ "step": 3, "description": "Search for info", "tool": "google_search" }}
|
||||
],
|
||||
"required_tool_keys": ["google_search"]
|
||||
}}
|
||||
```
|
||||
(Note: 'http-request', 'llm', 'code' are built-in, you don't need to list them in required_tool_keys,
|
||||
only external tools)
|
||||
|
||||
If intent is "off_topic":
|
||||
```json
|
||||
{{
|
||||
"intent": "off_topic",
|
||||
"message": "I can only help you build workflows. Try asking me to 'Create a workflow that...'",
|
||||
"suggestions": ["Scrape a website", "Summarize a PDF"]
|
||||
}}
|
||||
```
|
||||
</response_format>
|
||||
"""
|
||||
|
||||
PLANNER_USER_PROMPT = """<user_request>
|
||||
{instruction}
|
||||
</user_request>
|
||||
"""
|
||||
|
||||
|
||||
def format_tools_for_planner(tools: list[dict]) -> str:
|
||||
"""Format tools list for planner (Lightweight: Name + Description only)."""
|
||||
if not tools:
|
||||
return "No external tools available."
|
||||
|
||||
lines = []
|
||||
for t in tools:
|
||||
key = t.get("tool_key") or t.get("tool_name")
|
||||
provider = t.get("provider_id") or t.get("provider", "")
|
||||
desc = t.get("tool_description") or t.get("description", "")
|
||||
label = t.get("tool_label") or key
|
||||
|
||||
# Format: - [provider/key] Label: Description
|
||||
full_key = f"{provider}/{key}" if provider else key
|
||||
lines.append(f"- [{full_key}] {label}: {desc}")
|
||||
|
||||
return "\n".join(lines)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,349 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.workflow.generator.prompts.builder_prompts import (
|
||||
BUILDER_SYSTEM_PROMPT,
|
||||
BUILDER_SYSTEM_PROMPT_V2,
|
||||
BUILDER_USER_PROMPT,
|
||||
BUILDER_USER_PROMPT_V2,
|
||||
format_existing_edges,
|
||||
format_existing_nodes,
|
||||
format_selected_nodes,
|
||||
)
|
||||
from core.workflow.generator.prompts.planner_prompts import (
|
||||
PLANNER_SYSTEM_PROMPT,
|
||||
PLANNER_USER_PROMPT,
|
||||
format_tools_for_planner,
|
||||
)
|
||||
from core.workflow.generator.prompts.vibe_prompts import (
|
||||
format_available_models,
|
||||
format_available_nodes,
|
||||
format_available_tools,
|
||||
parse_vibe_response,
|
||||
)
|
||||
from core.workflow.generator.utils.graph_builder import CyclicDependencyError, GraphBuilder
|
||||
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
|
||||
from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowGenerator:
|
||||
"""
|
||||
Refactored Vibe Workflow Generator (Planner-Builder Architecture).
|
||||
Extracts Vibe logic from the monolithic LLMGenerator.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_flowchart(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
available_nodes: Sequence[dict[str, object]] | None = None,
|
||||
existing_nodes: Sequence[dict[str, object]] | None = None,
|
||||
existing_edges: Sequence[dict[str, object]] | None = None,
|
||||
available_tools: Sequence[dict[str, object]] | None = None,
|
||||
selected_node_ids: Sequence[str] | None = None,
|
||||
previous_workflow: dict[str, object] | None = None,
|
||||
regenerate_mode: bool = False,
|
||||
preferred_language: str | None = None,
|
||||
available_models: Sequence[dict[str, object]] | None = None,
|
||||
use_graph_builder: bool = False,
|
||||
):
|
||||
"""
|
||||
Generates a Dify Workflow Flowchart from natural language instruction.
|
||||
|
||||
Pipeline:
|
||||
1. Planner: Analyze intent & select tools.
|
||||
2. Context Filter: Filter relevant tools (reduce tokens).
|
||||
3. Builder: Generate node configurations.
|
||||
4. Repair: Fix common node/edge issues (NodeRepair, EdgeRepair).
|
||||
5. Validator: Check for errors & generate friendly hints.
|
||||
6. Renderer: Deterministic Mermaid generation.
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
available_tools_list = list(available_tools) if available_tools else []
|
||||
|
||||
# Check if this is modification mode (user is refining existing workflow)
|
||||
has_existing_nodes = existing_nodes and len(list(existing_nodes)) > 0
|
||||
|
||||
# --- STEP 1: PLANNER (Skip in modification mode) ---
|
||||
if has_existing_nodes:
|
||||
# In modification mode, skip Planner:
|
||||
# - User intent is clear: modify the existing workflow
|
||||
# - Tools are already in use (from existing nodes)
|
||||
# - No need for intent classification or tool selection
|
||||
plan_data = {"intent": "generate", "steps": [], "required_tool_keys": []}
|
||||
filtered_tools = available_tools_list # Use all available tools
|
||||
else:
|
||||
# In creation mode, run Planner to validate intent and select tools
|
||||
planner_tools_context = format_tools_for_planner(available_tools_list)
|
||||
planner_system = PLANNER_SYSTEM_PROMPT.format(tools_summary=planner_tools_context)
|
||||
planner_user = PLANNER_USER_PROMPT.format(instruction=instruction)
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=planner_system),
|
||||
UserPromptMessage(content=planner_user),
|
||||
],
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
)
|
||||
plan_content = response.message.content
|
||||
# Reuse parse_vibe_response logic or simple load
|
||||
plan_data = parse_vibe_response(plan_content)
|
||||
except Exception as e:
|
||||
logger.exception("Planner failed")
|
||||
return {"intent": "error", "error": f"Planning failed: {str(e)}"}
|
||||
|
||||
if plan_data.get("intent") == "off_topic":
|
||||
return {
|
||||
"intent": "off_topic",
|
||||
"message": plan_data.get("message", "I can only help with workflow creation."),
|
||||
"suggestions": plan_data.get("suggestions", []),
|
||||
}
|
||||
|
||||
# --- STEP 2: CONTEXT FILTERING ---
|
||||
required_tools = plan_data.get("required_tool_keys", [])
|
||||
|
||||
filtered_tools = []
|
||||
if required_tools:
|
||||
# Simple linear search (optimized version would use a map)
|
||||
for tool in available_tools_list:
|
||||
t_key = tool.get("tool_key") or tool.get("tool_name")
|
||||
provider = tool.get("provider_id") or tool.get("provider")
|
||||
full_key = f"{provider}/{t_key}" if provider else t_key
|
||||
|
||||
# Check if this tool is in required list (match either full key or short name)
|
||||
if t_key in required_tools or full_key in required_tools:
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# If logic only, no tools needed
|
||||
filtered_tools = []
|
||||
|
||||
# --- STEP 3: BUILDER (with retry loop) ---
|
||||
MAX_GLOBAL_RETRIES = 2 # Total attempts: 1 initial + 1 retry
|
||||
|
||||
workflow_data = None
|
||||
mermaid_code = None
|
||||
all_warnings = []
|
||||
all_fixes = []
|
||||
retry_count = 0
|
||||
validation_hints = []
|
||||
|
||||
for attempt in range(MAX_GLOBAL_RETRIES):
|
||||
retry_count = attempt
|
||||
logger.info("Generation attempt %s/%s", attempt + 1, MAX_GLOBAL_RETRIES)
|
||||
|
||||
# Prepare context
|
||||
tool_schemas = format_available_tools(filtered_tools)
|
||||
node_specs = format_available_nodes(list(available_nodes) if available_nodes else [])
|
||||
existing_nodes_context = format_existing_nodes(list(existing_nodes) if existing_nodes else None)
|
||||
existing_edges_context = format_existing_edges(list(existing_edges) if existing_edges else None)
|
||||
selected_nodes_context = format_selected_nodes(
|
||||
list(selected_node_ids) if selected_node_ids else None, list(existing_nodes) if existing_nodes else None
|
||||
)
|
||||
|
||||
# Build retry context
|
||||
retry_context = ""
|
||||
|
||||
# NOTE: Manual regeneration/refinement mode removed
|
||||
# Only handle automatic retry (validation errors)
|
||||
|
||||
# For automatic retry (validation errors)
|
||||
if attempt > 0 and validation_hints:
|
||||
severe_issues = [h for h in validation_hints if h.severity == "error"]
|
||||
if severe_issues:
|
||||
retry_context = "\n<validation_feedback>\n"
|
||||
retry_context += "The previous generation had validation errors:\n"
|
||||
for idx, hint in enumerate(severe_issues[:5], 1):
|
||||
retry_context += f"{idx}. {hint.message}\n"
|
||||
retry_context += "\nPlease fix these specific issues while keeping everything else UNCHANGED.\n"
|
||||
retry_context += "</validation_feedback>\n"
|
||||
|
||||
# Select prompt version based on use_graph_builder flag
|
||||
if use_graph_builder:
|
||||
builder_system = BUILDER_SYSTEM_PROMPT_V2.format(
|
||||
plan_context=json.dumps(plan_data.get("steps", []), indent=2),
|
||||
tool_schemas=tool_schemas,
|
||||
builtin_node_specs=node_specs,
|
||||
available_models=format_available_models(list(available_models or [])),
|
||||
preferred_language=preferred_language or "English",
|
||||
existing_nodes_context=existing_nodes_context,
|
||||
selected_nodes_context=selected_nodes_context,
|
||||
)
|
||||
builder_user = BUILDER_USER_PROMPT_V2.format(instruction=instruction) + retry_context
|
||||
else:
|
||||
builder_system = BUILDER_SYSTEM_PROMPT.format(
|
||||
plan_context=json.dumps(plan_data.get("steps", []), indent=2),
|
||||
tool_schemas=tool_schemas,
|
||||
builtin_node_specs=node_specs,
|
||||
available_models=format_available_models(list(available_models or [])),
|
||||
preferred_language=preferred_language or "English",
|
||||
existing_nodes_context=existing_nodes_context,
|
||||
existing_edges_context=existing_edges_context,
|
||||
selected_nodes_context=selected_nodes_context,
|
||||
)
|
||||
builder_user = BUILDER_USER_PROMPT.format(instruction=instruction) + retry_context
|
||||
|
||||
try:
|
||||
build_res = model_instance.invoke_llm(
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=builder_system),
|
||||
UserPromptMessage(content=builder_user),
|
||||
],
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
)
|
||||
# Builder output is raw JSON nodes/edges
|
||||
build_content = build_res.message.content
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]+?)```", build_content)
|
||||
if match:
|
||||
build_content = match.group(1)
|
||||
|
||||
workflow_data = json_repair.loads(build_content)
|
||||
|
||||
if "nodes" not in workflow_data:
|
||||
workflow_data["nodes"] = []
|
||||
|
||||
# --- GraphBuilder Mode: Build graph from depends_on ---
|
||||
if use_graph_builder:
|
||||
try:
|
||||
# Extract nodes from LLM output (without start/end)
|
||||
llm_nodes = workflow_data.get("nodes", [])
|
||||
|
||||
# Build complete graph with start/end and edges
|
||||
complete_nodes, edges = GraphBuilder.build_graph(llm_nodes)
|
||||
|
||||
workflow_data["nodes"] = complete_nodes
|
||||
workflow_data["edges"] = edges
|
||||
|
||||
logger.info(
|
||||
"GraphBuilder: built %d nodes, %d edges from %d LLM nodes",
|
||||
len(complete_nodes),
|
||||
len(edges),
|
||||
len(llm_nodes),
|
||||
)
|
||||
|
||||
except CyclicDependencyError as e:
|
||||
logger.warning("GraphBuilder: cyclic dependency detected: %s", e)
|
||||
# Add to validation hints for retry
|
||||
validation_hints.append(
|
||||
ValidationHint(
|
||||
node_id="",
|
||||
field="depends_on",
|
||||
message=f"Cyclic dependency detected: {e}. Please fix the dependency chain.",
|
||||
severity="error",
|
||||
)
|
||||
)
|
||||
if attempt == MAX_GLOBAL_RETRIES - 1:
|
||||
return {
|
||||
"intent": "error",
|
||||
"error": "Failed to build workflow: cyclic dependency detected.",
|
||||
}
|
||||
continue # Retry with error feedback
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("GraphBuilder failed on attempt %d", attempt + 1)
|
||||
if attempt == MAX_GLOBAL_RETRIES - 1:
|
||||
return {"intent": "error", "error": f"Graph building failed: {str(e)}"}
|
||||
continue
|
||||
else:
|
||||
# Legacy mode: edges from LLM output
|
||||
if "edges" not in workflow_data:
|
||||
workflow_data["edges"] = []
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Builder failed on attempt %d", attempt + 1)
|
||||
if attempt == MAX_GLOBAL_RETRIES - 1:
|
||||
return {"intent": "error", "error": f"Building failed: {str(e)}"}
|
||||
continue # Try again
|
||||
|
||||
# NOTE: NodeRepair and EdgeRepair have been removed.
|
||||
# Validation will detect structural issues, and LLM will fix them on retry.
|
||||
# This is more accurate because LLM understands the workflow context.
|
||||
|
||||
# --- STEP 4: RENDERER (Generate Mermaid early for validation) ---
|
||||
mermaid_code = generate_mermaid(workflow_data)
|
||||
|
||||
# --- STEP 5: VALIDATOR ---
|
||||
is_valid, validation_hints = WorkflowValidator.validate(workflow_data, available_tools_list)
|
||||
|
||||
# --- STEP 6: GRAPH VALIDATION (structural checks using graph algorithms) ---
|
||||
if attempt < MAX_GLOBAL_RETRIES - 1:
|
||||
try:
|
||||
from core.workflow.generator.utils.graph_validator import GraphValidator
|
||||
|
||||
graph_result = GraphValidator.validate(workflow_data)
|
||||
|
||||
if not graph_result.success:
|
||||
# Convert graph errors to validation hints
|
||||
for graph_error in graph_result.errors:
|
||||
validation_hints.append(
|
||||
ValidationHint(
|
||||
node_id=graph_error.node_id,
|
||||
field="edges",
|
||||
message=f"[Graph] {graph_error.message}",
|
||||
severity="error",
|
||||
)
|
||||
)
|
||||
# Also add warnings (dead ends) as hints
|
||||
for graph_warning in graph_result.warnings:
|
||||
validation_hints.append(
|
||||
ValidationHint(
|
||||
node_id=graph_warning.node_id,
|
||||
field="edges",
|
||||
message=f"[Graph] {graph_warning.message}",
|
||||
severity="warning",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Graph validation error: %s", e)
|
||||
# Collect all validation warnings
|
||||
all_warnings = [h.message for h in validation_hints]
|
||||
|
||||
# Check if we should retry
|
||||
severe_issues = [h for h in validation_hints if h.severity == "error"]
|
||||
|
||||
if not severe_issues or attempt == MAX_GLOBAL_RETRIES - 1:
|
||||
break
|
||||
|
||||
# Has severe errors and retries remaining - continue to next attempt
|
||||
|
||||
# Collect all validation warnings
|
||||
all_warnings = [h.message for h in validation_hints]
|
||||
|
||||
# Add stability warning (as requested by user)
|
||||
stability_warning = "The generated workflow may require debugging."
|
||||
if preferred_language and preferred_language.startswith("zh"):
|
||||
stability_warning = "生成的 Workflow 可能需要调试。"
|
||||
all_warnings.append(stability_warning)
|
||||
|
||||
return {
|
||||
"intent": "generate",
|
||||
"flowchart": mermaid_code,
|
||||
"nodes": workflow_data["nodes"],
|
||||
"edges": workflow_data["edges"],
|
||||
"message": plan_data.get("plan_thought", "Generated workflow based on your request."),
|
||||
"warnings": all_warnings,
|
||||
"tool_recommendations": [], # Legacy field
|
||||
"error": "",
|
||||
"fixed_issues": all_fixes, # Track what was auto-fixed
|
||||
"retry_count": retry_count, # Track how many retries were needed
|
||||
}
|
||||
@ -1,217 +0,0 @@
|
||||
"""
|
||||
Type definitions for Vibe Workflow Generator.
|
||||
|
||||
This module provides:
|
||||
- TypedDict classes for lightweight type hints (no runtime overhead)
|
||||
- Pydantic models for runtime validation where needed
|
||||
|
||||
Usage:
|
||||
# For type hints only (no runtime validation):
|
||||
from core.workflow.generator.types import WorkflowNodeDict, WorkflowEdgeDict
|
||||
|
||||
# For runtime validation:
|
||||
from core.workflow.generator.types import WorkflowNode, WorkflowEdge
|
||||
"""
|
||||
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ============================================================
|
||||
# TypedDict definitions (lightweight, for type hints only)
|
||||
# ============================================================
|
||||
|
||||
|
||||
class WorkflowNodeDict(TypedDict, total=False):
|
||||
"""
|
||||
Workflow node structure (TypedDict for hints).
|
||||
|
||||
Attributes:
|
||||
id: Unique node identifier
|
||||
type: Node type (e.g., "start", "end", "llm", "if-else", "http-request")
|
||||
title: Human-readable node title
|
||||
config: Node-specific configuration
|
||||
data: Additional node data
|
||||
"""
|
||||
|
||||
id: str
|
||||
type: str
|
||||
title: str
|
||||
config: dict[str, Any]
|
||||
data: dict[str, Any]
|
||||
|
||||
|
||||
class WorkflowEdgeDict(TypedDict, total=False):
|
||||
"""
|
||||
Workflow edge structure (TypedDict for hints).
|
||||
|
||||
Attributes:
|
||||
source: Source node ID
|
||||
target: Target node ID
|
||||
sourceHandle: Branch handle for if-else/question-classifier nodes
|
||||
"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
sourceHandle: str
|
||||
|
||||
|
||||
class AvailableModelDict(TypedDict):
|
||||
"""
|
||||
Available model structure.
|
||||
|
||||
Attributes:
|
||||
provider: Model provider (e.g., "openai", "anthropic")
|
||||
model: Model name (e.g., "gpt-4", "claude-3")
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class ToolParameterDict(TypedDict, total=False):
|
||||
"""
|
||||
Tool parameter structure.
|
||||
|
||||
Attributes:
|
||||
name: Parameter name
|
||||
type: Parameter type (e.g., "string", "number", "boolean")
|
||||
required: Whether parameter is required
|
||||
human_description: Human-readable description
|
||||
llm_description: LLM-oriented description
|
||||
options: Available options for enum-type parameters
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
required: bool
|
||||
human_description: str | dict[str, str]
|
||||
llm_description: str
|
||||
options: list[Any]
|
||||
|
||||
|
||||
class AvailableToolDict(TypedDict, total=False):
|
||||
"""
|
||||
Available tool structure.
|
||||
|
||||
Attributes:
|
||||
provider_id: Tool provider ID
|
||||
provider: Tool provider name (alternative to provider_id)
|
||||
tool_key: Unique tool key
|
||||
tool_name: Tool name (alternative to tool_key)
|
||||
tool_description: Tool description
|
||||
description: Alternative description field
|
||||
is_team_authorization: Whether tool is configured/authorized
|
||||
parameters: List of tool parameters
|
||||
"""
|
||||
|
||||
provider_id: str
|
||||
provider: str
|
||||
tool_key: str
|
||||
tool_name: str
|
||||
tool_description: str
|
||||
description: str
|
||||
is_team_authorization: bool
|
||||
parameters: list[ToolParameterDict]
|
||||
|
||||
|
||||
class WorkflowDataDict(TypedDict, total=False):
|
||||
"""
|
||||
Complete workflow data structure.
|
||||
|
||||
Attributes:
|
||||
nodes: List of workflow nodes
|
||||
edges: List of workflow edges
|
||||
warnings: List of warning messages
|
||||
"""
|
||||
|
||||
nodes: list[WorkflowNodeDict]
|
||||
edges: list[WorkflowEdgeDict]
|
||||
warnings: list[str]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Pydantic models (for runtime validation)
|
||||
# ============================================================
|
||||
|
||||
|
||||
class WorkflowNode(BaseModel):
|
||||
"""
|
||||
Workflow node with runtime validation.
|
||||
|
||||
Use this model when you need to validate node data at runtime.
|
||||
For lightweight type hints without validation, use WorkflowNodeDict.
|
||||
"""
|
||||
|
||||
id: str
|
||||
type: str
|
||||
title: str = ""
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
data: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowEdge(BaseModel):
|
||||
"""
|
||||
Workflow edge with runtime validation.
|
||||
|
||||
Use this model when you need to validate edge data at runtime.
|
||||
For lightweight type hints without validation, use WorkflowEdgeDict.
|
||||
"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
sourceHandle: str | None = None
|
||||
|
||||
|
||||
class AvailableModel(BaseModel):
|
||||
"""
|
||||
Available model with runtime validation.
|
||||
|
||||
Use this model when you need to validate model data at runtime.
|
||||
For lightweight type hints without validation, use AvailableModelDict.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
"""Tool parameter with runtime validation."""
|
||||
|
||||
name: str = ""
|
||||
type: str = "string"
|
||||
required: bool = False
|
||||
human_description: str | dict[str, str] = ""
|
||||
llm_description: str = ""
|
||||
options: list[Any] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AvailableTool(BaseModel):
|
||||
"""
|
||||
Available tool with runtime validation.
|
||||
|
||||
Use this model when you need to validate tool data at runtime.
|
||||
For lightweight type hints without validation, use AvailableToolDict.
|
||||
"""
|
||||
|
||||
provider_id: str = ""
|
||||
provider: str = ""
|
||||
tool_key: str = ""
|
||||
tool_name: str = ""
|
||||
tool_description: str = ""
|
||||
description: str = ""
|
||||
is_team_authorization: bool = False
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowData(BaseModel):
|
||||
"""
|
||||
Complete workflow data with runtime validation.
|
||||
|
||||
Use this model when you need to validate workflow data at runtime.
|
||||
For lightweight type hints without validation, use WorkflowDataDict.
|
||||
"""
|
||||
|
||||
nodes: list[WorkflowNode] = Field(default_factory=list)
|
||||
edges: list[WorkflowEdge] = Field(default_factory=list)
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
@ -1,384 +0,0 @@
|
||||
"""
|
||||
Edge Repair Utility for Vibe Workflow Generation.
|
||||
|
||||
This module provides intelligent edge repair capabilities for generated workflows.
|
||||
It can detect and fix common edge issues:
|
||||
- Missing edges between sequential nodes
|
||||
- Incomplete branches for question-classifier and if-else nodes
|
||||
- Orphaned nodes without connections
|
||||
|
||||
The repair logic is deterministic and doesn't require LLM calls.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.workflow.generator.types import WorkflowDataDict, WorkflowEdgeDict, WorkflowNodeDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepairResult:
|
||||
"""Result of edge repair operation."""
|
||||
|
||||
nodes: list[WorkflowNodeDict]
|
||||
edges: list[WorkflowEdgeDict]
|
||||
repairs_made: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def was_repaired(self) -> bool:
|
||||
"""Check if any repairs were made."""
|
||||
return len(self.repairs_made) > 0
|
||||
|
||||
|
||||
class EdgeRepair:
|
||||
"""
|
||||
Intelligent edge repair for workflow graphs.
|
||||
|
||||
Repairs are applied in order:
|
||||
1. Infer linear connections from node order (if no edges exist)
|
||||
2. Add missing branch edges for question-classifier
|
||||
3. Add missing branch edges for if-else
|
||||
4. Connect orphaned nodes
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def repair(cls, workflow_data: WorkflowDataDict) -> RepairResult:
|
||||
"""
|
||||
Repair edges in the workflow data.
|
||||
|
||||
Args:
|
||||
workflow_data: Dict containing 'nodes' and 'edges'
|
||||
|
||||
Returns:
|
||||
RepairResult with repaired nodes, edges, and repair logs
|
||||
"""
|
||||
nodes = list(workflow_data.get("nodes", []))
|
||||
edges = list(workflow_data.get("edges", []))
|
||||
repairs: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
logger.info("[EDGE REPAIR] Starting repair process for %s nodes, %s edges", len(nodes), len(edges))
|
||||
|
||||
# Build node lookup
|
||||
|
||||
# Build node lookup
|
||||
node_map = {n.get("id"): n for n in nodes if n.get("id")}
|
||||
node_ids = set(node_map.keys())
|
||||
|
||||
# 1. If no edges at all, infer linear chain
|
||||
if not edges and len(nodes) > 1:
|
||||
edges, inferred_repairs = cls._infer_linear_chain(nodes)
|
||||
repairs.extend(inferred_repairs)
|
||||
|
||||
# 2. Build edge index for analysis
|
||||
outgoing_edges: dict[str, list[WorkflowEdgeDict]] = {}
|
||||
incoming_edges: dict[str, list[WorkflowEdgeDict]] = {}
|
||||
for edge in edges:
|
||||
src = edge.get("source")
|
||||
tgt = edge.get("target")
|
||||
if src:
|
||||
outgoing_edges.setdefault(src, []).append(edge)
|
||||
if tgt:
|
||||
incoming_edges.setdefault(tgt, []).append(edge)
|
||||
|
||||
# 3. Repair question-classifier branches
|
||||
for node in nodes:
|
||||
if node.get("type") == "question-classifier":
|
||||
new_edges, branch_repairs, branch_warnings = cls._repair_classifier_branches(
|
||||
node, edges, outgoing_edges, node_ids
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
repairs.extend(branch_repairs)
|
||||
warnings.extend(branch_warnings)
|
||||
# Update outgoing index
|
||||
for edge in new_edges:
|
||||
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
|
||||
|
||||
# 4. Repair if-else branches
|
||||
for node in nodes:
|
||||
if node.get("type") == "if-else":
|
||||
new_edges, branch_repairs, branch_warnings = cls._repair_if_else_branches(
|
||||
node, edges, outgoing_edges, node_ids
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
repairs.extend(branch_repairs)
|
||||
warnings.extend(branch_warnings)
|
||||
# Update outgoing index
|
||||
for edge in new_edges:
|
||||
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
|
||||
|
||||
# 5. Connect orphaned nodes (nodes with no incoming edge, except start)
|
||||
new_edges, orphan_repairs = cls._connect_orphaned_nodes(nodes, edges, outgoing_edges, incoming_edges)
|
||||
edges.extend(new_edges)
|
||||
repairs.extend(orphan_repairs)
|
||||
|
||||
# 6. Connect nodes with no outgoing edge to 'end' (except end nodes)
|
||||
new_edges, terminal_repairs = cls._connect_terminal_nodes(nodes, edges, outgoing_edges)
|
||||
edges.extend(new_edges)
|
||||
repairs.extend(terminal_repairs)
|
||||
|
||||
if repairs:
|
||||
logger.info("[EDGE REPAIR] Completed with %s repairs:", len(repairs))
|
||||
for i, repair in enumerate(repairs, 1):
|
||||
logger.info("[EDGE REPAIR] %s. %s", i, repair)
|
||||
else:
|
||||
logger.info("[EDGE REPAIR] Completed - no repairs needed")
|
||||
|
||||
return RepairResult(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
repairs_made=repairs,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _infer_linear_chain(cls, nodes: list[WorkflowNodeDict]) -> tuple[list[WorkflowEdgeDict], list[str]]:
|
||||
"""
|
||||
Infer a linear chain of edges from node order.
|
||||
|
||||
This is used when no edges are provided at all.
|
||||
"""
|
||||
edges: list[WorkflowEdgeDict] = []
|
||||
repairs: list[str] = []
|
||||
|
||||
# Filter to get ordered node IDs
|
||||
node_ids = [n.get("id") for n in nodes if n.get("id")]
|
||||
|
||||
if len(node_ids) < 2:
|
||||
return edges, repairs
|
||||
|
||||
# Create edges between consecutive nodes
|
||||
for i in range(len(node_ids) - 1):
|
||||
src = node_ids[i]
|
||||
tgt = node_ids[i + 1]
|
||||
edges.append({"source": src, "target": tgt})
|
||||
repairs.append(f"Inferred edge: {src} -> {tgt}")
|
||||
|
||||
return edges, repairs
|
||||
|
||||
@classmethod
|
||||
def _repair_classifier_branches(
|
||||
cls,
|
||||
node: WorkflowNodeDict,
|
||||
edges: list[WorkflowEdgeDict],
|
||||
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
|
||||
valid_node_ids: set[str],
|
||||
) -> tuple[list[WorkflowEdgeDict], list[str], list[str]]:
|
||||
"""
|
||||
Repair missing branches for question-classifier nodes.
|
||||
|
||||
For each class that doesn't have an edge, create one pointing to 'end'.
|
||||
"""
|
||||
new_edges: list[WorkflowEdgeDict] = []
|
||||
repairs: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
node_id = node.get("id")
|
||||
if not node_id:
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
config = node.get("config", {})
|
||||
classes = config.get("classes", [])
|
||||
|
||||
if not classes:
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
# Get existing sourceHandles for this node
|
||||
existing_handles = set()
|
||||
for edge in outgoing_edges.get(node_id, []):
|
||||
handle = edge.get("sourceHandle")
|
||||
if handle:
|
||||
existing_handles.add(handle)
|
||||
|
||||
# Find 'end' node as default target
|
||||
end_node_id = "end"
|
||||
if "end" not in valid_node_ids:
|
||||
# Try to find an end node
|
||||
for nid in valid_node_ids:
|
||||
if "end" in nid.lower():
|
||||
end_node_id = nid
|
||||
break
|
||||
|
||||
# Add missing branches
|
||||
for cls_def in classes:
|
||||
if not isinstance(cls_def, dict):
|
||||
continue
|
||||
cls_id = cls_def.get("id")
|
||||
cls_name = cls_def.get("name", cls_id)
|
||||
|
||||
if cls_id and cls_id not in existing_handles:
|
||||
new_edge = {
|
||||
"source": node_id,
|
||||
"sourceHandle": cls_id,
|
||||
"target": end_node_id,
|
||||
}
|
||||
new_edges.append(new_edge)
|
||||
repairs.append(f"Added missing branch edge for class '{cls_name}' -> {end_node_id}")
|
||||
warnings.append(
|
||||
f"Auto-connected question-classifier branch '{cls_name}' to '{end_node_id}'. "
|
||||
"You may want to redirect this to a specific handler node."
|
||||
)
|
||||
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
@classmethod
|
||||
def _repair_if_else_branches(
|
||||
cls,
|
||||
node: WorkflowNodeDict,
|
||||
edges: list[WorkflowEdgeDict],
|
||||
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
|
||||
valid_node_ids: set[str],
|
||||
) -> tuple[list[WorkflowEdgeDict], list[str], list[str]]:
|
||||
"""
|
||||
Repair missing branches for if-else nodes.
|
||||
|
||||
If-else in Dify uses case_id as sourceHandle for each condition,
|
||||
plus 'false' for the else branch.
|
||||
"""
|
||||
new_edges: list[WorkflowEdgeDict] = []
|
||||
repairs: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
node_id = node.get("id")
|
||||
if not node_id:
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
# Get existing sourceHandles
|
||||
existing_handles = set()
|
||||
for edge in outgoing_edges.get(node_id, []):
|
||||
handle = edge.get("sourceHandle")
|
||||
if handle:
|
||||
existing_handles.add(handle)
|
||||
|
||||
# Find 'end' node as default target
|
||||
end_node_id = "end"
|
||||
if "end" not in valid_node_ids:
|
||||
for nid in valid_node_ids:
|
||||
if "end" in nid.lower():
|
||||
end_node_id = nid
|
||||
break
|
||||
|
||||
# Get required branches from config
|
||||
config = node.get("config", {})
|
||||
cases = config.get("cases", [])
|
||||
|
||||
# Build required handles: each case_id + 'false' for else
|
||||
required_branches = set()
|
||||
for case in cases:
|
||||
case_id = case.get("case_id")
|
||||
if case_id:
|
||||
required_branches.add(case_id)
|
||||
required_branches.add("false") # else branch
|
||||
|
||||
# Add missing branches
|
||||
for branch in required_branches:
|
||||
if branch not in existing_handles:
|
||||
new_edge = {
|
||||
"source": node_id,
|
||||
"sourceHandle": branch,
|
||||
"target": end_node_id,
|
||||
}
|
||||
new_edges.append(new_edge)
|
||||
repairs.append(f"Added missing if-else branch '{branch}' -> {end_node_id}")
|
||||
warnings.append(
|
||||
f"Auto-connected if-else branch '{branch}' to '{end_node_id}'. "
|
||||
"You may want to redirect this to a specific handler node."
|
||||
)
|
||||
|
||||
return new_edges, repairs, warnings
|
||||
|
||||
@classmethod
|
||||
def _connect_orphaned_nodes(
|
||||
cls,
|
||||
nodes: list[WorkflowNodeDict],
|
||||
edges: list[WorkflowEdgeDict],
|
||||
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
|
||||
incoming_edges: dict[str, list[WorkflowEdgeDict]],
|
||||
) -> tuple[list[WorkflowEdgeDict], list[str]]:
|
||||
"""
|
||||
Connect orphaned nodes to the previous node in sequence.
|
||||
|
||||
An orphaned node has no incoming edges and is not a 'start' node.
|
||||
"""
|
||||
new_edges: list[WorkflowEdgeDict] = []
|
||||
repairs: list[str] = []
|
||||
|
||||
node_ids = [n.get("id") for n in nodes if n.get("id")]
|
||||
node_types = {n.get("id"): n.get("type") for n in nodes}
|
||||
|
||||
for i, node_id in enumerate(node_ids):
|
||||
node_type = node_types.get(node_id)
|
||||
|
||||
# Skip start nodes - they don't need incoming edges
|
||||
if node_type == "start":
|
||||
continue
|
||||
|
||||
# Check if node has incoming edges
|
||||
if node_id not in incoming_edges or not incoming_edges[node_id]:
|
||||
# Find previous node to connect from
|
||||
if i > 0:
|
||||
prev_node_id = node_ids[i - 1]
|
||||
new_edge = {"source": prev_node_id, "target": node_id}
|
||||
new_edges.append(new_edge)
|
||||
repairs.append(f"Connected orphaned node: {prev_node_id} -> {node_id}")
|
||||
|
||||
# Update incoming_edges for subsequent checks
|
||||
incoming_edges.setdefault(node_id, []).append(new_edge)
|
||||
|
||||
return new_edges, repairs
|
||||
|
||||
@classmethod
|
||||
def _connect_terminal_nodes(
|
||||
cls,
|
||||
nodes: list[WorkflowNodeDict],
|
||||
edges: list[WorkflowEdgeDict],
|
||||
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
|
||||
) -> tuple[list[WorkflowEdgeDict], list[str]]:
|
||||
"""
|
||||
Connect terminal nodes (no outgoing edges) to 'end'.
|
||||
|
||||
A terminal node has no outgoing edges and is not an 'end' node.
|
||||
This ensures all branches eventually reach 'end'.
|
||||
"""
|
||||
new_edges: list[WorkflowEdgeDict] = []
|
||||
repairs: list[str] = []
|
||||
|
||||
# Find end node
|
||||
end_node_id = None
|
||||
node_ids = set()
|
||||
for n in nodes:
|
||||
nid = n.get("id")
|
||||
ntype = n.get("type")
|
||||
if nid:
|
||||
node_ids.add(nid)
|
||||
if ntype == "end":
|
||||
end_node_id = nid
|
||||
|
||||
if not end_node_id:
|
||||
# No end node found, can't connect
|
||||
return new_edges, repairs
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
node_type = node.get("type")
|
||||
|
||||
# Skip end nodes
|
||||
if node_type == "end":
|
||||
continue
|
||||
|
||||
# Skip nodes that already have outgoing edges
|
||||
if outgoing_edges.get(node_id):
|
||||
continue
|
||||
|
||||
# Connect to end
|
||||
new_edge = {"source": node_id, "target": end_node_id}
|
||||
new_edges.append(new_edge)
|
||||
repairs.append(f"Connected terminal node to end: {node_id} -> {end_node_id}")
|
||||
|
||||
# Update for subsequent checks
|
||||
outgoing_edges.setdefault(node_id, []).append(new_edge)
|
||||
|
||||
return new_edges, repairs
|
||||
@ -1,615 +0,0 @@
|
||||
"""
|
||||
GraphBuilder: Automatic workflow graph construction from node list.
|
||||
|
||||
This module implements the core logic for building complete workflow graphs
|
||||
from LLM-generated node lists with dependency declarations.
|
||||
|
||||
Key features:
|
||||
- Automatic start/end node generation
|
||||
- Dependency inference from variable references
|
||||
- Topological sorting with cycle detection
|
||||
- Special handling for branching nodes (if-else, question-classifier)
|
||||
- Silent error recovery where possible
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pattern to match variable references like {{#node_id.field#}}
|
||||
VAR_PATTERN = re.compile(r"\{\{#([^.#]+)\.[^#]+#\}\}")
|
||||
|
||||
# System variable prefixes to exclude from dependency inference
|
||||
SYSTEM_VAR_PREFIXES = {"sys", "start", "env"}
|
||||
|
||||
# Node types that have special branching behavior
|
||||
BRANCHING_NODE_TYPES = {"if-else", "question-classifier"}
|
||||
|
||||
# Container node types (iteration, loop) - these have internal subgraphs
|
||||
# but behave as single-input-single-output nodes in the external graph
|
||||
CONTAINER_NODE_TYPES = {"iteration", "loop"}
|
||||
|
||||
|
||||
class GraphBuildError(Exception):
|
||||
"""Raised when graph cannot be built due to unrecoverable errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CyclicDependencyError(GraphBuildError):
|
||||
"""Raised when cyclic dependencies are detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
"""
|
||||
Builds complete workflow graphs from LLM-generated node lists.
|
||||
|
||||
This class handles the conversion from a simplified node list format
|
||||
(with depends_on declarations) to a full workflow graph with nodes and edges.
|
||||
|
||||
The LLM only needs to generate:
|
||||
- Node configurations with depends_on arrays
|
||||
- Branch targets in config for branching nodes
|
||||
|
||||
The GraphBuilder automatically:
|
||||
- Adds start and end nodes
|
||||
- Generates all edges from dependencies
|
||||
- Infers implicit dependencies from variable references
|
||||
- Handles branching nodes (if-else, question-classifier)
|
||||
- Validates graph structure (no cycles, proper connectivity)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def build_graph(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
start_config: dict[str, Any] | None = None,
|
||||
end_config: dict[str, Any] | None = None,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Build a complete workflow graph from a node list.
|
||||
|
||||
Args:
|
||||
nodes: LLM-generated nodes (without start/end)
|
||||
start_config: Optional configuration for start node
|
||||
end_config: Optional configuration for end node
|
||||
|
||||
Returns:
|
||||
Tuple of (complete_nodes, edges) where:
|
||||
- complete_nodes includes start, user nodes, and end
|
||||
- edges contains all connections
|
||||
|
||||
Raises:
|
||||
CyclicDependencyError: If cyclic dependencies are detected
|
||||
GraphBuildError: If graph cannot be built
|
||||
"""
|
||||
if not nodes:
|
||||
# Empty node list - create minimal workflow
|
||||
start_node = cls._create_start_node([], start_config)
|
||||
end_node = cls._create_end_node([], end_config)
|
||||
edge = cls._create_edge("start", "end")
|
||||
return [start_node, end_node], [edge]
|
||||
|
||||
# Build node index for quick lookup
|
||||
node_map = {node["id"]: node for node in nodes}
|
||||
|
||||
# Step 1: Extract explicit dependencies from depends_on
|
||||
dependencies = cls._extract_explicit_dependencies(nodes)
|
||||
|
||||
# Step 2: Infer implicit dependencies from variable references
|
||||
dependencies = cls._infer_dependencies_from_variables(nodes, dependencies, node_map)
|
||||
|
||||
# Step 3: Validate and fix dependencies (remove invalid references)
|
||||
dependencies = cls._validate_dependencies(dependencies, node_map)
|
||||
|
||||
# Step 4: Topological sort (detects cycles)
|
||||
sorted_node_ids = cls._topological_sort(nodes, dependencies)
|
||||
|
||||
# Step 5: Generate start node
|
||||
start_node = cls._create_start_node(nodes, start_config)
|
||||
|
||||
# Step 6: Generate edges
|
||||
edges = cls._generate_edges(nodes, sorted_node_ids, dependencies, node_map)
|
||||
|
||||
# Step 7: Find terminal nodes and generate end node
|
||||
terminal_nodes = cls._find_terminal_nodes(nodes, dependencies, node_map)
|
||||
end_node = cls._create_end_node(terminal_nodes, end_config)
|
||||
|
||||
# Step 8: Add edges from terminal nodes to end
|
||||
for terminal_id in terminal_nodes:
|
||||
edges.append(cls._create_edge(terminal_id, "end"))
|
||||
|
||||
# Step 9: Assemble complete node list
|
||||
all_nodes = [start_node, *nodes, end_node]
|
||||
|
||||
return all_nodes, edges
|
||||
|
||||
@classmethod
|
||||
def _extract_explicit_dependencies(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract explicit dependencies from depends_on field.
|
||||
|
||||
Args:
|
||||
nodes: List of nodes with optional depends_on field
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node_id -> list of dependency node_ids
|
||||
"""
|
||||
dependencies: dict[str, list[str]] = {}
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id", "")
|
||||
depends_on = node.get("depends_on", [])
|
||||
|
||||
# Ensure depends_on is a list
|
||||
if isinstance(depends_on, str):
|
||||
depends_on = [depends_on] if depends_on else []
|
||||
elif not isinstance(depends_on, list):
|
||||
depends_on = []
|
||||
|
||||
dependencies[node_id] = list(depends_on)
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _infer_dependencies_from_variables(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
explicit_deps: dict[str, list[str]],
|
||||
node_map: dict[str, dict[str, Any]],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Infer implicit dependencies from variable references in config.
|
||||
|
||||
Scans node configurations for patterns like {{#node_id.field#}}
|
||||
and adds those as dependencies if not already declared.
|
||||
|
||||
Args:
|
||||
nodes: List of nodes
|
||||
explicit_deps: Already extracted explicit dependencies
|
||||
node_map: Map of node_id -> node for validation
|
||||
|
||||
Returns:
|
||||
Updated dependencies dictionary
|
||||
"""
|
||||
for node in nodes:
|
||||
node_id = node.get("id", "")
|
||||
config = node.get("config", {})
|
||||
|
||||
# Serialize config to search for variable references
|
||||
try:
|
||||
config_str = json.dumps(config, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
# Find all variable references
|
||||
referenced_nodes = set(VAR_PATTERN.findall(config_str))
|
||||
|
||||
# Filter out system variables
|
||||
referenced_nodes -= SYSTEM_VAR_PREFIXES
|
||||
|
||||
# Ensure node_id exists in dependencies
|
||||
if node_id not in explicit_deps:
|
||||
explicit_deps[node_id] = []
|
||||
|
||||
# Add inferred dependencies
|
||||
for ref in referenced_nodes:
|
||||
# Skip self-references (e.g., loop nodes referencing their own outputs)
|
||||
if ref == node_id:
|
||||
logger.debug(
|
||||
"Skipping self-reference: %s -> %s",
|
||||
node_id,
|
||||
ref,
|
||||
)
|
||||
continue
|
||||
|
||||
if ref in node_map and ref not in explicit_deps[node_id]:
|
||||
explicit_deps[node_id].append(ref)
|
||||
logger.debug(
|
||||
"Inferred dependency: %s -> %s (from variable reference)",
|
||||
node_id,
|
||||
ref,
|
||||
)
|
||||
|
||||
return explicit_deps
|
||||
|
||||
@classmethod
|
||||
def _validate_dependencies(
|
||||
cls,
|
||||
dependencies: dict[str, list[str]],
|
||||
node_map: dict[str, dict[str, Any]],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Validate dependencies and remove invalid references.
|
||||
|
||||
Silent fix: References to non-existent nodes are removed.
|
||||
|
||||
Args:
|
||||
dependencies: Dependencies to validate
|
||||
node_map: Map of valid node IDs
|
||||
|
||||
Returns:
|
||||
Validated dependencies
|
||||
"""
|
||||
valid_deps: dict[str, list[str]] = {}
|
||||
|
||||
for node_id, deps in dependencies.items():
|
||||
valid_deps[node_id] = []
|
||||
for dep in deps:
|
||||
if dep in node_map:
|
||||
valid_deps[node_id].append(dep)
|
||||
else:
|
||||
logger.warning(
|
||||
"Removed invalid dependency: %s -> %s (node does not exist)",
|
||||
node_id,
|
||||
dep,
|
||||
)
|
||||
|
||||
return valid_deps
|
||||
|
||||
@classmethod
|
||||
def _topological_sort(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
dependencies: dict[str, list[str]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Perform topological sort on nodes based on dependencies.
|
||||
|
||||
Uses Kahn's algorithm for cycle detection.
|
||||
|
||||
Args:
|
||||
nodes: List of nodes
|
||||
dependencies: Dependency graph
|
||||
|
||||
Returns:
|
||||
List of node IDs in topological order
|
||||
|
||||
Raises:
|
||||
CyclicDependencyError: If cyclic dependencies are detected
|
||||
"""
|
||||
# Build in-degree map
|
||||
in_degree: dict[str, int] = defaultdict(int)
|
||||
reverse_deps: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
node_ids = {node["id"] for node in nodes}
|
||||
|
||||
for node_id in node_ids:
|
||||
in_degree[node_id] = 0
|
||||
|
||||
for node_id, deps in dependencies.items():
|
||||
for dep in deps:
|
||||
if dep in node_ids:
|
||||
in_degree[node_id] += 1
|
||||
reverse_deps[dep].append(node_id)
|
||||
|
||||
# Start with nodes that have no dependencies
|
||||
queue = [nid for nid in node_ids if in_degree[nid] == 0]
|
||||
sorted_ids: list[str] = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
sorted_ids.append(current)
|
||||
|
||||
for dependent in reverse_deps[current]:
|
||||
in_degree[dependent] -= 1
|
||||
if in_degree[dependent] == 0:
|
||||
queue.append(dependent)
|
||||
|
||||
# Check for cycles
|
||||
if len(sorted_ids) != len(node_ids):
|
||||
remaining = node_ids - set(sorted_ids)
|
||||
raise CyclicDependencyError(f"Cyclic dependency detected involving nodes: {remaining}")
|
||||
|
||||
return sorted_ids
|
||||
|
||||
@classmethod
|
||||
def _generate_edges(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
sorted_node_ids: list[str],
|
||||
dependencies: dict[str, list[str]],
|
||||
node_map: dict[str, dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Generate all edges based on dependencies and special node handling.
|
||||
|
||||
Args:
|
||||
nodes: List of nodes
|
||||
sorted_node_ids: Topologically sorted node IDs
|
||||
dependencies: Dependency graph
|
||||
node_map: Map of node_id -> node
|
||||
|
||||
Returns:
|
||||
List of edge dictionaries
|
||||
"""
|
||||
edges: list[dict[str, Any]] = []
|
||||
nodes_with_incoming: set[str] = set()
|
||||
|
||||
# Track which nodes have outgoing edges from branching
|
||||
branching_sources: set[str] = set()
|
||||
|
||||
# First pass: Handle branching nodes
|
||||
for node in nodes:
|
||||
node_id = node.get("id", "")
|
||||
node_type = node.get("type", "")
|
||||
|
||||
if node_type == "if-else":
|
||||
branch_edges = cls._handle_if_else_node(node)
|
||||
edges.extend(branch_edges)
|
||||
branching_sources.add(node_id)
|
||||
nodes_with_incoming.update(edge["target"] for edge in branch_edges)
|
||||
|
||||
elif node_type == "question-classifier":
|
||||
branch_edges = cls._handle_question_classifier_node(node)
|
||||
edges.extend(branch_edges)
|
||||
branching_sources.add(node_id)
|
||||
nodes_with_incoming.update(edge["target"] for edge in branch_edges)
|
||||
|
||||
# Second pass: Generate edges from dependencies
|
||||
for node_id in sorted_node_ids:
|
||||
deps = dependencies.get(node_id, [])
|
||||
|
||||
if deps:
|
||||
# Connect from each dependency
|
||||
for dep_id in deps:
|
||||
dep_node = node_map.get(dep_id, {})
|
||||
dep_type = dep_node.get("type", "")
|
||||
|
||||
# Skip if dependency is a branching node (edges handled above)
|
||||
if dep_type in BRANCHING_NODE_TYPES:
|
||||
continue
|
||||
|
||||
edges.append(cls._create_edge(dep_id, node_id))
|
||||
nodes_with_incoming.add(node_id)
|
||||
else:
|
||||
# No dependencies - connect from start
|
||||
# But skip if this node receives edges from branching nodes
|
||||
if node_id not in nodes_with_incoming:
|
||||
edges.append(cls._create_edge("start", node_id))
|
||||
nodes_with_incoming.add(node_id)
|
||||
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
def _handle_if_else_node(
|
||||
cls,
|
||||
node: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Handle if-else node branching.
|
||||
|
||||
Expects config to contain true_branch and/or false_branch.
|
||||
|
||||
Args:
|
||||
node: If-else node
|
||||
|
||||
Returns:
|
||||
List of branch edges
|
||||
"""
|
||||
edges: list[dict[str, Any]] = []
|
||||
node_id = node.get("id", "")
|
||||
config = node.get("config", {})
|
||||
|
||||
true_branch = config.get("true_branch")
|
||||
false_branch = config.get("false_branch")
|
||||
|
||||
if true_branch:
|
||||
edges.append(cls._create_edge(node_id, true_branch, source_handle="true"))
|
||||
|
||||
if false_branch:
|
||||
edges.append(cls._create_edge(node_id, false_branch, source_handle="false"))
|
||||
|
||||
# If no branches specified, log warning
|
||||
if not true_branch and not false_branch:
|
||||
logger.warning(
|
||||
"if-else node %s has no branch targets specified",
|
||||
node_id,
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
def _handle_question_classifier_node(
|
||||
cls,
|
||||
node: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Handle question-classifier node branching.
|
||||
|
||||
Expects config.classes to contain class definitions with target fields.
|
||||
|
||||
Args:
|
||||
node: Question-classifier node
|
||||
|
||||
Returns:
|
||||
List of branch edges
|
||||
"""
|
||||
edges: list[dict[str, Any]] = []
|
||||
node_id = node.get("id", "")
|
||||
config = node.get("config", {})
|
||||
classes = config.get("classes", [])
|
||||
|
||||
if not classes:
|
||||
logger.warning(
|
||||
"question-classifier node %s has no classes defined",
|
||||
node_id,
|
||||
)
|
||||
return edges
|
||||
|
||||
for cls_def in classes:
|
||||
class_id = cls_def.get("id", "")
|
||||
target = cls_def.get("target")
|
||||
|
||||
if target:
|
||||
edges.append(cls._create_edge(node_id, target, source_handle=class_id))
|
||||
else:
|
||||
# Silent fix: Connect to end if no target specified
|
||||
edges.append(cls._create_edge(node_id, "end", source_handle=class_id))
|
||||
logger.debug(
|
||||
"question-classifier class %s has no target, connecting to end",
|
||||
class_id,
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
def _find_terminal_nodes(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
dependencies: dict[str, list[str]],
|
||||
node_map: dict[str, dict[str, Any]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Find nodes that should connect to the end node.
|
||||
|
||||
Terminal nodes are those that:
|
||||
- Are not dependencies of any other node
|
||||
- Are not branching nodes (those connect to their branches)
|
||||
|
||||
Args:
|
||||
nodes: List of nodes
|
||||
dependencies: Dependency graph
|
||||
node_map: Map of node_id -> node
|
||||
|
||||
Returns:
|
||||
List of terminal node IDs
|
||||
"""
|
||||
# Build set of all nodes that are depended upon
|
||||
depended_upon: set[str] = set()
|
||||
for deps in dependencies.values():
|
||||
depended_upon.update(deps)
|
||||
|
||||
# Also track nodes that are branch targets
|
||||
branch_targets: set[str] = set()
|
||||
branching_nodes: set[str] = set()
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id", "")
|
||||
node_type = node.get("type", "")
|
||||
config = node.get("config", {})
|
||||
|
||||
if node_type == "if-else":
|
||||
branching_nodes.add(node_id)
|
||||
if config.get("true_branch"):
|
||||
branch_targets.add(config["true_branch"])
|
||||
if config.get("false_branch"):
|
||||
branch_targets.add(config["false_branch"])
|
||||
|
||||
elif node_type == "question-classifier":
|
||||
branching_nodes.add(node_id)
|
||||
for cls_def in config.get("classes", []):
|
||||
if cls_def.get("target"):
|
||||
branch_targets.add(cls_def["target"])
|
||||
|
||||
# Find terminal nodes
|
||||
terminal_nodes: list[str] = []
|
||||
for node in nodes:
|
||||
node_id = node.get("id", "")
|
||||
node_type = node.get("type", "")
|
||||
|
||||
# Skip branching nodes - they don't connect to end directly
|
||||
if node_type in BRANCHING_NODE_TYPES:
|
||||
continue
|
||||
|
||||
# Terminal if not depended upon and not a branch target that leads elsewhere
|
||||
if node_id not in depended_upon:
|
||||
terminal_nodes.append(node_id)
|
||||
|
||||
# If no terminal nodes found (shouldn't happen), use all non-branching nodes
|
||||
if not terminal_nodes:
|
||||
terminal_nodes = [node["id"] for node in nodes if node.get("type") not in BRANCHING_NODE_TYPES]
|
||||
logger.warning("No terminal nodes found, using all non-branching nodes")
|
||||
|
||||
return terminal_nodes
|
||||
|
||||
@classmethod
|
||||
def _create_start_node(
|
||||
cls,
|
||||
nodes: list[dict[str, Any]],
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a start node.
|
||||
|
||||
Args:
|
||||
nodes: User nodes (for potential config inference)
|
||||
config: Optional start node configuration
|
||||
|
||||
Returns:
|
||||
Start node dictionary
|
||||
"""
|
||||
return {
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"config": config or {},
|
||||
"data": {},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _create_end_node(
|
||||
cls,
|
||||
terminal_nodes: list[str],
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create an end node.
|
||||
|
||||
Args:
|
||||
terminal_nodes: Nodes that will connect to end
|
||||
config: Optional end node configuration
|
||||
|
||||
Returns:
|
||||
End node dictionary
|
||||
"""
|
||||
return {
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"config": config or {},
|
||||
"data": {},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _create_edge(
|
||||
cls,
|
||||
source: str,
|
||||
target: str,
|
||||
source_handle: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create an edge dictionary.
|
||||
|
||||
Args:
|
||||
source: Source node ID
|
||||
target: Target node ID
|
||||
source_handle: Optional handle for branching (e.g., "true", "false", class_id)
|
||||
|
||||
Returns:
|
||||
Edge dictionary
|
||||
"""
|
||||
edge: dict[str, Any] = {
|
||||
"id": f"{source}-{target}-{uuid.uuid4().hex[:8]}",
|
||||
"source": source,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
if source_handle:
|
||||
edge["sourceHandle"] = source_handle
|
||||
else:
|
||||
edge["sourceHandle"] = "source"
|
||||
|
||||
edge["targetHandle"] = "target"
|
||||
|
||||
return edge
|
||||
@ -1,280 +0,0 @@
|
||||
"""
|
||||
Graph Validator for Workflow Generation
|
||||
|
||||
Validates workflow graph structure using graph algorithms:
|
||||
- Reachability from start node (BFS)
|
||||
- Reachability to end node (reverse BFS)
|
||||
- Branch edge validation for if-else and classifier nodes
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphError:
|
||||
"""Represents a structural error in the workflow graph."""
|
||||
|
||||
node_id: str
|
||||
node_type: str
|
||||
error_type: str # "unreachable", "dead_end", "cycle", "missing_start", "missing_end"
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphValidationResult:
|
||||
"""Result of graph validation."""
|
||||
|
||||
success: bool
|
||||
errors: list[GraphError] = field(default_factory=list)
|
||||
warnings: list[GraphError] = field(default_factory=list)
|
||||
execution_time: float = 0.0
|
||||
stats: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class GraphValidator:
|
||||
"""
|
||||
Validates workflow graph structure using proper graph algorithms.
|
||||
|
||||
Performs:
|
||||
1. Forward reachability analysis (BFS from start)
|
||||
2. Backward reachability analysis (reverse BFS from end)
|
||||
3. Branch edge validation for if-else and classifier nodes
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _build_adjacency(
|
||||
nodes: dict[str, dict], edges: list[dict]
|
||||
) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
|
||||
"""Build forward and reverse adjacency lists from edges."""
|
||||
outgoing: dict[str, list[str]] = {node_id: [] for node_id in nodes}
|
||||
incoming: dict[str, list[str]] = {node_id: [] for node_id in nodes}
|
||||
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
if source in outgoing and target in incoming:
|
||||
outgoing[source].append(target)
|
||||
incoming[target].append(source)
|
||||
|
||||
return outgoing, incoming
|
||||
|
||||
@staticmethod
|
||||
def _bfs_reachable(start: str, adjacency: dict[str, list[str]]) -> set[str]:
|
||||
"""BFS to find all nodes reachable from start node."""
|
||||
if start not in adjacency:
|
||||
return set()
|
||||
|
||||
visited = set()
|
||||
queue = deque([start])
|
||||
visited.add(start)
|
||||
|
||||
while queue:
|
||||
current = queue.popleft()
|
||||
for neighbor in adjacency.get(current, []):
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append(neighbor)
|
||||
|
||||
return visited
|
||||
|
||||
@staticmethod
|
||||
def validate(workflow_data: dict) -> GraphValidationResult:
|
||||
"""Validate workflow graph structure."""
|
||||
start_time = time.time()
|
||||
errors: list[GraphError] = []
|
||||
warnings: list[GraphError] = []
|
||||
|
||||
nodes_list = workflow_data.get("nodes", [])
|
||||
edges_list = workflow_data.get("edges", [])
|
||||
nodes = {n["id"]: n for n in nodes_list if n.get("id")}
|
||||
|
||||
# Find start and end nodes
|
||||
start_node_id = None
|
||||
end_node_ids = []
|
||||
|
||||
for node_id, node in nodes.items():
|
||||
node_type = node.get("type")
|
||||
if node_type == "start":
|
||||
start_node_id = node_id
|
||||
elif node_type == "end":
|
||||
end_node_ids.append(node_id)
|
||||
|
||||
# Check start node exists
|
||||
if not start_node_id:
|
||||
errors.append(
|
||||
GraphError(
|
||||
node_id="workflow",
|
||||
node_type="workflow",
|
||||
error_type="missing_start",
|
||||
message="Workflow has no start node",
|
||||
)
|
||||
)
|
||||
|
||||
# Check end node exists
|
||||
if not end_node_ids:
|
||||
errors.append(
|
||||
GraphError(
|
||||
node_id="workflow",
|
||||
node_type="workflow",
|
||||
error_type="missing_end",
|
||||
message="Workflow has no end node",
|
||||
)
|
||||
)
|
||||
|
||||
# If missing start or end, can't do reachability analysis
|
||||
if not start_node_id or not end_node_ids:
|
||||
execution_time = time.time() - start_time
|
||||
return GraphValidationResult(
|
||||
success=False,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
execution_time=execution_time,
|
||||
stats={"nodes": len(nodes), "edges": len(edges_list)},
|
||||
)
|
||||
|
||||
# Build adjacency lists
|
||||
outgoing, incoming = GraphValidator._build_adjacency(nodes, edges_list)
|
||||
|
||||
# --- FORWARD REACHABILITY: BFS from start ---
|
||||
reachable_from_start = GraphValidator._bfs_reachable(start_node_id, outgoing)
|
||||
|
||||
# Find unreachable nodes
|
||||
unreachable_nodes = set(nodes.keys()) - reachable_from_start
|
||||
for node_id in unreachable_nodes:
|
||||
node = nodes[node_id]
|
||||
errors.append(
|
||||
GraphError(
|
||||
node_id=node_id,
|
||||
node_type=node.get("type", "unknown"),
|
||||
error_type="unreachable",
|
||||
message=f"Node '{node_id}' is not reachable from start node",
|
||||
)
|
||||
)
|
||||
|
||||
# --- BACKWARD REACHABILITY: Reverse BFS from end nodes ---
|
||||
can_reach_end: set[str] = set()
|
||||
for end_id in end_node_ids:
|
||||
can_reach_end.update(GraphValidator._bfs_reachable(end_id, incoming))
|
||||
|
||||
# Find dead-end nodes (can't reach any end node)
|
||||
dead_end_nodes = set(nodes.keys()) - can_reach_end
|
||||
for node_id in dead_end_nodes:
|
||||
if node_id in unreachable_nodes:
|
||||
continue
|
||||
node = nodes[node_id]
|
||||
warnings.append(
|
||||
GraphError(
|
||||
node_id=node_id,
|
||||
node_type=node.get("type", "unknown"),
|
||||
error_type="dead_end",
|
||||
message=f"Node '{node_id}' cannot reach any end node (dead end)",
|
||||
)
|
||||
)
|
||||
|
||||
# --- Start node has outgoing edges? ---
|
||||
if not outgoing.get(start_node_id):
|
||||
errors.append(
|
||||
GraphError(
|
||||
node_id=start_node_id,
|
||||
node_type="start",
|
||||
error_type="disconnected",
|
||||
message="Start node has no outgoing connections",
|
||||
)
|
||||
)
|
||||
|
||||
# --- End nodes have incoming edges? ---
|
||||
for end_id in end_node_ids:
|
||||
if not incoming.get(end_id):
|
||||
errors.append(
|
||||
GraphError(
|
||||
node_id=end_id,
|
||||
node_type="end",
|
||||
error_type="disconnected",
|
||||
message="End node has no incoming connections",
|
||||
)
|
||||
)
|
||||
|
||||
# --- BRANCH EDGE VALIDATION ---
|
||||
edge_handles: dict[str, set[str]] = {}
|
||||
for edge in edges_list:
|
||||
source = edge.get("source")
|
||||
handle = edge.get("sourceHandle", "")
|
||||
if source:
|
||||
if source not in edge_handles:
|
||||
edge_handles[source] = set()
|
||||
edge_handles[source].add(handle)
|
||||
|
||||
# Check if-else and question-classifier nodes
|
||||
for node_id, node in nodes.items():
|
||||
node_type = node.get("type")
|
||||
|
||||
if node_type == "if-else":
|
||||
handles = edge_handles.get(node_id, set())
|
||||
config = node.get("config", {})
|
||||
cases = config.get("cases", [])
|
||||
|
||||
required_handles = set()
|
||||
for case in cases:
|
||||
case_id = case.get("case_id")
|
||||
if case_id:
|
||||
required_handles.add(case_id)
|
||||
required_handles.add("false")
|
||||
|
||||
missing = required_handles - handles
|
||||
for handle in missing:
|
||||
errors.append(
|
||||
GraphError(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
error_type="missing_branch",
|
||||
message=f"If-else node '{node_id}' missing edge for branch '{handle}'",
|
||||
)
|
||||
)
|
||||
|
||||
elif node_type == "question-classifier":
|
||||
handles = edge_handles.get(node_id, set())
|
||||
config = node.get("config", {})
|
||||
classes = config.get("classes", [])
|
||||
|
||||
required_handles = set()
|
||||
for cls in classes:
|
||||
if isinstance(cls, dict):
|
||||
cls_id = cls.get("id")
|
||||
if cls_id:
|
||||
required_handles.add(cls_id)
|
||||
|
||||
missing = required_handles - handles
|
||||
for handle in missing:
|
||||
cls_name = handle
|
||||
for cls in classes:
|
||||
if isinstance(cls, dict) and cls.get("id") == handle:
|
||||
cls_name = cls.get("name", handle)
|
||||
break
|
||||
errors.append(
|
||||
GraphError(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
error_type="missing_branch",
|
||||
message=f"Classifier '{node_id}' missing edge for class '{cls_name}'",
|
||||
)
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
success = len(errors) == 0
|
||||
|
||||
return GraphValidationResult(
|
||||
success=success,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
execution_time=execution_time,
|
||||
stats={
|
||||
"nodes": len(nodes),
|
||||
"edges": len(edges_list),
|
||||
"reachable_from_start": len(reachable_from_start),
|
||||
"can_reach_end": len(can_reach_end),
|
||||
"unreachable": len(unreachable_nodes),
|
||||
"dead_ends": len(dead_end_nodes - unreachable_nodes),
|
||||
},
|
||||
)
|
||||
@ -1,113 +0,0 @@
|
||||
import logging
|
||||
|
||||
from core.workflow.generator.types import WorkflowDataDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_mermaid(workflow_data: WorkflowDataDict) -> str:
|
||||
"""
|
||||
Generate a Mermaid flowchart from workflow data consisting of nodes and edges.
|
||||
|
||||
Args:
|
||||
workflow_data: Dict containing 'nodes' (list) and 'edges' (list)
|
||||
|
||||
Returns:
|
||||
String containing the Mermaid flowchart syntax
|
||||
"""
|
||||
nodes = workflow_data.get("nodes", [])
|
||||
edges = workflow_data.get("edges", [])
|
||||
|
||||
lines = ["flowchart TD"]
|
||||
|
||||
# 1. Define Nodes
|
||||
# Format: node_id["title<br/>type"] or similar
|
||||
# We will use the Vibe Workflow standard format: id["type=TYPE|title=TITLE"]
|
||||
# Or specifically for tool nodes: id["type=tool|title=TITLE|tool=TOOL_KEY"]
|
||||
|
||||
# Map of original IDs to safe Mermaid IDs
|
||||
id_map = {}
|
||||
|
||||
def get_safe_id(original_id: str) -> str:
|
||||
if original_id == "end":
|
||||
return "end_node"
|
||||
if original_id == "subgraph":
|
||||
return "subgraph_node"
|
||||
# Mermaid IDs should be alphanumeric.
|
||||
# If the ID has special chars, we might need to escape or hash, but Vibe usually generates simple IDs.
|
||||
# We'll trust standard IDs but handle the reserved keyword 'end'.
|
||||
return original_id
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
safe_id = get_safe_id(node_id)
|
||||
id_map[node_id] = safe_id
|
||||
|
||||
node_type = node.get("type", "unknown")
|
||||
title = node.get("title", "Untitled")
|
||||
|
||||
# Escape quotes in title
|
||||
safe_title = title.replace('"', "'")
|
||||
|
||||
if node_type == "tool":
|
||||
config = node.get("config", {})
|
||||
# Try multiple fields for tool reference
|
||||
tool_ref = (
|
||||
config.get("tool_key")
|
||||
or config.get("tool")
|
||||
or config.get("tool_name")
|
||||
or node.get("tool_name")
|
||||
or "unknown"
|
||||
)
|
||||
node_def = f'{safe_id}["type={node_type}|title={safe_title}|tool={tool_ref}"]'
|
||||
else:
|
||||
node_def = f'{safe_id}["type={node_type}|title={safe_title}"]'
|
||||
|
||||
lines.append(f" {node_def}")
|
||||
|
||||
# 2. Define Edges
|
||||
# Format: source --> target
|
||||
|
||||
# Track defined nodes to avoid edge errors
|
||||
defined_node_ids = {n.get("id") for n in nodes if n.get("id")}
|
||||
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
# Skip invalid edges
|
||||
if not source or not target:
|
||||
continue
|
||||
|
||||
if source not in defined_node_ids or target not in defined_node_ids:
|
||||
continue
|
||||
|
||||
safe_source = id_map.get(source, source)
|
||||
safe_target = id_map.get(target, target)
|
||||
|
||||
# Handle conditional branches (true/false) if present
|
||||
# In Dify workflow, sourceHandle is often used for this
|
||||
source_handle = edge.get("sourceHandle")
|
||||
label = ""
|
||||
|
||||
if source_handle == "true":
|
||||
label = "|true|"
|
||||
elif source_handle == "false":
|
||||
label = "|false|"
|
||||
elif source_handle and source_handle != "source":
|
||||
# For question-classifier or other multi-path nodes
|
||||
# Clean up handle for display if needed
|
||||
safe_handle = str(source_handle).replace('"', "'")
|
||||
label = f"|{safe_handle}|"
|
||||
|
||||
edge_line = f" {safe_source} -->{label} {safe_target}"
|
||||
lines.append(edge_line)
|
||||
|
||||
# Start/End nodes are implicitly handled if they are in the 'nodes' list
|
||||
# If not, we might need to add them, but usually the Builder should produce them.
|
||||
|
||||
result = "\n".join(lines)
|
||||
return result
|
||||
@ -1,304 +0,0 @@
|
||||
"""
|
||||
Node Repair Utility for Vibe Workflow Generation.
|
||||
|
||||
This module provides intelligent node configuration repair capabilities.
|
||||
It can detect and fix common node configuration issues:
|
||||
- Invalid comparison operators in if-else nodes (e.g. '>=' -> '≥')
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.workflow.generator.types import WorkflowNodeDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeRepairResult:
|
||||
"""Result of node repair operation."""
|
||||
|
||||
nodes: list[WorkflowNodeDict]
|
||||
repairs_made: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def was_repaired(self) -> bool:
|
||||
"""Check if any repairs were made."""
|
||||
return len(self.repairs_made) > 0
|
||||
|
||||
|
||||
class NodeRepair:
|
||||
"""
|
||||
Intelligent node configuration repair.
|
||||
"""
|
||||
|
||||
OPERATOR_MAP = {
|
||||
">=": "≥",
|
||||
"<=": "≤",
|
||||
"!=": "≠",
|
||||
"==": "=",
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
"json": "object",
|
||||
"dict": "object",
|
||||
"dictionary": "object",
|
||||
"float": "number",
|
||||
"int": "number",
|
||||
"integer": "number",
|
||||
"double": "number",
|
||||
"str": "string",
|
||||
"text": "string",
|
||||
"bool": "boolean",
|
||||
"list": "array[object]",
|
||||
"array": "array[object]",
|
||||
}
|
||||
|
||||
_REPAIR_HANDLERS = {
|
||||
"if-else": "_repair_if_else_operators",
|
||||
"variable-aggregator": "_repair_variable_aggregator_variables",
|
||||
"code": "_repair_code_node_config",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def repair(
|
||||
cls,
|
||||
nodes: list[WorkflowNodeDict],
|
||||
llm_callback=None,
|
||||
) -> NodeRepairResult:
|
||||
"""
|
||||
Repair node configurations.
|
||||
|
||||
Args:
|
||||
nodes: List of node dictionaries
|
||||
llm_callback: Optional callback(node, issue_desc) -> fixed_config_part
|
||||
|
||||
Returns:
|
||||
NodeRepairResult with repaired nodes and logs
|
||||
"""
|
||||
# Deep copy to avoid mutating original
|
||||
nodes = copy.deepcopy(nodes)
|
||||
repairs: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
logger.info("[NODE REPAIR] Starting repair process for %s nodes", len(nodes))
|
||||
|
||||
for node in nodes:
|
||||
node_type = node.get("type")
|
||||
|
||||
# 1. Rule-based repairs
|
||||
handler_name = cls._REPAIR_HANDLERS.get(node_type)
|
||||
if handler_name:
|
||||
handler = getattr(cls, handler_name)
|
||||
# Check if handler accepts llm_callback (inspect signature or just pass generic kwargs?)
|
||||
# Simplest for now: handlers signature: (node, repairs, llm_callback=None)
|
||||
try:
|
||||
handler(node, repairs, llm_callback=llm_callback)
|
||||
except TypeError:
|
||||
# Fallback for handlers that don't accept llm_callback yet
|
||||
handler(node, repairs)
|
||||
|
||||
# Add other node type repairs here as needed
|
||||
|
||||
if repairs:
|
||||
logger.info("[NODE REPAIR] Completed with %s repairs:", len(repairs))
|
||||
for i, repair in enumerate(repairs, 1):
|
||||
logger.info("[NODE REPAIR] %s. %s", i, repair)
|
||||
else:
|
||||
logger.info("[NODE REPAIR] Completed - no repairs needed")
|
||||
|
||||
return NodeRepairResult(
|
||||
nodes=nodes,
|
||||
repairs_made=repairs,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _repair_if_else_operators(cls, node: WorkflowNodeDict, repairs: list[str], **kwargs):
|
||||
"""
|
||||
Normalize comparison operators in if-else nodes.
|
||||
And ensure 'id' field exists for cases and conditions (frontend requirement).
|
||||
"""
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
cases = config.get("cases", [])
|
||||
|
||||
for case in cases:
|
||||
# Ensure case_id
|
||||
if "case_id" not in case:
|
||||
case["case_id"] = str(uuid.uuid4())
|
||||
repairs.append(f"Generated missing case_id for case in node '{node_id}'")
|
||||
|
||||
conditions = case.get("conditions", [])
|
||||
for condition in conditions:
|
||||
# Ensure condition id
|
||||
if "id" not in condition:
|
||||
condition["id"] = str(uuid.uuid4())
|
||||
# Not logging this repair to avoid clutter, as it's a structural fix
|
||||
|
||||
# Ensure value type (LLM might return int/float, but we need str/bool/list)
|
||||
val = condition.get("value")
|
||||
if isinstance(val, (int, float)) and not isinstance(val, bool):
|
||||
condition["value"] = str(val)
|
||||
repairs.append(f"Coerced numeric value to string in node '{node_id}'")
|
||||
|
||||
op = condition.get("comparison_operator")
|
||||
if op in cls.OPERATOR_MAP:
|
||||
new_op = cls.OPERATOR_MAP[op]
|
||||
condition["comparison_operator"] = new_op
|
||||
repairs.append(f"Normalized operator '{op}' to '{new_op}' in node '{node_id}'")
|
||||
|
||||
@classmethod
|
||||
def _repair_variable_aggregator_variables(cls, node: WorkflowNodeDict, repairs: list[str]):
|
||||
"""
|
||||
Repair variable-aggregator variables format.
|
||||
Converts dict format to list[list[str]] format.
|
||||
Expected: [["node_id", "field"], ["node_id2", "field2"]]
|
||||
May receive: [{"name": "...", "value_selector": ["node_id", "field"]}, ...]
|
||||
"""
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
variables = config.get("variables", [])
|
||||
|
||||
if not variables:
|
||||
return
|
||||
|
||||
repaired = False
|
||||
repaired_variables = []
|
||||
|
||||
for var in variables:
|
||||
if isinstance(var, dict):
|
||||
# Convert dict format to array format
|
||||
value_selector = var.get("value_selector") or var.get("selector") or var.get("path")
|
||||
if isinstance(value_selector, list) and len(value_selector) > 0:
|
||||
repaired_variables.append(value_selector)
|
||||
repaired = True
|
||||
else:
|
||||
# Try to extract from name field - LLM may generate {"name": "node_id.field"}
|
||||
name = var.get("name")
|
||||
if isinstance(name, str) and "." in name:
|
||||
# Try to parse "node_id.field" format
|
||||
parts = name.split(".", 1)
|
||||
if len(parts) == 2:
|
||||
repaired_variables.append([parts[0], parts[1]])
|
||||
repaired = True
|
||||
else:
|
||||
logger.warning(
|
||||
"Variable aggregator node '%s' has invalid variable format: %s",
|
||||
node_id,
|
||||
var,
|
||||
)
|
||||
repaired_variables.append([]) # Empty array as fallback
|
||||
else:
|
||||
# If no valid selector or name, skip this variable
|
||||
logger.warning(
|
||||
"Variable aggregator node '%s' has invalid variable format: %s",
|
||||
node_id,
|
||||
var,
|
||||
)
|
||||
# Don't add empty array - skip invalid variables
|
||||
elif isinstance(var, list):
|
||||
# Already in correct format
|
||||
repaired_variables.append(var)
|
||||
else:
|
||||
# Unknown format, skip
|
||||
logger.warning("Variable aggregator node '%s' has unknown variable format: %s", node_id, var)
|
||||
# Don't add empty array - skip invalid variables
|
||||
|
||||
if repaired:
|
||||
config["variables"] = repaired_variables
|
||||
repairs.append(f"Repaired variable-aggregator variables format in node '{node_id}'")
|
||||
|
||||
@classmethod
|
||||
def _repair_code_node_config(cls, node: WorkflowNodeDict, repairs: list[str], llm_callback=None):
|
||||
"""
|
||||
Repair code node configuration (outputs and variables).
|
||||
1. Outputs: Converts list format to dict format AND normalizes types.
|
||||
2. Variables: Ensures value_selector exists.
|
||||
"""
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
if "variables" not in config:
|
||||
config["variables"] = []
|
||||
|
||||
# --- Repair Variables ---
|
||||
variables = config.get("variables")
|
||||
if isinstance(variables, list):
|
||||
for var in variables:
|
||||
if isinstance(var, dict):
|
||||
# Ensure value_selector exists (frontend crashes if missing)
|
||||
if "value_selector" not in var:
|
||||
var["value_selector"] = []
|
||||
# Not logging trivial repairs
|
||||
|
||||
# --- Repair Outputs ---
|
||||
outputs = config.get("outputs")
|
||||
|
||||
if not outputs:
|
||||
return
|
||||
|
||||
# Helper to normalize type
|
||||
def normalize_type(t: str) -> str:
|
||||
t_lower = str(t).lower()
|
||||
return cls.TYPE_MAPPING.get(t_lower, t)
|
||||
|
||||
# 1. Handle Dict format (Standard) - Check for invalid types
|
||||
if isinstance(outputs, dict):
|
||||
changed = False
|
||||
for var_name, var_config in outputs.items():
|
||||
if isinstance(var_config, dict):
|
||||
original_type = var_config.get("type")
|
||||
if original_type:
|
||||
new_type = normalize_type(original_type)
|
||||
if new_type != original_type:
|
||||
var_config["type"] = new_type
|
||||
changed = True
|
||||
repairs.append(
|
||||
f"Normalized type '{original_type}' to '{new_type}' "
|
||||
f"for var '{var_name}' in node '{node_id}'"
|
||||
)
|
||||
return
|
||||
|
||||
# 2. Handle List format (Repair needed)
|
||||
if isinstance(outputs, list):
|
||||
new_outputs = {}
|
||||
for item in outputs:
|
||||
if isinstance(item, dict):
|
||||
var_name = item.get("variable") or item.get("name")
|
||||
var_type = item.get("type")
|
||||
if var_name and var_type:
|
||||
norm_type = normalize_type(var_type)
|
||||
new_outputs[var_name] = {"type": norm_type}
|
||||
if norm_type != var_type:
|
||||
repairs.append(
|
||||
f"Normalized type '{var_type}' to '{norm_type}' "
|
||||
f"during list conversion in node '{node_id}'"
|
||||
)
|
||||
|
||||
if new_outputs:
|
||||
config["outputs"] = new_outputs
|
||||
repairs.append(f"Repaired code node outputs format in node '{node_id}'")
|
||||
else:
|
||||
# Fallback: Try LLM if available
|
||||
if llm_callback:
|
||||
try:
|
||||
# Attempt to fix using LLM
|
||||
fixed_outputs = llm_callback(
|
||||
node,
|
||||
"outputs must be a dictionary like {'var_name': {'type': 'string'}}, "
|
||||
"but got a list or valid conversion failed.",
|
||||
)
|
||||
if isinstance(fixed_outputs, dict) and fixed_outputs:
|
||||
config["outputs"] = fixed_outputs
|
||||
repairs.append(f"Repaired code node outputs format using LLM in node '{node_id}'")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning("LLM fallback repair failed for node '%s': %s", node_id, e)
|
||||
|
||||
# If conversion/LLM failed, set to empty dict
|
||||
config["outputs"] = {}
|
||||
repairs.append(f"Reset invalid code node outputs to empty dict in node '{node_id}'")
|
||||
@ -1,101 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.generator.types import AvailableModelDict, AvailableToolDict, WorkflowDataDict
|
||||
from core.workflow.generator.validation.context import ValidationContext
|
||||
from core.workflow.generator.validation.engine import ValidationEngine
|
||||
from core.workflow.generator.validation.rules import Severity
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationHint:
|
||||
"""Legacy compatibility class for validation hints."""
|
||||
|
||||
node_id: str
|
||||
field: str
|
||||
message: str
|
||||
severity: str # 'error', 'warning'
|
||||
suggestion: str = None
|
||||
node_type: str = None # Added for test compatibility
|
||||
|
||||
# Alias for potential old code using 'type' instead of 'severity'
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.severity
|
||||
|
||||
@property
|
||||
def element_id(self) -> str:
|
||||
return self.node_id
|
||||
|
||||
|
||||
FriendlyHint = ValidationHint # Alias for backward compatibility
|
||||
|
||||
|
||||
class WorkflowValidator:
|
||||
"""
|
||||
Validates the generated workflow configuration (nodes and edges).
|
||||
Wraps the new ValidationEngine for backward compatibility.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def validate(
|
||||
cls,
|
||||
workflow_data: WorkflowDataDict,
|
||||
available_tools: list[AvailableToolDict],
|
||||
available_models: list[AvailableModelDict] | None = None,
|
||||
) -> tuple[bool, list[ValidationHint]]:
|
||||
"""
|
||||
Validate workflow data and return validity status and hints.
|
||||
|
||||
Args:
|
||||
workflow_data: Dict containing 'nodes' and 'edges'
|
||||
available_tools: List of available tool configurations
|
||||
available_models: List of available models (added for Vibe compat)
|
||||
|
||||
Returns:
|
||||
Tuple(max_severity_is_not_error, list_of_hints)
|
||||
"""
|
||||
nodes = workflow_data.get("nodes", [])
|
||||
edges = workflow_data.get("edges", [])
|
||||
|
||||
# Create context
|
||||
context = ValidationContext(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
available_models=available_models or [],
|
||||
available_tools=available_tools or [],
|
||||
)
|
||||
|
||||
# Run validation engine
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(context)
|
||||
|
||||
# Convert engine errors to legacy hints
|
||||
hints: list[ValidationHint] = []
|
||||
|
||||
error_count = 0
|
||||
warning_count = 0
|
||||
|
||||
for error in result.all_errors:
|
||||
# Map severity
|
||||
severity = "error" if error.severity == Severity.ERROR else "warning"
|
||||
|
||||
if severity == "error":
|
||||
error_count += 1
|
||||
else:
|
||||
warning_count += 1
|
||||
|
||||
# Map field from message or details if possible (heuristic)
|
||||
field_name = error.details.get("field", "unknown")
|
||||
|
||||
hints.append(
|
||||
ValidationHint(
|
||||
node_id=error.node_id,
|
||||
field=field_name,
|
||||
message=error.message,
|
||||
severity=severity,
|
||||
suggestion=error.fix_hint,
|
||||
node_type=error.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
return result.is_valid, hints
|
||||
@ -1,42 +0,0 @@
|
||||
"""
|
||||
Validation Rule Engine for Vibe Workflow Generation.
|
||||
|
||||
This module provides a declarative, schema-based validation system for
|
||||
generated workflow nodes. It classifies errors into fixable (LLM can auto-fix)
|
||||
and user-required (needs manual intervention) categories.
|
||||
|
||||
Usage:
|
||||
from core.workflow.generator.validation import ValidationEngine, ValidationContext
|
||||
|
||||
context = ValidationContext(
|
||||
available_models=[...],
|
||||
available_tools=[...],
|
||||
nodes=[...],
|
||||
edges=[...],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(context)
|
||||
|
||||
# Access classified errors
|
||||
fixable_errors = result.fixable_errors
|
||||
user_required_errors = result.user_required_errors
|
||||
"""
|
||||
|
||||
from core.workflow.generator.validation.context import ValidationContext
|
||||
from core.workflow.generator.validation.engine import ValidationEngine, ValidationResult
|
||||
from core.workflow.generator.validation.rules import (
|
||||
RuleCategory,
|
||||
Severity,
|
||||
ValidationError,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RuleCategory",
|
||||
"Severity",
|
||||
"ValidationContext",
|
||||
"ValidationEngine",
|
||||
"ValidationError",
|
||||
"ValidationResult",
|
||||
"ValidationRule",
|
||||
]
|
||||
@ -1,115 +0,0 @@
|
||||
"""
|
||||
Validation Context for the Rule Engine.
|
||||
|
||||
The ValidationContext holds all the data needed for validation:
|
||||
- Generated nodes and edges
|
||||
- Available models, tools, and datasets
|
||||
- Node output schemas for variable reference validation
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.workflow.generator.types import (
|
||||
AvailableModelDict,
|
||||
AvailableToolDict,
|
||||
WorkflowEdgeDict,
|
||||
WorkflowNodeDict,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationContext:
|
||||
"""
|
||||
Context object containing all data needed for validation.
|
||||
|
||||
This is passed to each validation rule, providing access to:
|
||||
- The nodes being validated
|
||||
- Edge connections between nodes
|
||||
- Available external resources (models, tools)
|
||||
"""
|
||||
|
||||
# Generated workflow data
|
||||
nodes: list[WorkflowNodeDict] = field(default_factory=list)
|
||||
edges: list[WorkflowEdgeDict] = field(default_factory=list)
|
||||
|
||||
# Available external resources
|
||||
available_models: list[AvailableModelDict] = field(default_factory=list)
|
||||
available_tools: list[AvailableToolDict] = field(default_factory=list)
|
||||
|
||||
# Cached lookups (populated lazily)
|
||||
_node_map: dict[str, WorkflowNodeDict] | None = field(default=None, repr=False)
|
||||
_model_set: set[tuple[str, str]] | None = field(default=None, repr=False)
|
||||
_tool_set: set[str] | None = field(default=None, repr=False)
|
||||
_configured_tool_set: set[str] | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def node_map(self) -> dict[str, WorkflowNodeDict]:
|
||||
"""Get a map of node_id -> node for quick lookup."""
|
||||
if self._node_map is None:
|
||||
self._node_map = {node.get("id", ""): node for node in self.nodes}
|
||||
return self._node_map
|
||||
|
||||
@property
|
||||
def model_set(self) -> set[tuple[str, str]]:
|
||||
"""Get a set of (provider, model_name) tuples for quick lookup."""
|
||||
if self._model_set is None:
|
||||
self._model_set = {(m.get("provider", ""), m.get("model", "")) for m in self.available_models}
|
||||
return self._model_set
|
||||
|
||||
@property
|
||||
def tool_set(self) -> set[str]:
|
||||
"""Get a set of all tool keys (both configured and unconfigured)."""
|
||||
if self._tool_set is None:
|
||||
self._tool_set = set()
|
||||
for tool in self.available_tools:
|
||||
provider = tool.get("provider_id") or tool.get("provider", "")
|
||||
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
|
||||
if provider and tool_key:
|
||||
self._tool_set.add(f"{provider}/{tool_key}")
|
||||
if tool_key:
|
||||
self._tool_set.add(tool_key)
|
||||
return self._tool_set
|
||||
|
||||
@property
|
||||
def configured_tool_set(self) -> set[str]:
|
||||
"""Get a set of configured (authorized) tool keys."""
|
||||
if self._configured_tool_set is None:
|
||||
self._configured_tool_set = set()
|
||||
for tool in self.available_tools:
|
||||
if not tool.get("is_team_authorization", False):
|
||||
continue
|
||||
provider = tool.get("provider_id") or tool.get("provider", "")
|
||||
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
|
||||
if provider and tool_key:
|
||||
self._configured_tool_set.add(f"{provider}/{tool_key}")
|
||||
if tool_key:
|
||||
self._configured_tool_set.add(tool_key)
|
||||
return self._configured_tool_set
|
||||
|
||||
def has_model(self, provider: str, model_name: str) -> bool:
|
||||
"""Check if a model is available."""
|
||||
return (provider, model_name) in self.model_set
|
||||
|
||||
def has_tool(self, tool_key: str) -> bool:
|
||||
"""Check if a tool exists (configured or not)."""
|
||||
return tool_key in self.tool_set
|
||||
|
||||
def is_tool_configured(self, tool_key: str) -> bool:
|
||||
"""Check if a tool is configured and ready to use."""
|
||||
return tool_key in self.configured_tool_set
|
||||
|
||||
def get_node(self, node_id: str) -> WorkflowNodeDict | None:
|
||||
"""Get a node by its ID."""
|
||||
return self.node_map.get(node_id)
|
||||
|
||||
def get_node_ids(self) -> set[str]:
|
||||
"""Get all node IDs in the workflow."""
|
||||
return set(self.node_map.keys())
|
||||
|
||||
def get_upstream_nodes(self, node_id: str) -> list[str]:
|
||||
"""Get IDs of nodes that connect to this node (upstream)."""
|
||||
return [edge.get("source", "") for edge in self.edges if edge.get("target") == node_id]
|
||||
|
||||
def get_downstream_nodes(self, node_id: str) -> list[str]:
|
||||
"""Get IDs of nodes that this node connects to (downstream)."""
|
||||
return [edge.get("target", "") for edge in self.edges if edge.get("source") == node_id]
|
||||
@ -1,260 +0,0 @@
|
||||
"""
|
||||
Validation Engine - Core validation logic.
|
||||
|
||||
The ValidationEngine orchestrates rule execution and aggregates results.
|
||||
It provides a clean interface for validating workflow nodes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.generator.types import (
|
||||
AvailableModelDict,
|
||||
AvailableToolDict,
|
||||
WorkflowEdgeDict,
|
||||
WorkflowNodeDict,
|
||||
)
|
||||
from core.workflow.generator.validation.context import ValidationContext
|
||||
from core.workflow.generator.validation.rules import (
|
||||
RuleCategory,
|
||||
Severity,
|
||||
ValidationError,
|
||||
get_registry,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""
|
||||
Result of validation containing all errors classified by fixability.
|
||||
|
||||
Attributes:
|
||||
all_errors: All validation errors found
|
||||
fixable_errors: Errors that LLM can automatically fix
|
||||
user_required_errors: Errors that require user intervention
|
||||
warnings: Non-blocking warnings
|
||||
stats: Validation statistics
|
||||
"""
|
||||
|
||||
all_errors: list[ValidationError] = field(default_factory=list)
|
||||
fixable_errors: list[ValidationError] = field(default_factory=list)
|
||||
user_required_errors: list[ValidationError] = field(default_factory=list)
|
||||
warnings: list[ValidationError] = field(default_factory=list)
|
||||
stats: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def has_errors(self) -> bool:
|
||||
"""Check if there are any errors (excluding warnings)."""
|
||||
return len(self.fixable_errors) > 0 or len(self.user_required_errors) > 0
|
||||
|
||||
@property
|
||||
def has_fixable_errors(self) -> bool:
|
||||
"""Check if there are fixable errors."""
|
||||
return len(self.fixable_errors) > 0
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if validation passed (no errors, warnings are OK)."""
|
||||
return not self.has_errors
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"fixable": [e.to_dict() for e in self.fixable_errors],
|
||||
"user_required": [e.to_dict() for e in self.user_required_errors],
|
||||
"warnings": [e.to_dict() for e in self.warnings],
|
||||
"all_warnings": [e.message for e in self.all_errors],
|
||||
"stats": self.stats,
|
||||
}
|
||||
|
||||
def get_error_messages(self) -> list[str]:
|
||||
"""Get all error messages as strings."""
|
||||
return [e.message for e in self.all_errors]
|
||||
|
||||
def get_fixable_by_node(self) -> dict[str, list[ValidationError]]:
|
||||
"""Group fixable errors by node ID."""
|
||||
result: dict[str, list[ValidationError]] = {}
|
||||
for error in self.fixable_errors:
|
||||
if error.node_id not in result:
|
||||
result[error.node_id] = []
|
||||
result[error.node_id].append(error)
|
||||
return result
|
||||
|
||||
|
||||
class ValidationEngine:
|
||||
"""
|
||||
The main validation engine.
|
||||
|
||||
Usage:
|
||||
engine = ValidationEngine()
|
||||
context = ValidationContext(nodes=[...], available_models=[...])
|
||||
result = engine.validate(context)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._registry = get_registry()
|
||||
|
||||
def validate(self, context: ValidationContext) -> ValidationResult:
|
||||
"""
|
||||
Validate all nodes in the context.
|
||||
|
||||
Args:
|
||||
context: ValidationContext with nodes, edges, and available resources
|
||||
|
||||
Returns:
|
||||
ValidationResult with classified errors
|
||||
"""
|
||||
result = ValidationResult()
|
||||
stats = {
|
||||
"total_nodes": len(context.nodes),
|
||||
"total_rules_checked": 0,
|
||||
"total_errors": 0,
|
||||
"fixable_count": 0,
|
||||
"user_required_count": 0,
|
||||
"warning_count": 0,
|
||||
}
|
||||
|
||||
# Validate each node
|
||||
for node in context.nodes:
|
||||
node_type = node.get("type", "unknown")
|
||||
node_id = node.get("id", "unknown")
|
||||
|
||||
# Get applicable rules for this node type
|
||||
rules = self._registry.get_rules_for_node(node_type)
|
||||
|
||||
for rule in rules:
|
||||
stats["total_rules_checked"] += 1
|
||||
|
||||
try:
|
||||
errors = rule.check(node, context)
|
||||
for error in errors:
|
||||
result.all_errors.append(error)
|
||||
stats["total_errors"] += 1
|
||||
|
||||
# Classify by severity and fixability
|
||||
if error.severity == Severity.WARNING:
|
||||
result.warnings.append(error)
|
||||
stats["warning_count"] += 1
|
||||
elif error.is_fixable:
|
||||
result.fixable_errors.append(error)
|
||||
stats["fixable_count"] += 1
|
||||
else:
|
||||
result.user_required_errors.append(error)
|
||||
stats["user_required_count"] += 1
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Rule '%s' failed for node '%s'",
|
||||
rule.id,
|
||||
node_id,
|
||||
)
|
||||
# Don't let a rule failure break the entire validation
|
||||
continue
|
||||
|
||||
# Validate edges separately
|
||||
edge_errors = self._validate_edges(context)
|
||||
for error in edge_errors:
|
||||
result.all_errors.append(error)
|
||||
stats["total_errors"] += 1
|
||||
if error.is_fixable:
|
||||
result.fixable_errors.append(error)
|
||||
stats["fixable_count"] += 1
|
||||
else:
|
||||
result.user_required_errors.append(error)
|
||||
stats["user_required_count"] += 1
|
||||
|
||||
result.stats = stats
|
||||
|
||||
return result
|
||||
|
||||
def _validate_edges(self, context: ValidationContext) -> list[ValidationError]:
|
||||
"""Validate edge connections."""
|
||||
errors: list[ValidationError] = []
|
||||
valid_node_ids = context.get_node_ids()
|
||||
|
||||
for edge in context.edges:
|
||||
source = edge.get("source", "")
|
||||
target = edge.get("target", "")
|
||||
|
||||
if source and source not in valid_node_ids:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="edge.source.invalid",
|
||||
node_id=source,
|
||||
node_type="edge",
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Edge source '{source}' does not exist",
|
||||
fix_hint="Update edge to reference existing node",
|
||||
)
|
||||
)
|
||||
|
||||
if target and target not in valid_node_ids:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="edge.target.invalid",
|
||||
node_id=target,
|
||||
node_type="edge",
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Edge target '{target}' does not exist",
|
||||
fix_hint="Update edge to reference existing node",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
def validate_single_node(
|
||||
self,
|
||||
node: WorkflowNodeDict,
|
||||
context: ValidationContext,
|
||||
) -> list[ValidationError]:
|
||||
"""
|
||||
Validate a single node.
|
||||
|
||||
Useful for incremental validation when a node is added/modified.
|
||||
"""
|
||||
node_type = node.get("type", "unknown")
|
||||
rules = self._registry.get_rules_for_node(node_type)
|
||||
|
||||
errors: list[ValidationError] = []
|
||||
for rule in rules:
|
||||
try:
|
||||
errors.extend(rule.check(node, context))
|
||||
except Exception:
|
||||
logger.exception("Rule '%s' failed", rule.id)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_nodes(
|
||||
nodes: list[WorkflowNodeDict],
|
||||
edges: list[WorkflowEdgeDict] | None = None,
|
||||
available_models: list[AvailableModelDict] | None = None,
|
||||
available_tools: list[AvailableToolDict] | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Convenience function to validate nodes without creating engine/context manually.
|
||||
|
||||
Args:
|
||||
nodes: List of workflow nodes to validate
|
||||
edges: Optional list of edges
|
||||
available_models: Optional list of available models
|
||||
available_tools: Optional list of available tools
|
||||
|
||||
Returns:
|
||||
ValidationResult with classified errors
|
||||
"""
|
||||
context = ValidationContext(
|
||||
nodes=nodes,
|
||||
edges=edges or [],
|
||||
available_models=available_models or [],
|
||||
available_tools=available_tools or [],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
return engine.validate(context)
|
||||
@ -1,947 +0,0 @@
|
||||
"""
|
||||
Validation Rules Definition and Registry.
|
||||
|
||||
This module defines:
|
||||
- ValidationRule: The rule structure
|
||||
- RuleCategory: Categories of validation rules
|
||||
- Severity: Error severity levels
|
||||
- ValidationError: Error output structure
|
||||
- All built-in validation rules
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.generator.types import WorkflowNodeDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.generator.validation.context import ValidationContext
|
||||
|
||||
|
||||
class RuleCategory(Enum):
|
||||
"""Categories of validation rules."""
|
||||
|
||||
STRUCTURE = "structure" # Field existence, types, formats
|
||||
SEMANTIC = "semantic" # Variable references, edge connections
|
||||
REFERENCE = "reference" # External resources (models, tools, datasets)
|
||||
|
||||
|
||||
class Severity(Enum):
|
||||
"""Severity levels for validation errors."""
|
||||
|
||||
ERROR = "error" # Must be fixed
|
||||
WARNING = "warning" # Should be fixed but not blocking
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationError:
|
||||
"""
|
||||
Represents a validation error found during rule execution.
|
||||
|
||||
Attributes:
|
||||
rule_id: The ID of the rule that generated this error
|
||||
node_id: The ID of the node with the error
|
||||
node_type: The type of the node
|
||||
category: The rule category
|
||||
severity: Error severity
|
||||
is_fixable: Whether LLM can auto-fix this error
|
||||
message: Human-readable error message
|
||||
fix_hint: Hint for LLM to fix the error
|
||||
details: Additional error details
|
||||
"""
|
||||
|
||||
rule_id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
category: RuleCategory
|
||||
severity: Severity
|
||||
is_fixable: bool
|
||||
message: str
|
||||
fix_hint: str = ""
|
||||
details: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"rule_id": self.rule_id,
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
"category": self.category.value,
|
||||
"severity": self.severity.value,
|
||||
"is_fixable": self.is_fixable,
|
||||
"message": self.message,
|
||||
"fix_hint": self.fix_hint,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
|
||||
# Type alias for rule check functions
|
||||
RuleCheckFn = Callable[
|
||||
[WorkflowNodeDict, "ValidationContext"],
|
||||
list[ValidationError],
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationRule:
|
||||
"""
|
||||
A validation rule definition.
|
||||
|
||||
Attributes:
|
||||
id: Unique rule identifier (e.g., "llm.model.required")
|
||||
node_types: List of node types this rule applies to, or ["*"] for all
|
||||
category: The rule category
|
||||
severity: Default severity for errors from this rule
|
||||
is_fixable: Whether errors from this rule can be auto-fixed by LLM
|
||||
check: The validation function
|
||||
description: Human-readable description of what this rule checks
|
||||
fix_hint: Default hint for fixing errors from this rule
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_types: list[str]
|
||||
category: RuleCategory
|
||||
severity: Severity
|
||||
is_fixable: bool
|
||||
check: RuleCheckFn
|
||||
description: str = ""
|
||||
fix_hint: str = ""
|
||||
|
||||
def applies_to(self, node_type: str) -> bool:
|
||||
"""Check if this rule applies to a given node type."""
|
||||
return "*" in self.node_types or node_type in self.node_types
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rule Registry
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class RuleRegistry:
|
||||
"""
|
||||
Registry for validation rules.
|
||||
|
||||
Rules are registered here and can be retrieved by category or node type.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._rules: list[ValidationRule] = []
|
||||
|
||||
def register(self, rule: ValidationRule) -> None:
|
||||
"""Register a validation rule."""
|
||||
self._rules.append(rule)
|
||||
|
||||
def get_rules_for_node(self, node_type: str) -> list[ValidationRule]:
|
||||
"""Get all rules that apply to a given node type."""
|
||||
return [r for r in self._rules if r.applies_to(node_type)]
|
||||
|
||||
def get_rules_by_category(self, category: RuleCategory) -> list[ValidationRule]:
|
||||
"""Get all rules in a given category."""
|
||||
return [r for r in self._rules if r.category == category]
|
||||
|
||||
def get_all_rules(self) -> list[ValidationRule]:
|
||||
"""Get all registered rules."""
|
||||
return list(self._rules)
|
||||
|
||||
|
||||
# Global rule registry instance
|
||||
_registry = RuleRegistry()
|
||||
|
||||
|
||||
def register_rule(rule: ValidationRule) -> ValidationRule:
|
||||
"""Decorator/function to register a rule with the global registry."""
|
||||
_registry.register(rule)
|
||||
return rule
|
||||
|
||||
|
||||
def get_registry() -> RuleRegistry:
|
||||
"""Get the global rule registry."""
|
||||
return _registry
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions for Rule Implementations
|
||||
# =============================================================================
|
||||
|
||||
# Explicit placeholder value defined in prompt contract
|
||||
# See: api/core/workflow/generator/prompts/vibe_prompts.py
|
||||
PLACEHOLDER_VALUE = "__PLACEHOLDER__"
|
||||
|
||||
# Variable reference pattern: {{#node_id.field#}}
|
||||
VARIABLE_REF_PATTERN = re.compile(r"\{\{#([^.#]+)\.([^#]+)#\}\}")
|
||||
|
||||
|
||||
def is_placeholder(value: Any) -> bool:
|
||||
"""Check if a value appears to be a placeholder."""
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
return value == PLACEHOLDER_VALUE or PLACEHOLDER_VALUE in value
|
||||
|
||||
|
||||
def extract_variable_refs(text: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Extract variable references from text.
|
||||
|
||||
Returns list of (node_id, field_name) tuples.
|
||||
"""
|
||||
return VARIABLE_REF_PATTERN.findall(text)
|
||||
|
||||
|
||||
def check_required_field(
|
||||
config: dict[str, Any],
|
||||
field_name: str,
|
||||
node_id: str,
|
||||
node_type: str,
|
||||
rule_id: str,
|
||||
fix_hint: str = "",
|
||||
) -> ValidationError | None:
|
||||
"""Helper to check if a required field exists and is non-empty."""
|
||||
value = config.get(field_name)
|
||||
if value is None or value == "" or (isinstance(value, list) and len(value) == 0):
|
||||
return ValidationError(
|
||||
rule_id=rule_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': missing required field '{field_name}'",
|
||||
fix_hint=fix_hint or f"Add '{field_name}' to the node config",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Structure Rules - Field existence, types, formats
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _check_llm_prompt_template(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that LLM node has prompt_template."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
err = check_required_field(
|
||||
config,
|
||||
"prompt_template",
|
||||
node_id,
|
||||
"llm",
|
||||
"llm.prompt_template.required",
|
||||
"Add prompt_template with system and user messages",
|
||||
)
|
||||
if err:
|
||||
errors.append(err)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_http_request_url(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that http-request node has url and method."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
# Check url
|
||||
url = config.get("url", "")
|
||||
if not url:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="http.url.required",
|
||||
node_id=node_id,
|
||||
node_type="http-request",
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': http-request missing required 'url'",
|
||||
fix_hint="Add url - use {{#start.url#}} or a concrete URL",
|
||||
)
|
||||
)
|
||||
elif is_placeholder(url):
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="http.url.placeholder",
|
||||
node_id=node_id,
|
||||
node_type="http-request",
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': url contains placeholder value",
|
||||
fix_hint="Replace placeholder with actual URL or variable reference",
|
||||
)
|
||||
)
|
||||
|
||||
# Check method
|
||||
method = config.get("method", "")
|
||||
if not method:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="http.method.required",
|
||||
node_id=node_id,
|
||||
node_type="http-request",
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': http-request missing 'method'",
|
||||
fix_hint="Add method: GET, POST, PUT, DELETE, or PATCH",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_code_node(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that code node has code and language."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
err = check_required_field(
|
||||
config,
|
||||
"code",
|
||||
node_id,
|
||||
"code",
|
||||
"code.code.required",
|
||||
"Add code with a main() function that returns a dict",
|
||||
)
|
||||
if err:
|
||||
errors.append(err)
|
||||
|
||||
err = check_required_field(
|
||||
config,
|
||||
"language",
|
||||
node_id,
|
||||
"code",
|
||||
"code.language.required",
|
||||
"Add language: python3 or javascript",
|
||||
)
|
||||
if err:
|
||||
errors.append(err)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_question_classifier(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that question-classifier has classes."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
err = check_required_field(
|
||||
config,
|
||||
"classes",
|
||||
node_id,
|
||||
"question-classifier",
|
||||
"classifier.classes.required",
|
||||
"Add classes array with id and name for each classification",
|
||||
)
|
||||
if err:
|
||||
errors.append(err)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_parameter_extractor(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that parameter-extractor has parameters and instruction."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
err = check_required_field(
|
||||
config,
|
||||
"parameters",
|
||||
node_id,
|
||||
"parameter-extractor",
|
||||
"extractor.parameters.required",
|
||||
"Add parameters array with name, type, description fields",
|
||||
)
|
||||
if err:
|
||||
errors.append(err)
|
||||
else:
|
||||
# Check individual parameters for required fields
|
||||
parameters = config.get("parameters", [])
|
||||
if isinstance(parameters, list):
|
||||
for i, param in enumerate(parameters):
|
||||
if isinstance(param, dict):
|
||||
# Check for 'required' field (boolean)
|
||||
if "required" not in param:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="extractor.param.required_field.missing",
|
||||
node_id=node_id,
|
||||
node_type="parameter-extractor",
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': parameter[{i}] missing 'required' field",
|
||||
fix_hint=f"Add 'required': True to parameter '{param.get('name', 'unknown')}'",
|
||||
details={"param_index": i, "param_name": param.get("name")},
|
||||
)
|
||||
)
|
||||
|
||||
# instruction is recommended but not strictly required
|
||||
if not config.get("instruction"):
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="extractor.instruction.recommended",
|
||||
node_id=node_id,
|
||||
node_type="parameter-extractor",
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.WARNING,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': parameter-extractor should have 'instruction'",
|
||||
fix_hint="Add instruction describing what to extract",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_knowledge_retrieval(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that knowledge-retrieval has dataset_ids."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
dataset_ids = config.get("dataset_ids", [])
|
||||
if not dataset_ids:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="knowledge.dataset.required",
|
||||
node_id=node_id,
|
||||
node_type="knowledge-retrieval",
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=False, # User must select knowledge base
|
||||
message=f"Node '{node_id}': knowledge-retrieval missing 'dataset_ids'",
|
||||
fix_hint="User must select knowledge bases in the UI",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Check for placeholder values
|
||||
for ds_id in dataset_ids:
|
||||
if is_placeholder(ds_id):
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="knowledge.dataset.placeholder",
|
||||
node_id=node_id,
|
||||
node_type="knowledge-retrieval",
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=False,
|
||||
message=f"Node '{node_id}': dataset_ids contains placeholder",
|
||||
fix_hint="User must replace placeholder with actual knowledge base ID",
|
||||
details={"placeholder_value": ds_id},
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_end_node(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that end node has outputs defined."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
outputs = config.get("outputs", [])
|
||||
if not outputs:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="end.outputs.recommended",
|
||||
node_id=node_id,
|
||||
node_type="end",
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.WARNING,
|
||||
is_fixable=True,
|
||||
message="End node should define output variables",
|
||||
fix_hint="Add outputs array with variable and value_selector",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Semantic Rules - Variable references, edge connections
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _check_variable_references(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that variable references point to valid nodes."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
node_type = node.get("type", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
# Get all valid node IDs (including 'start' which is always valid)
|
||||
valid_node_ids = ctx.get_node_ids()
|
||||
valid_node_ids.add("start")
|
||||
valid_node_ids.add("sys") # System variables
|
||||
|
||||
def check_text_for_refs(text: str, field_path: str) -> None:
|
||||
if not isinstance(text, str):
|
||||
return
|
||||
refs = extract_variable_refs(text)
|
||||
for ref_node_id, ref_field in refs:
|
||||
if ref_node_id not in valid_node_ids:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="variable.ref.invalid_node",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': references non-existent node '{ref_node_id}'",
|
||||
fix_hint=f"Change {{{{#{ref_node_id}.{ref_field}#}}}} to reference a valid node",
|
||||
details={"field_path": field_path, "invalid_ref": ref_node_id},
|
||||
)
|
||||
)
|
||||
|
||||
# Check prompt_template for LLM nodes
|
||||
prompt_template = config.get("prompt_template", [])
|
||||
if isinstance(prompt_template, list):
|
||||
for i, msg in enumerate(prompt_template):
|
||||
if isinstance(msg, dict):
|
||||
text = msg.get("text", "")
|
||||
check_text_for_refs(text, f"prompt_template[{i}].text")
|
||||
|
||||
# Check instruction field
|
||||
instruction = config.get("instruction", "")
|
||||
check_text_for_refs(instruction, "instruction")
|
||||
|
||||
# Check url for http-request
|
||||
url = config.get("url", "")
|
||||
check_text_for_refs(url, "url")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# NOTE: _check_node_has_outgoing_edge removed - handled by GraphValidator
|
||||
|
||||
|
||||
# NOTE: _check_node_has_incoming_edge removed - handled by GraphValidator
|
||||
|
||||
|
||||
# NOTE: _check_question_classifier_branches removed - handled by EdgeRepair
|
||||
|
||||
|
||||
# NOTE: _check_if_else_branches removed - handled by EdgeRepair
|
||||
|
||||
|
||||
def _check_if_else_operators(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that if-else comparison operators are valid."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
node_type = node.get("type", "unknown")
|
||||
|
||||
if node_type != "if-else":
|
||||
return errors
|
||||
|
||||
valid_operators = {
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
"all of",
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
"null",
|
||||
"not null",
|
||||
"exists",
|
||||
"not exists",
|
||||
}
|
||||
|
||||
config = node.get("config", {})
|
||||
cases = config.get("cases", [])
|
||||
|
||||
for case in cases:
|
||||
conditions = case.get("conditions", [])
|
||||
for condition in conditions:
|
||||
op = condition.get("comparison_operator")
|
||||
if op and op not in valid_operators:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="ifelse.operator.invalid",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Invalid operator '{op}' in if-else node",
|
||||
fix_hint=f"Use one of: {', '.join(sorted(valid_operators))}",
|
||||
details={"invalid_operator": op, "field": "config.cases.conditions.comparison_operator"},
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_edge_targets_exist(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that edge targets reference existing nodes."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
node_type = node.get("type", "unknown")
|
||||
|
||||
valid_node_ids = ctx.get_node_ids()
|
||||
|
||||
# Check all outgoing edges from this node
|
||||
for edge in ctx.edges:
|
||||
if edge.get("source") == node_id:
|
||||
target = edge.get("target")
|
||||
if target and target not in valid_node_ids:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="edge.target.invalid",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Edge from '{node_id}' targets non-existent node '{target}'",
|
||||
fix_hint=f"Change edge target from '{target}' to an existing node",
|
||||
details={"invalid_target": target, "field": "edges"},
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Reference Rules - External resources (models, tools, datasets)
|
||||
# =============================================================================
|
||||
|
||||
# Node types that require model configuration
|
||||
MODEL_REQUIRED_NODE_TYPES = {"llm", "question-classifier", "parameter-extractor"}
|
||||
|
||||
|
||||
def _check_model_config(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that model configuration is valid."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
node_type = node.get("type", "unknown")
|
||||
config = node.get("config", {})
|
||||
|
||||
if node_type not in MODEL_REQUIRED_NODE_TYPES:
|
||||
return errors
|
||||
|
||||
model = config.get("model")
|
||||
|
||||
# Check if model config exists
|
||||
if not model:
|
||||
if ctx.available_models:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="model.required",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}' ({node_type}): missing required 'model' configuration",
|
||||
fix_hint="Add model config using one of the available models",
|
||||
)
|
||||
)
|
||||
else:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="model.no_available",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=False,
|
||||
message=f"Node '{node_id}' ({node_type}): needs model but no models available",
|
||||
fix_hint="User must configure a model provider first",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
# Check if model config is valid
|
||||
if isinstance(model, dict):
|
||||
provider = model.get("provider", "")
|
||||
name = model.get("name", "")
|
||||
|
||||
# Check for placeholder values
|
||||
if is_placeholder(provider) or is_placeholder(name):
|
||||
if ctx.available_models:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="model.placeholder",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': model config contains placeholder",
|
||||
fix_hint="Replace placeholder with actual model from available_models",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
# Check if model exists in available_models
|
||||
if ctx.available_models and provider and name:
|
||||
if not ctx.has_model(provider, name):
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="model.not_found",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': model '{provider}/{name}' not in available models",
|
||||
fix_hint="Replace with a model from available_models",
|
||||
details={"provider": provider, "model": name},
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_tool_reference(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
|
||||
"""Check that tool references are valid and configured."""
|
||||
errors: list[ValidationError] = []
|
||||
node_id = node.get("id", "unknown")
|
||||
node_type = node.get("type", "unknown")
|
||||
|
||||
if node_type != "tool":
|
||||
return errors
|
||||
|
||||
config = node.get("config", {})
|
||||
tool_ref = (
|
||||
config.get("tool_key")
|
||||
or config.get("tool_name")
|
||||
or config.get("provider_id", "") + "/" + config.get("tool_name", "")
|
||||
)
|
||||
|
||||
if not tool_ref:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="tool.key.required",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
message=f"Node '{node_id}': tool node missing tool_key",
|
||||
fix_hint="Add tool_key from available_tools",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
# Check if tool exists
|
||||
if not ctx.has_tool(tool_ref):
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="tool.not_found",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True, # Can be replaced with http-request fallback
|
||||
message=f"Node '{node_id}': tool '{tool_ref}' not found",
|
||||
fix_hint="Use http-request or code node as fallback",
|
||||
details={"tool_ref": tool_ref},
|
||||
)
|
||||
)
|
||||
elif not ctx.is_tool_configured(tool_ref):
|
||||
errors.append(
|
||||
ValidationError(
|
||||
rule_id="tool.not_configured",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.WARNING,
|
||||
is_fixable=False, # User needs to configure
|
||||
message=f"Node '{node_id}': tool '{tool_ref}' requires configuration",
|
||||
fix_hint="Configure the tool in Tools settings",
|
||||
details={"tool_ref": tool_ref},
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Register All Rules
|
||||
# =============================================================================
|
||||
|
||||
# Structure Rules
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="llm.prompt_template.required",
|
||||
node_types=["llm"],
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_llm_prompt_template,
|
||||
description="LLM node must have prompt_template",
|
||||
fix_hint="Add prompt_template with system and user messages",
|
||||
)
|
||||
)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="http.config.required",
|
||||
node_types=["http-request"],
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_http_request_url,
|
||||
description="HTTP request node must have url and method",
|
||||
fix_hint="Add url and method to config",
|
||||
)
|
||||
)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="code.config.required",
|
||||
node_types=["code"],
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_code_node,
|
||||
description="Code node must have code and language",
|
||||
fix_hint="Add code with main() function and language",
|
||||
)
|
||||
)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="classifier.classes.required",
|
||||
node_types=["question-classifier"],
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_question_classifier,
|
||||
description="Question classifier must have classes",
|
||||
fix_hint="Add classes array with classification options",
|
||||
)
|
||||
)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="extractor.config.required",
|
||||
node_types=["parameter-extractor"],
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_parameter_extractor,
|
||||
description="Parameter extractor must have parameters",
|
||||
fix_hint="Add parameters array",
|
||||
)
|
||||
)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="knowledge.config.required",
|
||||
node_types=["knowledge-retrieval"],
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=False,
|
||||
check=_check_knowledge_retrieval,
|
||||
description="Knowledge retrieval must have dataset_ids",
|
||||
fix_hint="User must select knowledge base",
|
||||
)
|
||||
)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="end.outputs.check",
|
||||
node_types=["end"],
|
||||
category=RuleCategory.STRUCTURE,
|
||||
severity=Severity.WARNING,
|
||||
is_fixable=True,
|
||||
check=_check_end_node,
|
||||
description="End node should have outputs",
|
||||
fix_hint="Add outputs array",
|
||||
)
|
||||
)
|
||||
|
||||
# Semantic Rules
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="variable.references.valid",
|
||||
node_types=["*"],
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_variable_references,
|
||||
description="Variable references must point to valid nodes",
|
||||
fix_hint="Fix variable reference to use valid node ID",
|
||||
)
|
||||
)
|
||||
|
||||
# Edge Validation Rules
|
||||
# NOTE: Edge connectivity and branch completeness are now handled by:
|
||||
# - GraphValidator (BFS-based reachability analysis)
|
||||
# - EdgeRepair (automatic branch edge repair)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="edge.targets.valid",
|
||||
node_types=["*"],
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_edge_targets_exist,
|
||||
description="Edge targets must reference existing nodes",
|
||||
fix_hint="Change edge target to an existing node ID",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference Rules
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="model.config.valid",
|
||||
node_types=["llm", "question-classifier", "parameter-extractor"],
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_model_config,
|
||||
description="Model configuration must be valid",
|
||||
fix_hint="Add valid model from available_models",
|
||||
)
|
||||
)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="tool.reference.valid",
|
||||
node_types=["tool"],
|
||||
category=RuleCategory.REFERENCE,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_tool_reference,
|
||||
description="Tool reference must be valid and configured",
|
||||
fix_hint="Use valid tool or fallback node",
|
||||
)
|
||||
)
|
||||
|
||||
register_rule(
|
||||
ValidationRule(
|
||||
id="ifelse.operator.valid",
|
||||
node_types=["if-else"],
|
||||
category=RuleCategory.SEMANTIC,
|
||||
severity=Severity.ERROR,
|
||||
is_fixable=True,
|
||||
check=_check_if_else_operators,
|
||||
description="If-else operators must be valid",
|
||||
fix_hint="Use standard operators like ≥, ≤, =, ≠",
|
||||
)
|
||||
)
|
||||
@ -199,14 +199,6 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_default_config_schema(cls) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get the default configuration schema for the node.
|
||||
Used for LLM generation.
|
||||
"""
|
||||
return None
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@ -11,24 +9,6 @@ class EndNode(Node[EndNodeData]):
|
||||
node_type = NodeType.END
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
@classmethod
|
||||
def get_default_config_schema(cls) -> dict[str, Any] | None:
|
||||
return {
|
||||
"description": "Workflow exit point - defines output variables",
|
||||
"required": ["outputs"],
|
||||
"parameters": {
|
||||
"outputs": {
|
||||
"type": "array",
|
||||
"description": "Output variables to return",
|
||||
"item_schema": {
|
||||
"variable": "string - output variable name",
|
||||
"type": "enum: string, number, object, array",
|
||||
"value_selector": "array - path to source value, e.g. ['node_id', 'field']",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@ -14,27 +14,6 @@ class StartNode(Node[StartNodeData]):
|
||||
node_type = NodeType.START
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
@classmethod
|
||||
def get_default_config_schema(cls) -> dict[str, Any] | None:
|
||||
return {
|
||||
"description": "Workflow entry point - defines input variables",
|
||||
"required": [],
|
||||
"parameters": {
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"description": "Input variables for the workflow",
|
||||
"item_schema": {
|
||||
"variable": "string - variable name",
|
||||
"label": "string - display label",
|
||||
"type": "enum: text-input, paragraph, number, select, file, file-list",
|
||||
"required": "boolean",
|
||||
"max_length": "number (optional)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"outputs": ["All defined variables are available as {{#start.variable_name#}}"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@ -50,19 +50,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@classmethod
|
||||
def get_default_config_schema(cls) -> dict[str, Any] | None:
|
||||
return {
|
||||
"description": "Execute an external tool",
|
||||
"required": ["provider_id", "tool_id", "tool_parameters"],
|
||||
"parameters": {
|
||||
"provider_id": {"type": "string"},
|
||||
"provider_type": {"type": "string"},
|
||||
"tool_id": {"type": "string"},
|
||||
"tool_parameters": {"type": "object"},
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run the tool node
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.12.0"
|
||||
version = "1.11.4"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@ -24,7 +24,7 @@ class TagService:
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -0,0 +1,222 @@
|
||||
import builtins
|
||||
import contextlib
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret"
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
|
||||
db.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fix_method_view_issue(monkeypatch):
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
|
||||
|
||||
|
||||
def _create_isolated_router():
|
||||
import controllers.fastopenapi
|
||||
|
||||
router_class = type(controllers.fastopenapi.console_router)
|
||||
return router_class()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patch_auth_and_router(temp_router):
|
||||
def noop(func):
|
||||
return func
|
||||
|
||||
default_user = MagicMock(has_edit_permission=True, is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
patch("controllers.fastopenapi.console_router", temp_router),
|
||||
patch("extensions.ext_fastopenapi.console_router", temp_router),
|
||||
patch("controllers.console.wraps.setup_required", side_effect=noop),
|
||||
patch("libs.login.login_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.edit_permission_required", side_effect=noop),
|
||||
patch("libs.login.current_account_with_tenant", return_value=(default_user, "tenant-id")),
|
||||
patch("configs.dify_config.EDITION", "CLOUD"),
|
||||
):
|
||||
import extensions.ext_fastopenapi
|
||||
|
||||
importlib.reload(extensions.ext_fastopenapi)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def _force_reload_module(target_module: str, alias_module: str):
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
module = importlib.import_module(target_module)
|
||||
sys.modules[alias_module] = sys.modules[target_module]
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _dedupe_routes(router):
|
||||
seen = set()
|
||||
unique_routes = []
|
||||
for path, method, endpoint in reversed(router.get_routes()):
|
||||
key = (path, method, endpoint.__name__)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique_routes.append((path, method, endpoint))
|
||||
router._routes = list(reversed(unique_routes))
|
||||
|
||||
|
||||
def _cleanup_modules(target_module: str, alias_module: str):
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tags_module_env():
|
||||
target_module = "controllers.console.tag.tags"
|
||||
alias_module = "api.controllers.console.tag.tags"
|
||||
temp_router = _create_isolated_router()
|
||||
|
||||
try:
|
||||
with _patch_auth_and_router(temp_router):
|
||||
tags_module = _force_reload_module(target_module, alias_module)
|
||||
_dedupe_routes(temp_router)
|
||||
yield tags_module
|
||||
finally:
|
||||
_cleanup_modules(target_module, alias_module)
|
||||
|
||||
|
||||
def test_list_tags_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-1", name="Alpha", type="app", binding_count=2)
|
||||
with patch("controllers.console.tag.tags.TagService.get_tags", return_value=[tag]):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.get("/console/api/tags?type=app&keyword=Alpha")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == [
|
||||
{"id": "tag-1", "name": "Alpha", "type": "app", "binding_count": 2},
|
||||
]
|
||||
|
||||
|
||||
def test_create_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-2", name="Beta", type="app")
|
||||
with patch("controllers.console.tag.tags.TagService.save_tags", return_value=tag) as mock_save:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tags", json={"name": "Beta", "type": "app"})
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"id": "tag-2",
|
||||
"name": "Beta",
|
||||
"type": "app",
|
||||
"binding_count": 0,
|
||||
}
|
||||
mock_save.assert_called_once_with({"name": "Beta", "type": "app"})
|
||||
|
||||
|
||||
def test_update_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-3", name="Gamma", type="app")
|
||||
with (
|
||||
patch("controllers.console.tag.tags.TagService.update_tags", return_value=tag) as mock_update,
|
||||
patch("controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=4),
|
||||
):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.patch(
|
||||
"/console/api/tags/11111111-1111-1111-1111-111111111111",
|
||||
json={"name": "Gamma", "type": "app"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"id": "tag-3",
|
||||
"name": "Gamma",
|
||||
"type": "app",
|
||||
"binding_count": 4,
|
||||
}
|
||||
mock_update.assert_called_once_with(
|
||||
{"name": "Gamma", "type": "app"},
|
||||
"11111111-1111-1111-1111-111111111111",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
with patch("controllers.console.tag.tags.TagService.delete_tag") as mock_delete:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.delete("/console/api/tags/11111111-1111-1111-1111-111111111111")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 204
|
||||
mock_delete.assert_called_once_with("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
|
||||
def test_create_tag_binding_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
payload = {"tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "app"}
|
||||
with patch("controllers.console.tag.tags.TagService.save_tag_binding") as mock_bind:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tag-bindings/create", json=payload)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_bind.assert_called_once_with(payload)
|
||||
|
||||
|
||||
def test_delete_tag_binding_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
payload = {"tag_id": "tag-1", "target_id": "target-1", "type": "app"}
|
||||
with patch("controllers.console.tag.tags.TagService.delete_tag_binding") as mock_unbind:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tag-bindings/remove", json=payload)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_unbind.assert_called_once_with(payload)
|
||||
@ -1,400 +0,0 @@
|
||||
"""
|
||||
Unit tests for GraphBuilder.
|
||||
|
||||
Tests the automatic graph construction from node lists with dependency declarations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.generator.utils.graph_builder import (
|
||||
CyclicDependencyError,
|
||||
GraphBuilder,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphBuilderBasic:
|
||||
"""Basic functionality tests."""
|
||||
|
||||
def test_empty_nodes_creates_minimal_workflow(self):
|
||||
"""Empty node list creates start -> end workflow."""
|
||||
result_nodes, result_edges = GraphBuilder.build_graph([])
|
||||
|
||||
assert len(result_nodes) == 2
|
||||
assert result_nodes[0]["type"] == "start"
|
||||
assert result_nodes[1]["type"] == "end"
|
||||
assert len(result_edges) == 1
|
||||
assert result_edges[0]["source"] == "start"
|
||||
assert result_edges[0]["target"] == "end"
|
||||
|
||||
def test_simple_linear_workflow(self):
|
||||
"""Simple linear workflow: start -> fetch -> process -> end."""
|
||||
nodes = [
|
||||
{"id": "fetch", "type": "http-request", "depends_on": []},
|
||||
{"id": "process", "type": "llm", "depends_on": ["fetch"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have: start + 2 user nodes + end = 4
|
||||
assert len(result_nodes) == 4
|
||||
assert result_nodes[0]["type"] == "start"
|
||||
assert result_nodes[-1]["type"] == "end"
|
||||
|
||||
# Should have: start->fetch, fetch->process, process->end = 3
|
||||
assert len(result_edges) == 3
|
||||
|
||||
# Verify edge connections
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
assert ("start", "fetch") in edge_pairs
|
||||
assert ("fetch", "process") in edge_pairs
|
||||
assert ("process", "end") in edge_pairs
|
||||
|
||||
|
||||
class TestParallelWorkflow:
|
||||
"""Tests for parallel node handling."""
|
||||
|
||||
def test_parallel_workflow(self):
|
||||
"""Parallel workflow: multiple nodes from start, merging to one."""
|
||||
nodes = [
|
||||
{"id": "api1", "type": "http-request", "depends_on": []},
|
||||
{"id": "api2", "type": "http-request", "depends_on": []},
|
||||
{"id": "merge", "type": "llm", "depends_on": ["api1", "api2"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# start should connect to both api1 and api2
|
||||
start_edges = [e for e in result_edges if e["source"] == "start"]
|
||||
assert len(start_edges) == 2
|
||||
|
||||
start_targets = {e["target"] for e in start_edges}
|
||||
assert start_targets == {"api1", "api2"}
|
||||
|
||||
# Both api1 and api2 should connect to merge
|
||||
merge_incoming = [e for e in result_edges if e["target"] == "merge"]
|
||||
assert len(merge_incoming) == 2
|
||||
|
||||
def test_multiple_terminal_nodes(self):
|
||||
"""Multiple terminal nodes all connect to end."""
|
||||
nodes = [
|
||||
{"id": "branch1", "type": "llm", "depends_on": []},
|
||||
{"id": "branch2", "type": "llm", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Both branches should connect to end
|
||||
end_incoming = [e for e in result_edges if e["target"] == "end"]
|
||||
assert len(end_incoming) == 2
|
||||
|
||||
|
||||
class TestIfElseWorkflow:
|
||||
"""Tests for if-else branching."""
|
||||
|
||||
def test_if_else_workflow(self):
|
||||
"""Conditional branching workflow."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "check",
|
||||
"type": "if-else",
|
||||
"config": {"true_branch": "success", "false_branch": "fallback"},
|
||||
"depends_on": [],
|
||||
},
|
||||
{"id": "success", "type": "llm", "depends_on": []},
|
||||
{"id": "fallback", "type": "code", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have true and false branch edges
|
||||
branch_edges = [e for e in result_edges if e["source"] == "check"]
|
||||
assert len(branch_edges) == 2
|
||||
assert any(e.get("sourceHandle") == "true" for e in branch_edges)
|
||||
assert any(e.get("sourceHandle") == "false" for e in branch_edges)
|
||||
|
||||
# Verify targets
|
||||
true_edge = next(e for e in branch_edges if e.get("sourceHandle") == "true")
|
||||
false_edge = next(e for e in branch_edges if e.get("sourceHandle") == "false")
|
||||
assert true_edge["target"] == "success"
|
||||
assert false_edge["target"] == "fallback"
|
||||
|
||||
def test_if_else_missing_branch_no_error(self):
|
||||
"""if-else with only true branch doesn't error (warning only)."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "check",
|
||||
"type": "if-else",
|
||||
"config": {"true_branch": "success"},
|
||||
"depends_on": [],
|
||||
},
|
||||
{"id": "success", "type": "llm", "depends_on": []},
|
||||
]
|
||||
# Should not raise
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have one branch edge
|
||||
branch_edges = [e for e in result_edges if e["source"] == "check"]
|
||||
assert len(branch_edges) == 1
|
||||
assert branch_edges[0].get("sourceHandle") == "true"
|
||||
|
||||
|
||||
class TestQuestionClassifierWorkflow:
|
||||
"""Tests for question-classifier branching."""
|
||||
|
||||
def test_question_classifier_workflow(self):
|
||||
"""Question classifier with multiple classes."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"query": ["start", "user_input"],
|
||||
"classes": [
|
||||
{"id": "tech", "name": "技术问题", "target": "tech_handler"},
|
||||
{"id": "sales", "name": "销售咨询", "target": "sales_handler"},
|
||||
{"id": "other", "name": "其他问题", "target": "other_handler"},
|
||||
],
|
||||
},
|
||||
"depends_on": [],
|
||||
},
|
||||
{"id": "tech_handler", "type": "llm", "depends_on": []},
|
||||
{"id": "sales_handler", "type": "llm", "depends_on": []},
|
||||
{"id": "other_handler", "type": "llm", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have 3 branch edges from classifier
|
||||
classifier_edges = [e for e in result_edges if e["source"] == "classifier"]
|
||||
assert len(classifier_edges) == 3
|
||||
|
||||
# Each should use class id as sourceHandle
|
||||
assert any(e.get("sourceHandle") == "tech" and e["target"] == "tech_handler" for e in classifier_edges)
|
||||
assert any(e.get("sourceHandle") == "sales" and e["target"] == "sales_handler" for e in classifier_edges)
|
||||
assert any(e.get("sourceHandle") == "other" and e["target"] == "other_handler" for e in classifier_edges)
|
||||
|
||||
def test_question_classifier_missing_target(self):
|
||||
"""Classes without target connect to end."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"classes": [
|
||||
{"id": "known", "name": "已知问题", "target": "handler"},
|
||||
{"id": "unknown", "name": "未知问题"}, # Missing target
|
||||
],
|
||||
},
|
||||
"depends_on": [],
|
||||
},
|
||||
{"id": "handler", "type": "llm", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Missing target should connect to end
|
||||
classifier_edges = [e for e in result_edges if e["source"] == "classifier"]
|
||||
assert any(e.get("sourceHandle") == "unknown" and e["target"] == "end" for e in classifier_edges)
|
||||
|
||||
|
||||
class TestVariableDependencyInference:
|
||||
"""Tests for automatic dependency inference from variables."""
|
||||
|
||||
def test_variable_dependency_inference(self):
|
||||
"""Dependencies inferred from variable references."""
|
||||
nodes = [
|
||||
{"id": "fetch", "type": "http-request", "depends_on": []},
|
||||
{
|
||||
"id": "process",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"text": "{{#fetch.body#}}"}]},
|
||||
# No explicit depends_on, but references fetch
|
||||
},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should automatically infer process depends on fetch
|
||||
assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges)
|
||||
|
||||
def test_system_variable_not_inferred(self):
|
||||
"""System variables (sys, start) not inferred as dependencies."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "process",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"text": "{{#sys.query#}} {{#start.input#}}"}]},
|
||||
"depends_on": [],
|
||||
},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should connect to start, not create dependency on sys or start
|
||||
edge_sources = {e["source"] for e in result_edges}
|
||||
assert "sys" not in edge_sources
|
||||
assert "start" in edge_sources
|
||||
|
||||
|
||||
class TestCycleDetection:
|
||||
"""Tests for cyclic dependency detection."""
|
||||
|
||||
def test_cyclic_dependency_detected(self):
|
||||
"""Cyclic dependencies raise error."""
|
||||
nodes = [
|
||||
{"id": "a", "type": "llm", "depends_on": ["c"]},
|
||||
{"id": "b", "type": "llm", "depends_on": ["a"]},
|
||||
{"id": "c", "type": "llm", "depends_on": ["b"]},
|
||||
]
|
||||
|
||||
with pytest.raises(CyclicDependencyError):
|
||||
GraphBuilder.build_graph(nodes)
|
||||
|
||||
def test_self_dependency_detected(self):
|
||||
"""Self-dependency raises error."""
|
||||
nodes = [
|
||||
{"id": "a", "type": "llm", "depends_on": ["a"]},
|
||||
]
|
||||
|
||||
with pytest.raises(CyclicDependencyError):
|
||||
GraphBuilder.build_graph(nodes)
|
||||
|
||||
|
||||
class TestErrorRecovery:
|
||||
"""Tests for silent error recovery."""
|
||||
|
||||
def test_invalid_dependency_removed(self):
|
||||
"""Invalid dependencies (non-existent nodes) are silently removed."""
|
||||
nodes = [
|
||||
{"id": "process", "type": "llm", "depends_on": ["nonexistent"]},
|
||||
]
|
||||
# Should not raise, invalid dependency silently removed
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Process should connect from start (since invalid dep was removed)
|
||||
assert any(e["source"] == "start" and e["target"] == "process" for e in result_edges)
|
||||
|
||||
def test_depends_on_as_string(self):
|
||||
"""depends_on as string is converted to list."""
|
||||
nodes = [
|
||||
{"id": "fetch", "type": "http-request", "depends_on": []},
|
||||
{"id": "process", "type": "llm", "depends_on": "fetch"}, # String instead of list
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should work correctly
|
||||
assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges)
|
||||
|
||||
|
||||
class TestContainerNodes:
|
||||
"""Tests for container nodes (iteration, loop)."""
|
||||
|
||||
def test_iteration_node_as_regular_node(self):
|
||||
"""Iteration nodes behave as regular single-in-single-out nodes."""
|
||||
nodes = [
|
||||
{"id": "prepare", "type": "code", "depends_on": []},
|
||||
{
|
||||
"id": "loop",
|
||||
"type": "iteration",
|
||||
"config": {"iterator_selector": ["prepare", "items"]},
|
||||
"depends_on": ["prepare"],
|
||||
},
|
||||
{"id": "process_result", "type": "llm", "depends_on": ["loop"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should have standard edges: start->prepare, prepare->loop, loop->process_result, process_result->end
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
assert ("start", "prepare") in edge_pairs
|
||||
assert ("prepare", "loop") in edge_pairs
|
||||
assert ("loop", "process_result") in edge_pairs
|
||||
assert ("process_result", "end") in edge_pairs
|
||||
|
||||
def test_loop_node_as_regular_node(self):
|
||||
"""Loop nodes behave as regular single-in-single-out nodes."""
|
||||
nodes = [
|
||||
{"id": "init", "type": "code", "depends_on": []},
|
||||
{
|
||||
"id": "repeat",
|
||||
"type": "loop",
|
||||
"config": {"loop_count": 5},
|
||||
"depends_on": ["init"],
|
||||
},
|
||||
{"id": "finish", "type": "llm", "depends_on": ["repeat"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Standard edge flow
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
assert ("init", "repeat") in edge_pairs
|
||||
assert ("repeat", "finish") in edge_pairs
|
||||
|
||||
def test_iteration_with_variable_inference(self):
|
||||
"""Iteration node dependencies can be inferred from iterator_selector."""
|
||||
nodes = [
|
||||
{"id": "data_source", "type": "http-request", "depends_on": []},
|
||||
{
|
||||
"id": "process_each",
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
"iterator_selector": ["data_source", "items"],
|
||||
},
|
||||
# No explicit depends_on, but references data_source
|
||||
},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Should infer dependency from iterator_selector reference
|
||||
# Note: iterator_selector format is different from {{#...#}}, so this tests
|
||||
# that explicit depends_on is properly handled when not provided
|
||||
# In this case, process_each has no depends_on, so it connects to start
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
# Without explicit depends_on, connects to start
|
||||
assert ("start", "process_each") in edge_pairs or ("data_source", "process_each") in edge_pairs
|
||||
|
||||
def test_loop_node_self_reference_not_cycle(self):
|
||||
"""Loop nodes referencing their own outputs should not create cycle."""
|
||||
nodes = [
|
||||
{"id": "init", "type": "code", "depends_on": []},
|
||||
{
|
||||
"id": "my_loop",
|
||||
"type": "loop",
|
||||
"config": {
|
||||
"loop_count": 5,
|
||||
# Loop node referencing its own output (common pattern)
|
||||
"prompt": "Previous: {{#my_loop.output#}}, continue...",
|
||||
},
|
||||
"depends_on": ["init"],
|
||||
},
|
||||
{"id": "finish", "type": "llm", "depends_on": ["my_loop"]},
|
||||
]
|
||||
# Should NOT raise CyclicDependencyError
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
# Verify the graph is built correctly
|
||||
assert len(result_nodes) == 5 # start + 3 + end
|
||||
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
|
||||
assert ("init", "my_loop") in edge_pairs
|
||||
assert ("my_loop", "finish") in edge_pairs
|
||||
|
||||
|
||||
class TestEdgeStructure:
|
||||
"""Tests for edge structure correctness."""
|
||||
|
||||
def test_edge_has_required_fields(self):
|
||||
"""Edges have all required fields."""
|
||||
nodes = [
|
||||
{"id": "node1", "type": "llm", "depends_on": []},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
for edge in result_edges:
|
||||
assert "id" in edge
|
||||
assert "source" in edge
|
||||
assert "target" in edge
|
||||
assert "sourceHandle" in edge
|
||||
assert "targetHandle" in edge
|
||||
|
||||
def test_edge_id_unique(self):
|
||||
"""Each edge has a unique ID."""
|
||||
nodes = [
|
||||
{"id": "a", "type": "llm", "depends_on": []},
|
||||
{"id": "b", "type": "llm", "depends_on": []},
|
||||
{"id": "c", "type": "llm", "depends_on": ["a", "b"]},
|
||||
]
|
||||
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
|
||||
|
||||
edge_ids = [e["id"] for e in result_edges]
|
||||
assert len(edge_ids) == len(set(edge_ids)) # All unique
|
||||
@ -1,287 +0,0 @@
|
||||
"""
|
||||
Unit tests for the Mermaid Generator.
|
||||
|
||||
Tests cover:
|
||||
- Basic workflow rendering
|
||||
- Reserved word handling ('end' → 'end_node')
|
||||
- Question classifier multi-branch edges
|
||||
- If-else branch labels
|
||||
- Edge validation and skipping
|
||||
- Tool node formatting
|
||||
"""
|
||||
|
||||
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
|
||||
|
||||
|
||||
class TestBasicWorkflow:
|
||||
"""Tests for basic workflow Mermaid generation."""
|
||||
|
||||
def test_simple_start_end_workflow(self):
|
||||
"""Test simple Start → End workflow."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "flowchart TD" in result
|
||||
assert 'start["type=start|title=Start"]' in result
|
||||
assert 'end_node["type=end|title=End"]' in result
|
||||
assert "start --> end_node" in result
|
||||
|
||||
def test_start_llm_end_workflow(self):
|
||||
"""Test Start → LLM → End workflow."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "llm", "type": "llm", "title": "Generate"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "llm"},
|
||||
{"source": "llm", "target": "end"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert 'llm["type=llm|title=Generate"]' in result
|
||||
assert "start --> llm" in result
|
||||
assert "llm --> end_node" in result
|
||||
|
||||
def test_empty_workflow(self):
|
||||
"""Test empty workflow returns minimal output."""
|
||||
workflow_data = {"nodes": [], "edges": []}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert result == "flowchart TD"
|
||||
|
||||
def test_missing_keys_handled(self):
|
||||
"""Test workflow with missing keys doesn't crash."""
|
||||
workflow_data = {}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "flowchart TD" in result
|
||||
|
||||
|
||||
class TestReservedWords:
|
||||
"""Tests for reserved word handling in node IDs."""
|
||||
|
||||
def test_end_node_id_is_replaced(self):
|
||||
"""Test 'end' node ID is replaced with 'end_node'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "end", "type": "end", "title": "End"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should use end_node instead of end
|
||||
assert "end_node[" in result
|
||||
assert '"type=end|title=End"' in result
|
||||
|
||||
def test_subgraph_node_id_is_replaced(self):
|
||||
"""Test 'subgraph' node ID is replaced with 'subgraph_node'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "subgraph", "type": "code", "title": "Process"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "subgraph_node[" in result
|
||||
|
||||
def test_edge_uses_safe_ids(self):
|
||||
"""Test edges correctly reference safe IDs after replacement."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Edge should use end_node, not end
|
||||
assert "start --> end_node" in result
|
||||
assert "start --> end\n" not in result
|
||||
|
||||
|
||||
class TestBranchEdges:
|
||||
"""Tests for branching node edge labels."""
|
||||
|
||||
def test_question_classifier_source_handles(self):
|
||||
"""Test question-classifier edges with sourceHandle labels."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "classifier", "type": "question-classifier", "title": "Classify"},
|
||||
{"id": "refund", "type": "llm", "title": "Handle Refund"},
|
||||
{"id": "inquiry", "type": "llm", "title": "Handle Inquiry"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "classifier", "target": "refund", "sourceHandle": "refund"},
|
||||
{"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "classifier -->|refund| refund" in result
|
||||
assert "classifier -->|inquiry| inquiry" in result
|
||||
|
||||
def test_if_else_true_false_handles(self):
|
||||
"""Test if-else edges with true/false labels."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "ifelse", "type": "if-else", "title": "Check"},
|
||||
{"id": "yes_branch", "type": "llm", "title": "Yes"},
|
||||
{"id": "no_branch", "type": "llm", "title": "No"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"},
|
||||
{"source": "ifelse", "target": "no_branch", "sourceHandle": "false"},
|
||||
],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "ifelse -->|true| yes_branch" in result
|
||||
assert "ifelse -->|false| no_branch" in result
|
||||
|
||||
def test_source_handle_source_is_ignored(self):
|
||||
"""Test sourceHandle='source' doesn't add label."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "llm1", "type": "llm", "title": "LLM 1"},
|
||||
{"id": "llm2", "type": "llm", "title": "LLM 2"},
|
||||
],
|
||||
"edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should be plain arrow without label
|
||||
assert "llm1 --> llm2" in result
|
||||
assert "llm1 -->|source|" not in result
|
||||
|
||||
|
||||
class TestEdgeValidation:
|
||||
"""Tests for edge validation and error handling."""
|
||||
|
||||
def test_edge_with_missing_source_is_skipped(self):
|
||||
"""Test edge with non-existent source node is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "end", "type": "end", "title": "End"}],
|
||||
"edges": [{"source": "nonexistent", "target": "end"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should not contain the invalid edge
|
||||
assert "nonexistent" not in result
|
||||
assert "-->" not in result or "nonexistent" not in result
|
||||
|
||||
def test_edge_with_missing_target_is_skipped(self):
|
||||
"""Test edge with non-existent target node is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
|
||||
"edges": [{"source": "start", "target": "nonexistent"}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Edge should be skipped
|
||||
assert "start --> nonexistent" not in result
|
||||
|
||||
def test_edge_without_source_or_target_is_skipped(self):
|
||||
"""Test edge missing source or target is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
|
||||
"edges": [{"source": "start"}, {"target": "start"}, {}],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# No edges should be rendered
|
||||
assert result.count("-->") == 0
|
||||
|
||||
|
||||
class TestToolNodes:
|
||||
"""Tests for tool node formatting."""
|
||||
|
||||
def test_tool_node_includes_tool_key(self):
|
||||
"""Test tool node includes tool_key in label."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "search",
|
||||
"type": "tool",
|
||||
"title": "Search",
|
||||
"config": {"tool_key": "google/search"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert 'search["type=tool|title=Search|tool=google/search"]' in result
|
||||
|
||||
def test_tool_node_with_tool_name_fallback(self):
|
||||
"""Test tool node uses tool_name as fallback."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "tool1",
|
||||
"type": "tool",
|
||||
"title": "My Tool",
|
||||
"config": {"tool_name": "my_tool"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "tool=my_tool" in result
|
||||
|
||||
def test_tool_node_missing_tool_key_shows_unknown(self):
|
||||
"""Test tool node without tool_key shows 'unknown'."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "tool=unknown" in result
|
||||
|
||||
|
||||
class TestNodeFormatting:
|
||||
"""Tests for node label formatting."""
|
||||
|
||||
def test_quotes_in_title_are_escaped(self):
|
||||
"""Test double quotes in title are replaced with single quotes."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Double quotes should be replaced
|
||||
assert "Say 'Hello'" in result
|
||||
assert 'Say "Hello"' not in result
|
||||
|
||||
def test_node_without_id_is_skipped(self):
|
||||
"""Test node without id is skipped."""
|
||||
workflow_data = {
|
||||
"nodes": [{"type": "llm", "title": "No ID"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
# Should only have flowchart header
|
||||
lines = [line for line in result.split("\n") if line.strip()]
|
||||
assert len(lines) == 1
|
||||
|
||||
def test_node_default_values(self):
|
||||
"""Test node with missing type/title uses defaults."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "node1"}],
|
||||
"edges": [],
|
||||
}
|
||||
result = generate_mermaid(workflow_data)
|
||||
|
||||
assert "type=unknown" in result
|
||||
assert "title=Untitled" in result
|
||||
@ -1,81 +0,0 @@
|
||||
from core.workflow.generator.utils.node_repair import NodeRepair
|
||||
|
||||
|
||||
class TestNodeRepair:
|
||||
"""Tests for NodeRepair utility."""
|
||||
|
||||
def test_repair_if_else_valid_operators(self):
|
||||
"""Test that valid operators remain unchanged."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [
|
||||
{
|
||||
"conditions": [
|
||||
{"comparison_operator": "≥", "value": "1"},
|
||||
{"comparison_operator": "=", "value": "2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
assert result.nodes == nodes
|
||||
|
||||
def test_repair_if_else_invalid_operators(self):
|
||||
"""Test that invalid operators are normalized."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [
|
||||
{
|
||||
"conditions": [
|
||||
{"comparison_operator": ">=", "value": "1"},
|
||||
{"comparison_operator": "<=", "value": "2"},
|
||||
{"comparison_operator": "!=", "value": "3"},
|
||||
{"comparison_operator": "==", "value": "4"},
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is True
|
||||
assert len(result.repairs_made) == 4
|
||||
|
||||
conditions = result.nodes[0]["config"]["cases"][0]["conditions"]
|
||||
assert conditions[0]["comparison_operator"] == "≥"
|
||||
assert conditions[1]["comparison_operator"] == "≤"
|
||||
assert conditions[2]["comparison_operator"] == "≠"
|
||||
assert conditions[3]["comparison_operator"] == "="
|
||||
|
||||
def test_repair_ignores_other_nodes(self):
|
||||
"""Test that other node types are ignored."""
|
||||
nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
assert result.nodes[0]["config"]["some_field"] == ">="
|
||||
|
||||
def test_repair_handles_missing_config(self):
|
||||
"""Test robustness against missing fields."""
|
||||
nodes = [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "if-else",
|
||||
# Missing config
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"type": "if-else",
|
||||
"config": {}, # Missing cases
|
||||
},
|
||||
]
|
||||
result = NodeRepair.repair(nodes)
|
||||
assert result.was_repaired is False
|
||||
@ -1,99 +0,0 @@
|
||||
"""
|
||||
Tests for node schemas validation.
|
||||
|
||||
Ensures that the node configuration stays in sync with registered node types.
|
||||
"""
|
||||
|
||||
from core.workflow.generator.config.node_schemas import (
|
||||
get_builtin_node_schemas,
|
||||
validate_node_schemas,
|
||||
)
|
||||
|
||||
|
||||
class TestNodeSchemasValidation:
|
||||
"""Tests for node schema validation utilities."""
|
||||
|
||||
def test_validate_node_schemas_returns_no_warnings(self):
|
||||
"""Ensure all registered node types have corresponding schemas."""
|
||||
warnings = validate_node_schemas()
|
||||
# If this test fails, it means a new node type was added but
|
||||
# no schema was defined for it in node_schemas.py
|
||||
assert len(warnings) == 0, (
|
||||
f"Missing schemas for node types: {warnings}. "
|
||||
"Please add schemas for these node types in node_schemas.py "
|
||||
"or add them to _INTERNAL_NODE_TYPES if they don't need schemas."
|
||||
)
|
||||
|
||||
def test_builtin_node_schemas_not_empty(self):
|
||||
"""Ensure BUILTIN_NODE_SCHEMAS contains expected node types."""
|
||||
# get_builtin_node_schemas() includes dynamic schemas
|
||||
all_schemas = get_builtin_node_schemas()
|
||||
assert len(all_schemas) > 0
|
||||
# Core node types should always be present
|
||||
expected_types = ["llm", "code", "http-request", "if-else"]
|
||||
for node_type in expected_types:
|
||||
assert node_type in all_schemas, f"Missing schema for core node type: {node_type}"
|
||||
|
||||
def test_schema_structure(self):
|
||||
"""Ensure each schema has required fields."""
|
||||
all_schemas = get_builtin_node_schemas()
|
||||
for node_type, schema in all_schemas.items():
|
||||
assert "description" in schema, f"Missing 'description' in schema for {node_type}"
|
||||
# 'parameters' is optional but if present should be a dict
|
||||
if "parameters" in schema:
|
||||
assert isinstance(schema["parameters"], dict), (
|
||||
f"'parameters' in schema for {node_type} should be a dict"
|
||||
)
|
||||
|
||||
|
||||
class TestNodeSchemasMerged:
|
||||
"""Tests to verify the merged configuration works correctly."""
|
||||
|
||||
def test_fallback_rules_available(self):
|
||||
"""Ensure FALLBACK_RULES is available from node_schemas."""
|
||||
from core.workflow.generator.config.node_schemas import FALLBACK_RULES
|
||||
|
||||
assert len(FALLBACK_RULES) > 0
|
||||
assert "http-request" in FALLBACK_RULES
|
||||
assert "code" in FALLBACK_RULES
|
||||
assert "llm" in FALLBACK_RULES
|
||||
|
||||
def test_node_type_aliases_available(self):
|
||||
"""Ensure NODE_TYPE_ALIASES is available from node_schemas."""
|
||||
from core.workflow.generator.config.node_schemas import NODE_TYPE_ALIASES
|
||||
|
||||
assert len(NODE_TYPE_ALIASES) > 0
|
||||
assert NODE_TYPE_ALIASES.get("gpt") == "llm"
|
||||
assert NODE_TYPE_ALIASES.get("api") == "http-request"
|
||||
|
||||
def test_field_name_corrections_available(self):
|
||||
"""Ensure FIELD_NAME_CORRECTIONS is available from node_schemas."""
|
||||
from core.workflow.generator.config.node_schemas import (
|
||||
FIELD_NAME_CORRECTIONS,
|
||||
get_corrected_field_name,
|
||||
)
|
||||
|
||||
assert len(FIELD_NAME_CORRECTIONS) > 0
|
||||
# Test the helper function
|
||||
assert get_corrected_field_name("http-request", "text") == "body"
|
||||
assert get_corrected_field_name("llm", "response") == "text"
|
||||
assert get_corrected_field_name("code", "unknown") == "unknown"
|
||||
|
||||
def test_config_init_exports(self):
|
||||
"""Ensure config __init__.py exports all needed symbols."""
|
||||
from core.workflow.generator.config import (
|
||||
BUILTIN_NODE_SCHEMAS,
|
||||
FALLBACK_RULES,
|
||||
FIELD_NAME_CORRECTIONS,
|
||||
NODE_TYPE_ALIASES,
|
||||
get_corrected_field_name,
|
||||
validate_node_schemas,
|
||||
)
|
||||
|
||||
# Just verify imports work
|
||||
assert BUILTIN_NODE_SCHEMAS is not None
|
||||
assert FALLBACK_RULES is not None
|
||||
assert FIELD_NAME_CORRECTIONS is not None
|
||||
assert NODE_TYPE_ALIASES is not None
|
||||
assert callable(get_corrected_field_name)
|
||||
assert callable(validate_node_schemas)
|
||||
@ -1,172 +0,0 @@
|
||||
"""
|
||||
Unit tests for the Planner Prompts.
|
||||
|
||||
Tests cover:
|
||||
- Tool formatting for planner context
|
||||
- Edge cases with missing fields
|
||||
- Empty tool lists
|
||||
"""
|
||||
|
||||
from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner
|
||||
|
||||
|
||||
class TestFormatToolsForPlanner:
|
||||
"""Tests for format_tools_for_planner function."""
|
||||
|
||||
def test_empty_tools_returns_default_message(self):
|
||||
"""Test empty tools list returns default message."""
|
||||
result = format_tools_for_planner([])
|
||||
|
||||
assert result == "No external tools available."
|
||||
|
||||
def test_none_tools_returns_default_message(self):
|
||||
"""Test None tools list returns default message."""
|
||||
result = format_tools_for_planner(None)
|
||||
|
||||
assert result == "No external tools available."
|
||||
|
||||
def test_single_tool_formatting(self):
|
||||
"""Test single tool is formatted correctly."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "google",
|
||||
"tool_key": "search",
|
||||
"tool_label": "Google Search",
|
||||
"tool_description": "Search the web using Google",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[google/search]" in result
|
||||
assert "Google Search" in result
|
||||
assert "Search the web using Google" in result
|
||||
|
||||
def test_multiple_tools_formatting(self):
|
||||
"""Test multiple tools are formatted correctly."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "google",
|
||||
"tool_key": "search",
|
||||
"tool_label": "Search",
|
||||
"tool_description": "Web search",
|
||||
},
|
||||
{
|
||||
"provider_id": "slack",
|
||||
"tool_key": "send_message",
|
||||
"tool_label": "Send Message",
|
||||
"tool_description": "Send a Slack message",
|
||||
},
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
assert "[google/search]" in result
|
||||
assert "[slack/send_message]" in result
|
||||
|
||||
def test_tool_without_provider_uses_key_only(self):
|
||||
"""Test tool without provider_id uses tool_key only."""
|
||||
tools = [
|
||||
{
|
||||
"tool_key": "my_tool",
|
||||
"tool_label": "My Tool",
|
||||
"tool_description": "A custom tool",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Should format as [my_tool] without provider prefix
|
||||
assert "[my_tool]" in result
|
||||
assert "My Tool" in result
|
||||
|
||||
def test_tool_with_tool_name_fallback(self):
|
||||
"""Test tool uses tool_name when tool_key is missing."""
|
||||
tools = [
|
||||
{
|
||||
"tool_name": "fallback_tool",
|
||||
"description": "Fallback description",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "fallback_tool" in result
|
||||
assert "Fallback description" in result
|
||||
|
||||
def test_tool_with_missing_description(self):
|
||||
"""Test tool with missing description doesn't crash."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "test",
|
||||
"tool_key": "tool1",
|
||||
"tool_label": "Tool 1",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[test/tool1]" in result
|
||||
assert "Tool 1" in result
|
||||
|
||||
def test_tool_with_all_missing_fields(self):
|
||||
"""Test tool with all fields missing uses defaults."""
|
||||
tools = [{}]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Should not crash, may produce minimal output
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_tool_uses_provider_fallback(self):
|
||||
"""Test tool uses 'provider' when 'provider_id' is missing."""
|
||||
tools = [
|
||||
{
|
||||
"provider": "openai",
|
||||
"tool_key": "dalle",
|
||||
"tool_label": "DALL-E",
|
||||
"tool_description": "Generate images",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
assert "[openai/dalle]" in result
|
||||
|
||||
def test_tool_label_fallback_to_key(self):
|
||||
"""Test tool_label falls back to tool_key when missing."""
|
||||
tools = [
|
||||
{
|
||||
"provider_id": "test",
|
||||
"tool_key": "my_key",
|
||||
"tool_description": "Description here",
|
||||
}
|
||||
]
|
||||
result = format_tools_for_planner(tools)
|
||||
|
||||
# Label should fallback to key
|
||||
assert "my_key" in result
|
||||
assert "Description here" in result
|
||||
|
||||
|
||||
class TestPlannerPromptConstants:
|
||||
"""Tests for planner prompt constant availability."""
|
||||
|
||||
def test_planner_system_prompt_exists(self):
|
||||
"""Test PLANNER_SYSTEM_PROMPT is defined."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
|
||||
|
||||
assert PLANNER_SYSTEM_PROMPT is not None
|
||||
assert len(PLANNER_SYSTEM_PROMPT) > 0
|
||||
assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT
|
||||
|
||||
def test_planner_user_prompt_exists(self):
|
||||
"""Test PLANNER_USER_PROMPT is defined."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT
|
||||
|
||||
assert PLANNER_USER_PROMPT is not None
|
||||
assert "{instruction}" in PLANNER_USER_PROMPT
|
||||
|
||||
def test_planner_system_prompt_has_required_sections(self):
|
||||
"""Test PLANNER_SYSTEM_PROMPT has required XML sections."""
|
||||
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
|
||||
|
||||
assert "<role>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<task>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<available_tools>" in PLANNER_SYSTEM_PROMPT
|
||||
assert "<response_format>" in PLANNER_SYSTEM_PROMPT
|
||||
@ -1,510 +0,0 @@
|
||||
"""
|
||||
Unit tests for the Validation Rule Engine.
|
||||
|
||||
Tests cover:
|
||||
- Structure rules (required fields, types, formats)
|
||||
- Semantic rules (variable references, edge connections)
|
||||
- Reference rules (model exists, tool configured, dataset valid)
|
||||
- ValidationEngine integration
|
||||
"""
|
||||
|
||||
from core.workflow.generator.validation import (
|
||||
ValidationContext,
|
||||
ValidationEngine,
|
||||
)
|
||||
from core.workflow.generator.validation.rules import (
|
||||
extract_variable_refs,
|
||||
is_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class TestPlaceholderDetection:
|
||||
"""Tests for placeholder detection utility."""
|
||||
|
||||
def test_detects_please_select(self):
|
||||
assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True
|
||||
|
||||
def test_detects_your_prefix(self):
|
||||
assert is_placeholder("YOUR_API_KEY") is True
|
||||
|
||||
def test_detects_todo(self):
|
||||
assert is_placeholder("TODO: fill this in") is True
|
||||
|
||||
def test_detects_placeholder(self):
|
||||
assert is_placeholder("PLACEHOLDER_VALUE") is True
|
||||
|
||||
def test_detects_example_prefix(self):
|
||||
assert is_placeholder("EXAMPLE_URL") is True
|
||||
|
||||
def test_detects_replace_prefix(self):
|
||||
assert is_placeholder("REPLACE_WITH_ACTUAL") is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert is_placeholder("please_select") is True
|
||||
assert is_placeholder("Please_Select") is True
|
||||
|
||||
def test_valid_values_not_detected(self):
|
||||
assert is_placeholder("https://api.example.com") is False
|
||||
assert is_placeholder("gpt-4") is False
|
||||
assert is_placeholder("my_variable") is False
|
||||
|
||||
def test_non_string_returns_false(self):
|
||||
assert is_placeholder(123) is False
|
||||
assert is_placeholder(None) is False
|
||||
assert is_placeholder(["list"]) is False
|
||||
|
||||
|
||||
class TestVariableRefExtraction:
|
||||
"""Tests for variable reference extraction."""
|
||||
|
||||
def test_extracts_simple_ref(self):
|
||||
refs = extract_variable_refs("Hello {{#start.query#}}")
|
||||
assert refs == [("start", "query")]
|
||||
|
||||
def test_extracts_multiple_refs(self):
|
||||
refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}")
|
||||
assert refs == [("node1", "output"), ("node2", "text")]
|
||||
|
||||
def test_extracts_nested_field(self):
|
||||
refs = extract_variable_refs("{{#http_request.body#}}")
|
||||
assert refs == [("http_request", "body")]
|
||||
|
||||
def test_no_refs_returns_empty(self):
|
||||
refs = extract_variable_refs("No references here")
|
||||
assert refs == []
|
||||
|
||||
def test_handles_malformed_refs(self):
|
||||
refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}")
|
||||
assert refs == []
|
||||
|
||||
|
||||
class TestValidationContext:
|
||||
"""Tests for ValidationContext."""
|
||||
|
||||
def test_node_map_lookup(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start"},
|
||||
{"id": "llm_1", "type": "llm"},
|
||||
]
|
||||
)
|
||||
assert ctx.get_node("start") == {"id": "start", "type": "start"}
|
||||
assert ctx.get_node("nonexistent") is None
|
||||
|
||||
def test_model_set(self):
|
||||
ctx = ValidationContext(
|
||||
available_models=[
|
||||
{"provider": "openai", "model": "gpt-4"},
|
||||
{"provider": "anthropic", "model": "claude-3"},
|
||||
]
|
||||
)
|
||||
assert ctx.has_model("openai", "gpt-4") is True
|
||||
assert ctx.has_model("anthropic", "claude-3") is True
|
||||
assert ctx.has_model("openai", "gpt-3.5") is False
|
||||
|
||||
def test_tool_set(self):
|
||||
ctx = ValidationContext(
|
||||
available_tools=[
|
||||
{"provider_id": "google", "tool_key": "search", "is_team_authorization": True},
|
||||
{"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False},
|
||||
]
|
||||
)
|
||||
assert ctx.has_tool("google/search") is True
|
||||
assert ctx.has_tool("search") is True
|
||||
assert ctx.is_tool_configured("google/search") is True
|
||||
assert ctx.is_tool_configured("slack/send_message") is False
|
||||
|
||||
def test_upstream_downstream_nodes(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start"},
|
||||
{"id": "llm", "type": "llm"},
|
||||
{"id": "end", "type": "end"},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "llm"},
|
||||
{"source": "llm", "target": "end"},
|
||||
],
|
||||
)
|
||||
assert ctx.get_upstream_nodes("llm") == ["start"]
|
||||
assert ctx.get_downstream_nodes("llm") == ["end"]
|
||||
|
||||
|
||||
class TestStructureRules:
|
||||
"""Tests for structure validation rules."""
|
||||
|
||||
def test_llm_missing_prompt_template(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_errors
|
||||
errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_llm_with_prompt_template_passes(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "You are helpful"},
|
||||
{"role": "user", "text": "Hello"},
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# No prompt_template errors
|
||||
errors = [e for e in result.all_errors if "prompt_template" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_http_request_missing_url(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "http_1", "type": "http-request", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "http.url" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_http_request_placeholder_url(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "http_1",
|
||||
"type": "http-request",
|
||||
"config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"},
|
||||
}
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "placeholder" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_code_node_missing_fields(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "code_1", "type": "code", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
error_rules = {e.rule_id for e in result.all_errors}
|
||||
assert "code.code.required" in error_rules
|
||||
assert "code.language.required" in error_rules
|
||||
|
||||
def test_knowledge_retrieval_missing_dataset(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False # User must configure
|
||||
|
||||
|
||||
class TestSemanticRules:
|
||||
"""Tests for semantic validation rules."""
|
||||
|
||||
def test_valid_variable_reference(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#start.query#}}"}]},
|
||||
},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# No variable reference errors
|
||||
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_invalid_variable_reference(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#nonexistent.field#}}"}]},
|
||||
},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert "nonexistent" in errors[0].message
|
||||
|
||||
def test_edge_validation(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "end"},
|
||||
{"source": "nonexistent", "target": "end"},
|
||||
],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "edge" in e.rule_id]
|
||||
assert len(errors) == 1
|
||||
assert "nonexistent" in errors[0].message
|
||||
|
||||
|
||||
class TestReferenceRules:
|
||||
"""Tests for reference validation rules (models, tools)."""
|
||||
|
||||
def test_llm_missing_model_with_available(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.required"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_llm_missing_model_no_available(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[], # No models available
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.no_available"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False
|
||||
|
||||
def test_llm_with_valid_model(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [{"role": "user", "text": "Hi"}],
|
||||
"model": {"provider": "openai", "name": "gpt-4"},
|
||||
},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if "model" in e.rule_id]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_llm_with_invalid_model(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_template": [{"role": "user", "text": "Hi"}],
|
||||
"model": {"provider": "openai", "name": "gpt-99"},
|
||||
},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "model.not_found"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is True
|
||||
|
||||
def test_tool_node_not_found(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "tool_1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "nonexistent/tool"},
|
||||
}
|
||||
],
|
||||
available_tools=[],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_tool_node_not_configured(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "tool_1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "google/search"},
|
||||
}
|
||||
],
|
||||
available_tools=[{"provider_id": "google", "tool_key": "search", "is_team_authorization": False}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].is_fixable is False
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""Tests for ValidationResult classification."""
|
||||
|
||||
def test_has_errors(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_errors is True
|
||||
assert result.is_valid is False
|
||||
|
||||
def test_has_fixable_errors(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
|
||||
}
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
assert result.has_fixable_errors is True
|
||||
assert len(result.fixable_errors) > 0
|
||||
|
||||
def test_get_fixable_by_node(self):
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "llm_1", "type": "llm", "config": {}},
|
||||
{"id": "http_1", "type": "http-request", "config": {}},
|
||||
]
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
by_node = result.get_fixable_by_node()
|
||||
assert "llm_1" in by_node
|
||||
assert "http_1" in by_node
|
||||
|
||||
def test_to_dict(self):
|
||||
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
d = result.to_dict()
|
||||
assert "fixable" in d
|
||||
assert "user_required" in d
|
||||
assert "warnings" in d
|
||||
assert "all_warnings" in d
|
||||
assert "stats" in d
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full validation pipeline."""
|
||||
|
||||
def test_complete_workflow_validation(self):
|
||||
"""Test validation of a complete workflow."""
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"config": {"variables": [{"variable": "query", "type": "text-input"}]},
|
||||
},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4"},
|
||||
"prompt_template": [{"role": "user", "text": "{{#start.query#}}"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "end",
|
||||
"type": "end",
|
||||
"config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]},
|
||||
},
|
||||
],
|
||||
edges=[
|
||||
{"source": "start", "target": "llm_1"},
|
||||
{"source": "llm_1", "target": "end"},
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# Should have no errors
|
||||
assert result.is_valid is True
|
||||
assert len(result.fixable_errors) == 0
|
||||
assert len(result.user_required_errors) == 0
|
||||
|
||||
def test_workflow_with_multiple_errors(self):
|
||||
"""Test workflow with multiple types of errors."""
|
||||
ctx = ValidationContext(
|
||||
nodes=[
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm_1",
|
||||
"type": "llm",
|
||||
"config": {}, # Missing prompt_template and model
|
||||
},
|
||||
{
|
||||
"id": "kb_1",
|
||||
"type": "knowledge-retrieval",
|
||||
"config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
available_models=[{"provider": "openai", "model": "gpt-4"}],
|
||||
)
|
||||
engine = ValidationEngine()
|
||||
result = engine.validate(ctx)
|
||||
|
||||
# Should have multiple errors
|
||||
assert result.has_errors is True
|
||||
assert len(result.fixable_errors) >= 2 # model, prompt_template
|
||||
assert len(result.user_required_errors) >= 1 # dataset placeholder
|
||||
|
||||
# Check stats
|
||||
assert result.stats["total_nodes"] == 4
|
||||
assert result.stats["total_errors"] >= 3
|
||||
@ -1,434 +0,0 @@
|
||||
"""
|
||||
Unit tests for the Vibe Workflow Validator.
|
||||
|
||||
Tests cover:
|
||||
- Basic validation function
|
||||
- User-friendly validation hints
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator
|
||||
|
||||
|
||||
class TestValidationHint:
|
||||
"""Tests for ValidationHint dataclass."""
|
||||
|
||||
def test_hint_creation(self):
|
||||
"""Test creating a validation hint."""
|
||||
hint = ValidationHint(
|
||||
node_id="llm_1",
|
||||
field="model",
|
||||
message="Model is not configured",
|
||||
severity="error",
|
||||
)
|
||||
assert hint.node_id == "llm_1"
|
||||
assert hint.field == "model"
|
||||
assert hint.message == "Model is not configured"
|
||||
assert hint.severity == "error"
|
||||
|
||||
def test_hint_with_suggestion(self):
|
||||
"""Test hint with suggestion."""
|
||||
hint = ValidationHint(
|
||||
node_id="http_1",
|
||||
field="url",
|
||||
message="URL is required",
|
||||
severity="error",
|
||||
suggestion="Add a valid URL like https://api.example.com",
|
||||
)
|
||||
assert hint.suggestion is not None
|
||||
|
||||
|
||||
class TestWorkflowValidatorBasic:
|
||||
"""Tests for basic validation scenarios."""
|
||||
|
||||
def test_empty_workflow_is_valid(self):
|
||||
"""Test empty workflow passes validation."""
|
||||
workflow_data = {"nodes": [], "edges": []}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
# Empty but valid structure
|
||||
assert is_valid is True
|
||||
assert len(hints) == 0
|
||||
|
||||
def test_minimal_valid_workflow(self):
|
||||
"""Test minimal Start → End workflow."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "end"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
def test_complete_workflow_with_llm(self):
|
||||
"""Test complete workflow with LLM node."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {"variables": []}},
|
||||
{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4"},
|
||||
"prompt_template": [{"role": "user", "text": "Hello"}],
|
||||
},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {"outputs": []}},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "llm"},
|
||||
{"source": "llm", "target": "end"},
|
||||
],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
# Should pass with no critical errors
|
||||
errors = [h for h in hints if h.severity == "error"]
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestVariableReferenceValidation:
|
||||
"""Tests for variable reference validation."""
|
||||
|
||||
def test_valid_variable_reference(self):
|
||||
"""Test valid variable reference passes."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "Query: {{#start.query#}}"}]},
|
||||
},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "llm"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
ref_errors = [h for h in hints if "reference" in h.message.lower()]
|
||||
assert len(ref_errors) == 0
|
||||
|
||||
def test_invalid_variable_reference(self):
|
||||
"""Test invalid variable reference generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "llm",
|
||||
"type": "llm",
|
||||
"config": {"prompt_template": [{"role": "user", "text": "{{#nonexistent.field#}}"}]},
|
||||
},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "llm"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
# Should have a hint about invalid reference
|
||||
ref_hints = [h for h in hints if "nonexistent" in h.message or "reference" in h.message.lower()]
|
||||
assert len(ref_hints) >= 1
|
||||
|
||||
|
||||
class TestEdgeValidation:
|
||||
"""Tests for edge validation."""
|
||||
|
||||
def test_edge_with_invalid_source(self):
|
||||
"""Test edge with non-existent source generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "end", "type": "end", "config": {}}],
|
||||
"edges": [{"source": "nonexistent", "target": "end"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
# Should have hint about invalid edge
|
||||
edge_hints = [h for h in hints if "edge" in h.message.lower() or "source" in h.message.lower()]
|
||||
assert len(edge_hints) >= 1
|
||||
|
||||
def test_edge_with_invalid_target(self):
|
||||
"""Test edge with non-existent target generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [{"id": "start", "type": "start", "config": {}}],
|
||||
"edges": [{"source": "start", "target": "nonexistent"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
edge_hints = [h for h in hints if "edge" in h.message.lower() or "target" in h.message.lower()]
|
||||
assert len(edge_hints) >= 1
|
||||
|
||||
|
||||
class TestToolValidation:
|
||||
"""Tests for tool node validation."""
|
||||
|
||||
def test_tool_node_found_in_available(self):
|
||||
"""Test tool node that exists in available tools."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "tool1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "google/search"},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "tool1"}, {"source": "tool1", "target": "end"}],
|
||||
}
|
||||
available_tools = [{"provider_id": "google", "tool_key": "search", "is_team_authorization": True}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools)
|
||||
|
||||
tool_errors = [h for h in hints if h.severity == "error" and "tool" in h.message.lower()]
|
||||
assert len(tool_errors) == 0
|
||||
|
||||
def test_tool_node_not_found(self):
|
||||
"""Test tool node not in available tools generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "tool1",
|
||||
"type": "tool",
|
||||
"config": {"tool_key": "unknown/tool"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
available_tools = []
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools)
|
||||
|
||||
tool_hints = [h for h in hints if "tool" in h.message.lower()]
|
||||
assert len(tool_hints) >= 1
|
||||
|
||||
|
||||
class TestQuestionClassifierValidation:
|
||||
"""Tests for question-classifier node validation."""
|
||||
|
||||
def test_question_classifier_with_classes(self):
|
||||
"""Test question-classifier with valid classes."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"classes": [
|
||||
{"id": "class1", "name": "Class 1"},
|
||||
{"id": "class2", "name": "Class 2"},
|
||||
],
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
|
||||
},
|
||||
},
|
||||
{"id": "h1", "type": "llm", "config": {}},
|
||||
{"id": "h2", "type": "llm", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "classifier"},
|
||||
{"source": "classifier", "sourceHandle": "class1", "target": "h1"},
|
||||
{"source": "classifier", "sourceHandle": "class2", "target": "h2"},
|
||||
{"source": "h1", "target": "end"},
|
||||
{"source": "h2", "target": "end"},
|
||||
],
|
||||
}
|
||||
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
|
||||
|
||||
class_errors = [h for h in hints if "class" in h.message.lower() and h.severity == "error"]
|
||||
assert len(class_errors) == 0
|
||||
|
||||
def test_question_classifier_missing_classes(self):
|
||||
"""Test question-classifier without classes generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "classifier",
|
||||
"type": "question-classifier",
|
||||
"config": {"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
|
||||
|
||||
# Should have hint about missing classes
|
||||
class_hints = [h for h in hints if "class" in h.message.lower()]
|
||||
assert len(class_hints) >= 1
|
||||
|
||||
|
||||
class TestHttpRequestValidation:
|
||||
"""Tests for HTTP request node validation."""
|
||||
|
||||
def test_http_request_with_url(self):
|
||||
"""Test HTTP request with valid URL."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "http",
|
||||
"type": "http-request",
|
||||
"config": {"url": "https://api.example.com", "method": "GET"},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "http"}, {"source": "http", "target": "end"}],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
url_errors = [h for h in hints if "url" in h.message.lower() and h.severity == "error"]
|
||||
assert len(url_errors) == 0
|
||||
|
||||
def test_http_request_missing_url(self):
|
||||
"""Test HTTP request without URL generates hint."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "http",
|
||||
"type": "http-request",
|
||||
"config": {"method": "GET"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
|
||||
url_hints = [h for h in hints if "url" in h.message.lower()]
|
||||
assert len(url_hints) >= 1
|
||||
|
||||
|
||||
class TestParameterExtractorValidation:
|
||||
"""Tests for parameter-extractor node validation."""
|
||||
|
||||
def test_parameter_extractor_valid_params(self):
|
||||
"""Test parameter-extractor with valid parameters."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "extractor",
|
||||
"type": "parameter-extractor",
|
||||
"config": {
|
||||
"instruction": "Extract info",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "Name",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
|
||||
},
|
||||
},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [{"source": "start", "target": "extractor"}, {"source": "extractor", "target": "end"}],
|
||||
}
|
||||
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
|
||||
|
||||
errors = [h for h in hints if h.severity == "error"]
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_parameter_extractor_missing_required_field(self):
|
||||
"""Test parameter-extractor missing 'required' field in parameter item."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "extractor",
|
||||
"type": "parameter-extractor",
|
||||
"config": {
|
||||
"instruction": "Extract info",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "Name",
|
||||
# Missing 'required'
|
||||
}
|
||||
],
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
|
||||
|
||||
errors = [h for h in hints if "required" in h.message and h.severity == "error"]
|
||||
assert len(errors) >= 1
|
||||
assert "parameter-extractor" in errors[0].node_type
|
||||
|
||||
|
||||
class TestIfElseValidation:
|
||||
"""Tests for if-else node validation."""
|
||||
|
||||
def test_if_else_valid_operators(self):
|
||||
"""Test if-else with valid operators."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "ifelse",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [{"case_id": "c1", "conditions": [{"comparison_operator": "≥", "value": "1"}]}]
|
||||
},
|
||||
},
|
||||
{"id": "t", "type": "llm", "config": {}},
|
||||
{"id": "f", "type": "llm", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "ifelse"},
|
||||
{"source": "ifelse", "sourceHandle": "true", "target": "t"},
|
||||
{"source": "ifelse", "sourceHandle": "false", "target": "f"},
|
||||
{"source": "t", "target": "end"},
|
||||
{"source": "f", "target": "end"},
|
||||
],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
errors = [h for h in hints if h.severity == "error"]
|
||||
# Filter out LLM model errors if any (available tools/models check might trigger)
|
||||
# (actually available_models empty list might trigger model error?
|
||||
# No, model config validation skips if model field not present? No, LLM has model config.
|
||||
# But logic skips check if key missing? Let's check logic.
|
||||
# _check_model_config checks if provider/name match available. If available is empty, it fails.
|
||||
# But wait, validate default available_models is None?
|
||||
# I should provide mock available_models or ignore model errors.
|
||||
|
||||
# Actually LLM node "config": {} implies missing model config. Rules check if config structure is valid?
|
||||
# Let's filter specifically for operator errors.
|
||||
operator_errors = [h for h in errors if "operator" in h.message]
|
||||
assert len(operator_errors) == 0
|
||||
|
||||
def test_if_else_invalid_operators(self):
|
||||
"""Test if-else with invalid operators."""
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{"id": "start", "type": "start", "config": {}},
|
||||
{
|
||||
"id": "ifelse",
|
||||
"type": "if-else",
|
||||
"config": {
|
||||
"cases": [{"case_id": "c1", "conditions": [{"comparison_operator": ">=", "value": "1"}]}]
|
||||
},
|
||||
},
|
||||
{"id": "t", "type": "llm", "config": {}},
|
||||
{"id": "f", "type": "llm", "config": {}},
|
||||
{"id": "end", "type": "end", "config": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "ifelse"},
|
||||
{"source": "ifelse", "sourceHandle": "true", "target": "t"},
|
||||
{"source": "ifelse", "sourceHandle": "false", "target": "f"},
|
||||
{"source": "t", "target": "end"},
|
||||
{"source": "f", "target": "end"},
|
||||
],
|
||||
}
|
||||
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
|
||||
operator_errors = [h for h in hints if "operator" in h.message and h.severity == "error"]
|
||||
assert len(operator_errors) > 0
|
||||
assert "≥" in operator_errors[0].suggestion
|
||||
2
api/uv.lock
generated
2
api/uv.lock
generated
@ -1368,7 +1368,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.12.0"
|
||||
version = "1.11.4"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
set -x
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.12.0
|
||||
image: langgenius/dify-web:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
@ -662,14 +662,13 @@ services:
|
||||
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
|
||||
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
|
||||
volumes:
|
||||
- ./volumes/iris:/durable
|
||||
- ./volumes/iris:/opt/iris
|
||||
- ./iris/iris-init.script:/iris-init.script
|
||||
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
|
||||
entrypoint: ["/custom-entrypoint.sh"]
|
||||
tty: true
|
||||
environment:
|
||||
TZ: ${IRIS_TIMEZONE:-UTC}
|
||||
ISC_DATA_DIRECTORY: /durable/iris
|
||||
|
||||
# Oracle vector database
|
||||
oracle:
|
||||
|
||||
@ -707,7 +707,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -749,7 +749,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -788,7 +788,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -818,7 +818,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.12.0
|
||||
image: langgenius/dify-web:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
@ -1348,14 +1348,13 @@ services:
|
||||
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
|
||||
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
|
||||
volumes:
|
||||
- ./volumes/iris:/durable
|
||||
- ./volumes/iris:/opt/iris
|
||||
- ./iris/iris-init.script:/iris-init.script
|
||||
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
|
||||
entrypoint: ["/custom-entrypoint.sh"]
|
||||
tty: true
|
||||
environment:
|
||||
TZ: ${IRIS_TIMEZONE:-UTC}
|
||||
ISC_DATA_DIRECTORY: /durable/iris
|
||||
|
||||
# Oracle vector database
|
||||
oracle:
|
||||
|
||||
@ -1,33 +1,15 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# IRIS configuration flag file (stored in durable directory to persist with data)
|
||||
IRIS_CONFIG_DONE="/durable/.iris-configured"
|
||||
|
||||
# Function to wait for IRIS to be ready
|
||||
wait_for_iris() {
|
||||
echo "Waiting for IRIS to be ready..."
|
||||
local max_attempts=30
|
||||
local attempt=1
|
||||
while [ "$attempt" -le "$max_attempts" ]; do
|
||||
if iris qlist IRIS 2>/dev/null | grep -q "running"; then
|
||||
echo "IRIS is ready."
|
||||
return 0
|
||||
fi
|
||||
echo "Attempt $attempt/$max_attempts: IRIS not ready yet, waiting..."
|
||||
sleep 2
|
||||
attempt=$((attempt + 1))
|
||||
done
|
||||
echo "ERROR: IRIS failed to start within expected time." >&2
|
||||
return 1
|
||||
}
|
||||
# IRIS configuration flag file
|
||||
IRIS_CONFIG_DONE="/opt/iris/.iris-configured"
|
||||
|
||||
# Function to configure IRIS
|
||||
configure_iris() {
|
||||
echo "Configuring IRIS for first-time setup..."
|
||||
|
||||
# Wait for IRIS to be fully started
|
||||
wait_for_iris
|
||||
sleep 5
|
||||
|
||||
# Execute the initialization script
|
||||
iris session IRIS < /iris-init.script
|
||||
|
||||
326
web/__tests__/goto-anything/command-selector.test.tsx
Normal file
326
web/__tests__/goto-anything/command-selector.test.tsx
Normal file
@ -0,0 +1,326 @@
|
||||
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import CommandSelector from '../../app/components/goto-anything/command-selector'
|
||||
|
||||
vi.mock('cmdk', () => ({
|
||||
Command: {
|
||||
Group: ({ children, className }: any) => <div className={className}>{children}</div>,
|
||||
Item: ({ children, onSelect, value, className }: any) => (
|
||||
<div
|
||||
className={className}
|
||||
onClick={() => onSelect?.()}
|
||||
data-value={value}
|
||||
data-testid={`command-item-${value}`}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
}))
|
||||
|
||||
describe('CommandSelector', () => {
|
||||
const mockActions: Record<string, ActionItem> = {
|
||||
app: {
|
||||
key: '@app',
|
||||
shortcut: '@app',
|
||||
title: 'Search Applications',
|
||||
description: 'Search apps',
|
||||
search: vi.fn(),
|
||||
},
|
||||
knowledge: {
|
||||
key: '@knowledge',
|
||||
shortcut: '@kb',
|
||||
title: 'Search Knowledge',
|
||||
description: 'Search knowledge bases',
|
||||
search: vi.fn(),
|
||||
},
|
||||
plugin: {
|
||||
key: '@plugin',
|
||||
shortcut: '@plugin',
|
||||
title: 'Search Plugins',
|
||||
description: 'Search plugins',
|
||||
search: vi.fn(),
|
||||
},
|
||||
node: {
|
||||
key: '@node',
|
||||
shortcut: '@node',
|
||||
title: 'Search Nodes',
|
||||
description: 'Search workflow nodes',
|
||||
search: vi.fn(),
|
||||
},
|
||||
}
|
||||
|
||||
const mockOnCommandSelect = vi.fn()
|
||||
const mockOnCommandValueChange = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Basic Rendering', () => {
|
||||
it('should render all actions when no filter is provided', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render empty filter as showing all actions', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Filtering Functionality', () => {
|
||||
it('should filter actions based on searchFilter - single match', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should filter actions with multiple matches', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="p"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should be case-insensitive when filtering', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="APP"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should match partial strings', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="od"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Empty State', () => {
|
||||
it('should show empty state when no matches found', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="xyz"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show empty state when filter is empty', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Selection and Highlight Management', () => {
|
||||
it('should call onCommandValueChange when filter changes and first item differs', () => {
|
||||
const { rerender } = render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
commandValue="@app"
|
||||
onCommandValueChange={mockOnCommandValueChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
commandValue="@app"
|
||||
onCommandValueChange={mockOnCommandValueChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@kb')
|
||||
})
|
||||
|
||||
it('should not call onCommandValueChange if current value still exists', () => {
|
||||
const { rerender } = render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
commandValue="@app"
|
||||
onCommandValueChange={mockOnCommandValueChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="a"
|
||||
commandValue="@app"
|
||||
onCommandValueChange={mockOnCommandValueChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle onCommandSelect callback correctly', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
/>,
|
||||
)
|
||||
|
||||
const knowledgeItem = screen.getByTestId('command-item-@kb')
|
||||
fireEvent.click(knowledgeItem)
|
||||
|
||||
expect(mockOnCommandSelect).toHaveBeenCalledWith('@kb')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty actions object', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={{}}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle special characters in filter', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="@"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined onCommandValueChange gracefully', () => {
|
||||
const { rerender } = render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(() => {
|
||||
rerender(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
/>,
|
||||
)
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Backward Compatibility', () => {
|
||||
it('should work without searchFilter prop (backward compatible)', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should work without commandValue and onCommandValueChange props', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
236
web/__tests__/goto-anything/match-action.test.ts
Normal file
236
web/__tests__/goto-anything/match-action.test.ts
Normal file
@ -0,0 +1,236 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
|
||||
|
||||
// Import after mocking to get mocked version
|
||||
import { matchAction } from '../../app/components/goto-anything/actions'
|
||||
import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry'
|
||||
|
||||
// Mock the entire actions module to avoid import issues
|
||||
vi.mock('../../app/components/goto-anything/actions', () => ({
|
||||
matchAction: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../../app/components/goto-anything/actions/commands/registry')
|
||||
|
||||
// Implement the actual matchAction logic for testing
|
||||
const actualMatchAction = (query: string, actions: Record<string, ActionItem>) => {
|
||||
const result = Object.values(actions).find((action) => {
|
||||
// Special handling for slash commands
|
||||
if (action.key === '/') {
|
||||
// Get all registered commands from the registry
|
||||
const allCommands = slashCommandRegistry.getAllCommands()
|
||||
|
||||
// Check if query matches any registered command
|
||||
return allCommands.some((cmd) => {
|
||||
const cmdPattern = `/${cmd.name}`
|
||||
|
||||
// For direct mode commands, don't match (keep in command selector)
|
||||
if (cmd.mode === 'direct')
|
||||
return false
|
||||
|
||||
// For submenu mode commands, match when complete command is entered
|
||||
return query === cmdPattern || query.startsWith(`${cmdPattern} `)
|
||||
})
|
||||
}
|
||||
|
||||
const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`)
|
||||
return reg.test(query)
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
// Replace mock with actual implementation
|
||||
;(matchAction as Mock).mockImplementation(actualMatchAction)
|
||||
|
||||
describe('matchAction Logic', () => {
|
||||
const mockActions: Record<string, ActionItem> = {
|
||||
app: {
|
||||
key: '@app',
|
||||
shortcut: '@a',
|
||||
title: 'Search Applications',
|
||||
description: 'Search apps',
|
||||
search: vi.fn(),
|
||||
},
|
||||
knowledge: {
|
||||
key: '@knowledge',
|
||||
shortcut: '@kb',
|
||||
title: 'Search Knowledge',
|
||||
description: 'Search knowledge bases',
|
||||
search: vi.fn(),
|
||||
},
|
||||
slash: {
|
||||
key: '/',
|
||||
shortcut: '/',
|
||||
title: 'Commands',
|
||||
description: 'Execute commands',
|
||||
search: vi.fn(),
|
||||
},
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
|
||||
{ name: 'docs', mode: 'direct' },
|
||||
{ name: 'community', mode: 'direct' },
|
||||
{ name: 'feedback', mode: 'direct' },
|
||||
{ name: 'account', mode: 'direct' },
|
||||
{ name: 'theme', mode: 'submenu' },
|
||||
{ name: 'language', mode: 'submenu' },
|
||||
])
|
||||
})
|
||||
|
||||
describe('@ Actions Matching', () => {
|
||||
it('should match @app with key', () => {
|
||||
const result = matchAction('@app', mockActions)
|
||||
expect(result).toBe(mockActions.app)
|
||||
})
|
||||
|
||||
it('should match @app with shortcut', () => {
|
||||
const result = matchAction('@a', mockActions)
|
||||
expect(result).toBe(mockActions.app)
|
||||
})
|
||||
|
||||
it('should match @knowledge with key', () => {
|
||||
const result = matchAction('@knowledge', mockActions)
|
||||
expect(result).toBe(mockActions.knowledge)
|
||||
})
|
||||
|
||||
it('should match @knowledge with shortcut @kb', () => {
|
||||
const result = matchAction('@kb', mockActions)
|
||||
expect(result).toBe(mockActions.knowledge)
|
||||
})
|
||||
|
||||
it('should match with text after action', () => {
|
||||
const result = matchAction('@app search term', mockActions)
|
||||
expect(result).toBe(mockActions.app)
|
||||
})
|
||||
|
||||
it('should not match partial @ actions', () => {
|
||||
const result = matchAction('@ap', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Slash Commands Matching', () => {
|
||||
describe('Direct Mode Commands', () => {
|
||||
it('should not match direct mode commands', () => {
|
||||
const result = matchAction('/docs', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should not match direct mode with arguments', () => {
|
||||
const result = matchAction('/docs something', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should not match any direct mode command', () => {
|
||||
expect(matchAction('/community', mockActions)).toBeUndefined()
|
||||
expect(matchAction('/feedback', mockActions)).toBeUndefined()
|
||||
expect(matchAction('/account', mockActions)).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Submenu Mode Commands', () => {
|
||||
it('should match submenu mode commands exactly', () => {
|
||||
const result = matchAction('/theme', mockActions)
|
||||
expect(result).toBe(mockActions.slash)
|
||||
})
|
||||
|
||||
it('should match submenu mode with arguments', () => {
|
||||
const result = matchAction('/theme dark', mockActions)
|
||||
expect(result).toBe(mockActions.slash)
|
||||
})
|
||||
|
||||
it('should match all submenu commands', () => {
|
||||
expect(matchAction('/language', mockActions)).toBe(mockActions.slash)
|
||||
expect(matchAction('/language en', mockActions)).toBe(mockActions.slash)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Slash Without Command', () => {
|
||||
it('should not match single slash', () => {
|
||||
const result = matchAction('/', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should not match unregistered commands', () => {
|
||||
const result = matchAction('/unknown', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty query', () => {
|
||||
const result = matchAction('', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle whitespace only', () => {
|
||||
const result = matchAction(' ', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle regular text without actions', () => {
|
||||
const result = matchAction('search something', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle special characters', () => {
|
||||
const result = matchAction('#tag', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle multiple @ or /', () => {
|
||||
expect(matchAction('@@app', mockActions)).toBeUndefined()
|
||||
expect(matchAction('//theme', mockActions)).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Mode-based Filtering', () => {
|
||||
it('should filter direct mode commands from matching', () => {
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
|
||||
{ name: 'test', mode: 'direct' },
|
||||
])
|
||||
|
||||
const result = matchAction('/test', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should allow submenu mode commands to match', () => {
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
|
||||
{ name: 'test', mode: 'submenu' },
|
||||
])
|
||||
|
||||
const result = matchAction('/test', mockActions)
|
||||
expect(result).toBe(mockActions.slash)
|
||||
})
|
||||
|
||||
it('should treat undefined mode as submenu', () => {
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
|
||||
{ name: 'test' }, // No mode specified
|
||||
])
|
||||
|
||||
const result = matchAction('/test', mockActions)
|
||||
expect(result).toBe(mockActions.slash)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry Integration', () => {
|
||||
it('should call getAllCommands when matching slash', () => {
|
||||
matchAction('/theme', mockActions)
|
||||
expect(slashCommandRegistry.getAllCommands).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not call getAllCommands for @ actions', () => {
|
||||
matchAction('@app', mockActions)
|
||||
expect(slashCommandRegistry.getAllCommands).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle empty command list', () => {
|
||||
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([])
|
||||
const result = matchAction('/anything', mockActions)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -9,8 +9,10 @@ import type { MockedFunction } from 'vitest'
|
||||
* 4. Ensure errors don't propagate to UI layer causing "search failed"
|
||||
*/
|
||||
|
||||
import { appScope, knowledgeScope, pluginScope, searchAnything } from '@/app/components/goto-anything/actions'
|
||||
import { searchApps, searchDatasets, searchPlugins } from '@/service/use-goto-anything'
|
||||
import { Actions, searchAnything } from '@/app/components/goto-anything/actions'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import { postMarketplace } from '@/service/base'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
|
||||
// Mock react-i18next before importing modules that use it
|
||||
vi.mock('react-i18next', () => ({
|
||||
@ -20,17 +22,22 @@ vi.mock('react-i18next', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock the new oRPC-based service functions
|
||||
vi.mock('@/service/use-goto-anything', () => ({
|
||||
searchApps: vi.fn(),
|
||||
searchDatasets: vi.fn(),
|
||||
searchPlugins: vi.fn(),
|
||||
// Mock API functions
|
||||
vi.mock('@/service/base', () => ({
|
||||
postMarketplace: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockSearchApps = searchApps as MockedFunction<typeof searchApps>
|
||||
const mockSearchDatasets = searchDatasets as MockedFunction<typeof searchDatasets>
|
||||
const mockSearchPlugins = searchPlugins as MockedFunction<typeof searchPlugins>
|
||||
const searchScopes = [appScope, knowledgeScope, pluginScope]
|
||||
vi.mock('@/service/apps', () => ({
|
||||
fetchAppList: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/datasets', () => ({
|
||||
fetchDatasets: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockPostMarketplace = postMarketplace as MockedFunction<typeof postMarketplace>
|
||||
const mockFetchAppList = fetchAppList as MockedFunction<typeof fetchAppList>
|
||||
const mockFetchDatasets = fetchDatasets as MockedFunction<typeof fetchDatasets>
|
||||
|
||||
describe('GotoAnything Search Error Handling', () => {
|
||||
beforeEach(() => {
|
||||
@ -48,33 +55,45 @@ describe('GotoAnything Search Error Handling', () => {
|
||||
describe('@plugin search error handling', () => {
|
||||
it('should return empty array when API fails instead of throwing error', async () => {
|
||||
// Mock marketplace API failure (403 permission denied)
|
||||
mockSearchPlugins.mockRejectedValue(new Error('HTTP 403: Forbidden'))
|
||||
mockPostMarketplace.mockRejectedValue(new Error('HTTP 403: Forbidden'))
|
||||
|
||||
const result = await pluginScope.search('@plugin', 'test', 'en')
|
||||
const pluginAction = Actions.plugin
|
||||
|
||||
// Directly call plugin action's search method
|
||||
const result = await pluginAction.search('@plugin', 'test', 'en')
|
||||
|
||||
// Should return empty array instead of throwing error
|
||||
expect(result).toEqual([])
|
||||
expect(mockSearchPlugins).toHaveBeenCalledWith('test')
|
||||
expect(mockPostMarketplace).toHaveBeenCalledWith('/plugins/search/advanced', {
|
||||
body: {
|
||||
page: 1,
|
||||
page_size: 10,
|
||||
query: 'test',
|
||||
type: 'plugin',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should return empty array when user has no plugin data', async () => {
|
||||
// Mock marketplace returning empty data
|
||||
mockSearchPlugins.mockResolvedValue({
|
||||
data: { plugins: [], total: 0 },
|
||||
mockPostMarketplace.mockResolvedValue({
|
||||
data: { plugins: [] },
|
||||
})
|
||||
|
||||
const result = await pluginScope.search('@plugin', '', 'en')
|
||||
const pluginAction = Actions.plugin
|
||||
const result = await pluginAction.search('@plugin', '', 'en')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should return empty array when API returns unexpected data structure', async () => {
|
||||
// Mock API returning unexpected data structure
|
||||
mockSearchPlugins.mockResolvedValue({
|
||||
mockPostMarketplace.mockResolvedValue({
|
||||
data: null,
|
||||
} as any)
|
||||
})
|
||||
|
||||
const result = await pluginScope.search('@plugin', 'test', 'en')
|
||||
const pluginAction = Actions.plugin
|
||||
const result = await pluginAction.search('@plugin', 'test', 'en')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
@ -83,18 +102,20 @@ describe('GotoAnything Search Error Handling', () => {
|
||||
describe('Other search types error handling', () => {
|
||||
it('@app search should return empty array when API fails', async () => {
|
||||
// Mock app API failure
|
||||
mockSearchApps.mockRejectedValue(new Error('API Error'))
|
||||
mockFetchAppList.mockRejectedValue(new Error('API Error'))
|
||||
|
||||
const result = await appScope.search('@app', 'test', 'en')
|
||||
const appAction = Actions.app
|
||||
const result = await appAction.search('@app', 'test', 'en')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('@knowledge search should return empty array when API fails', async () => {
|
||||
// Mock knowledge API failure
|
||||
mockSearchDatasets.mockRejectedValue(new Error('API Error'))
|
||||
mockFetchDatasets.mockRejectedValue(new Error('API Error'))
|
||||
|
||||
const result = await knowledgeScope.search('@knowledge', 'test', 'en')
|
||||
const knowledgeAction = Actions.knowledge
|
||||
const result = await knowledgeAction.search('@knowledge', 'test', 'en')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
@ -103,11 +124,11 @@ describe('GotoAnything Search Error Handling', () => {
|
||||
describe('Unified search entry error handling', () => {
|
||||
it('regular search (without @prefix) should return successful results even when partial APIs fail', async () => {
|
||||
// Set app and knowledge success, plugin failure
|
||||
mockSearchApps.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
|
||||
mockSearchDatasets.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
|
||||
mockSearchPlugins.mockRejectedValue(new Error('Plugin API failed'))
|
||||
mockFetchAppList.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
|
||||
mockFetchDatasets.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
|
||||
mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed'))
|
||||
|
||||
const result = await searchAnything('en', 'test', undefined, searchScopes)
|
||||
const result = await searchAnything('en', 'test')
|
||||
|
||||
// Should return successful results even if plugin search fails
|
||||
expect(result).toEqual([])
|
||||
@ -116,9 +137,10 @@ describe('GotoAnything Search Error Handling', () => {
|
||||
|
||||
it('@plugin dedicated search should return empty array when API fails', async () => {
|
||||
// Mock plugin API failure
|
||||
mockSearchPlugins.mockRejectedValue(new Error('Plugin service unavailable'))
|
||||
mockPostMarketplace.mockRejectedValue(new Error('Plugin service unavailable'))
|
||||
|
||||
const result = await searchAnything('en', '@plugin test', pluginScope, searchScopes)
|
||||
const pluginAction = Actions.plugin
|
||||
const result = await searchAnything('en', '@plugin test', pluginAction)
|
||||
|
||||
// Should return empty array instead of throwing error
|
||||
expect(result).toEqual([])
|
||||
@ -126,9 +148,10 @@ describe('GotoAnything Search Error Handling', () => {
|
||||
|
||||
it('@app dedicated search should return empty array when API fails', async () => {
|
||||
// Mock app API failure
|
||||
mockSearchApps.mockRejectedValue(new Error('App service unavailable'))
|
||||
mockFetchAppList.mockRejectedValue(new Error('App service unavailable'))
|
||||
|
||||
const result = await searchAnything('en', '@app test', appScope, searchScopes)
|
||||
const appAction = Actions.app
|
||||
const result = await searchAnything('en', '@app test', appAction)
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
@ -137,14 +160,14 @@ describe('GotoAnything Search Error Handling', () => {
|
||||
describe('Error handling consistency validation', () => {
|
||||
it('all search types should return empty array when encountering errors', async () => {
|
||||
// Mock all APIs to fail
|
||||
mockSearchPlugins.mockRejectedValue(new Error('Plugin API failed'))
|
||||
mockSearchApps.mockRejectedValue(new Error('App API failed'))
|
||||
mockSearchDatasets.mockRejectedValue(new Error('Dataset API failed'))
|
||||
mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed'))
|
||||
mockFetchAppList.mockRejectedValue(new Error('App API failed'))
|
||||
mockFetchDatasets.mockRejectedValue(new Error('Dataset API failed'))
|
||||
|
||||
const actions = [
|
||||
{ name: '@plugin', action: pluginScope },
|
||||
{ name: '@app', action: appScope },
|
||||
{ name: '@knowledge', action: knowledgeScope },
|
||||
{ name: '@plugin', action: Actions.plugin },
|
||||
{ name: '@app', action: Actions.app },
|
||||
{ name: '@knowledge', action: Actions.knowledge },
|
||||
]
|
||||
|
||||
for (const { name, action } of actions) {
|
||||
@ -156,9 +179,9 @@ describe('GotoAnything Search Error Handling', () => {
|
||||
|
||||
describe('Edge case testing', () => {
|
||||
it('empty search term should be handled properly', async () => {
|
||||
mockSearchPlugins.mockResolvedValue({ data: { plugins: [], total: 0 } })
|
||||
mockPostMarketplace.mockResolvedValue({ data: { plugins: [] } })
|
||||
|
||||
const result = await searchAnything('en', '@plugin ', pluginScope, searchScopes)
|
||||
const result = await searchAnything('en', '@plugin ', Actions.plugin)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
@ -166,17 +189,17 @@ describe('GotoAnything Search Error Handling', () => {
|
||||
const timeoutError = new Error('Network timeout')
|
||||
timeoutError.name = 'TimeoutError'
|
||||
|
||||
mockSearchPlugins.mockRejectedValue(timeoutError)
|
||||
mockPostMarketplace.mockRejectedValue(timeoutError)
|
||||
|
||||
const result = await searchAnything('en', '@plugin test', pluginScope, searchScopes)
|
||||
const result = await searchAnything('en', '@plugin test', Actions.plugin)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('JSON parsing errors should be handled correctly', async () => {
|
||||
const parseError = new SyntaxError('Unexpected token in JSON')
|
||||
mockSearchPlugins.mockRejectedValue(parseError)
|
||||
mockPostMarketplace.mockRejectedValue(parseError)
|
||||
|
||||
const result = await searchAnything('en', '@plugin test', pluginScope, searchScopes)
|
||||
const result = await searchAnything('en', '@plugin test', Actions.plugin)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import Cookies from 'js-cookie'
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import { parseAsBoolean, useQueryState } from 'nuqs'
|
||||
import { parseAsString, useQueryState } from 'nuqs'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import {
|
||||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
@ -28,7 +28,7 @@ export const AppInitializer = ({
|
||||
const [init, setInit] = useState(false)
|
||||
const [oauthNewUser, setOauthNewUser] = useQueryState(
|
||||
'oauth_new_user',
|
||||
parseAsBoolean.withOptions({ history: 'replace' }),
|
||||
parseAsString.withOptions({ history: 'replace' }),
|
||||
)
|
||||
|
||||
const isSetupFinished = useCallback(async () => {
|
||||
@ -46,7 +46,7 @@ export const AppInitializer = ({
|
||||
(async () => {
|
||||
const action = searchParams.get('action')
|
||||
|
||||
if (oauthNewUser) {
|
||||
if (oauthNewUser === 'true') {
|
||||
let utmInfo = null
|
||||
const utmInfoStr = Cookies.get('utm_info')
|
||||
if (utmInfoStr) {
|
||||
|
||||
@ -10,15 +10,9 @@ type VersionSelectorProps = {
|
||||
versionLen: number
|
||||
value: number
|
||||
onChange: (index: number) => void
|
||||
contentClassName?: string
|
||||
}
|
||||
|
||||
const VersionSelector: React.FC<VersionSelectorProps> = ({
|
||||
versionLen,
|
||||
value,
|
||||
onChange,
|
||||
contentClassName,
|
||||
}) => {
|
||||
const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, onChange }) => {
|
||||
const { t } = useTranslation()
|
||||
const [isOpen, {
|
||||
setFalse: handleOpenFalse,
|
||||
@ -70,7 +64,6 @@ const VersionSelector: React.FC<VersionSelectorProps> = ({
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className={cn(
|
||||
'z-[99]',
|
||||
contentClassName,
|
||||
)}
|
||||
>
|
||||
<div
|
||||
|
||||
@ -62,19 +62,19 @@ const AppCard = ({
|
||||
{app.description}
|
||||
</div>
|
||||
</div>
|
||||
{(canCreate || isTrialApp) && (
|
||||
{canCreate && (
|
||||
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', canCreate && 'grid-cols-2')}>
|
||||
{canCreate && (
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', isTrialApp && 'grid-cols-2')}>
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
</Button>
|
||||
{isTrialApp && (
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -3,6 +3,8 @@ import type { FC, ReactNode } from 'react'
|
||||
import type { SliceProps } from './type'
|
||||
import { autoUpdate, flip, FloatingFocusManager, offset, shift, useDismiss, useFloating, useHover, useInteractions, useRole } from '@floating-ui/react'
|
||||
import { RiDeleteBinLine } from '@remixicon/react'
|
||||
// @ts-expect-error no types available
|
||||
import lineClamp from 'line-clamp'
|
||||
import { useState } from 'react'
|
||||
import ActionButton, { ActionButtonState } from '@/app/components/base/action-button'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@ -56,8 +58,12 @@ export const EditSlice: FC<EditSliceProps> = (props) => {
|
||||
<>
|
||||
<SliceContainer
|
||||
{...rest}
|
||||
className={cn('mr-0 line-clamp-4 block', className)}
|
||||
ref={refs.setReference}
|
||||
className={cn('mr-0 block', className)}
|
||||
ref={(ref) => {
|
||||
refs.setReference(ref)
|
||||
if (ref)
|
||||
lineClamp(ref, 4)
|
||||
}}
|
||||
{...getReferenceProps()}
|
||||
>
|
||||
<SliceLabel
|
||||
|
||||
@ -74,15 +74,11 @@ const AppCard = ({
|
||||
</div>
|
||||
{isExplore && (canCreate || isTrialApp) && (
|
||||
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 space-x-2', canCreate && 'grid-cols-2')}>
|
||||
{
|
||||
canCreate && (
|
||||
<Button variant="primary" className="h-7" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('appCard.addToWorkspace', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
<div className={cn('grid h-8 w-full grid-cols-2 space-x-2')}>
|
||||
<Button variant="primary" className="h-7" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('appCard.addToWorkspace', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
<Button className="h-7" onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
|
||||
@ -16,14 +16,6 @@ vi.mock('react-i18next', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal() as object
|
||||
return {
|
||||
...actual,
|
||||
IS_CLOUD_EDITION: true,
|
||||
}
|
||||
})
|
||||
|
||||
const mockUseGetTryAppInfo = vi.fn()
|
||||
|
||||
vi.mock('@/service/use-try-app', () => ({
|
||||
|
||||
@ -14,14 +14,6 @@ vi.mock('react-i18next', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal() as object
|
||||
return {
|
||||
...actual,
|
||||
IS_CLOUD_EDITION: true,
|
||||
}
|
||||
})
|
||||
|
||||
describe('Tab', () => {
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import type { AppSearchResult, ScopeDescriptor } from './types'
|
||||
import type { ActionItem, AppSearchResult } from './types'
|
||||
import type { App } from '@/types/app'
|
||||
import { searchApps } from '@/service/use-goto-anything'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import { getRedirectionPath } from '@/utils/app-redirection'
|
||||
import { AppTypeIcon } from '../../app/type-selector'
|
||||
import AppIcon from '../../base/app-icon'
|
||||
import { ACTION_KEYS } from '../constants'
|
||||
|
||||
const parser = (apps: App[]): AppSearchResult[] => {
|
||||
return apps.map(app => ({
|
||||
@ -36,14 +35,21 @@ const parser = (apps: App[]): AppSearchResult[] => {
|
||||
}))
|
||||
}
|
||||
|
||||
export const appScope: ScopeDescriptor = {
|
||||
id: 'app',
|
||||
shortcut: ACTION_KEYS.APP,
|
||||
export const appAction: ActionItem = {
|
||||
key: '@app',
|
||||
shortcut: '@app',
|
||||
title: 'Search Applications',
|
||||
description: 'Search and navigate to your applications',
|
||||
// action,
|
||||
search: async (_, searchTerm = '', _locale) => {
|
||||
try {
|
||||
const response = await searchApps(searchTerm)
|
||||
const response = await fetchAppList({
|
||||
url: 'apps',
|
||||
params: {
|
||||
page: 1,
|
||||
name: searchTerm,
|
||||
},
|
||||
})
|
||||
const apps = response?.data || []
|
||||
return parser(apps)
|
||||
}
|
||||
|
||||
@ -1,59 +0,0 @@
|
||||
import type { SlashCommandHandler } from './types'
|
||||
import { RiSparklingFill } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { getI18n } from 'react-i18next'
|
||||
import { isInWorkflowPage, VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
|
||||
import { registerCommands, unregisterCommands } from './command-bus'
|
||||
|
||||
type BananaDeps = Record<string, never>
|
||||
|
||||
const BANANA_PROMPT_EXAMPLE = 'Summarize a document, classify sentiment, then notify Slack'
|
||||
|
||||
const dispatchVibeCommand = (input?: string) => {
|
||||
if (typeof document === 'undefined')
|
||||
return
|
||||
|
||||
document.dispatchEvent(new CustomEvent(VIBE_COMMAND_EVENT, { detail: { dsl: input } }))
|
||||
}
|
||||
|
||||
export const bananaCommand: SlashCommandHandler<BananaDeps> = {
|
||||
name: 'banana',
|
||||
description: getI18n().t('gotoAnything.actions.vibeDesc', { ns: 'app' }),
|
||||
mode: 'submenu',
|
||||
isAvailable: () => isInWorkflowPage(),
|
||||
|
||||
async search(args: string, locale: string = 'en') {
|
||||
const trimmed = args.trim()
|
||||
const hasInput = !!trimmed
|
||||
|
||||
return [{
|
||||
id: 'banana-vibe',
|
||||
title: getI18n().t('gotoAnything.actions.vibeTitle', { ns: 'app', lng: locale }) || 'Banana',
|
||||
description: hasInput
|
||||
? getI18n().t('gotoAnything.actions.vibeDesc', { ns: 'app', lng: locale })
|
||||
: getI18n().t('gotoAnything.actions.vibeHint', { ns: 'app', lng: locale, prompt: BANANA_PROMPT_EXAMPLE }),
|
||||
type: 'command' as const,
|
||||
icon: (
|
||||
<div className="flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg">
|
||||
<RiSparklingFill className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
),
|
||||
data: {
|
||||
command: 'workflow.vibe',
|
||||
args: { dsl: trimmed },
|
||||
},
|
||||
}]
|
||||
},
|
||||
|
||||
register(_deps: BananaDeps) {
|
||||
registerCommands({
|
||||
'workflow.vibe': async (args) => {
|
||||
dispatchVibeCommand(args?.dsl)
|
||||
},
|
||||
})
|
||||
},
|
||||
|
||||
unregister() {
|
||||
unregisterCommands(['workflow.vibe'])
|
||||
},
|
||||
}
|
||||
@ -9,7 +9,7 @@ export {
|
||||
export { slashCommandRegistry, SlashCommandRegistry } from './registry'
|
||||
|
||||
// Command system exports
|
||||
export { slashScope } from './slash'
|
||||
export { slashAction } from './slash'
|
||||
export { registerSlashCommands, SlashCommandProvider, unregisterSlashCommands } from './slash'
|
||||
|
||||
export type { SlashCommandHandler } from './types'
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
import type { CommandSearchResult } from '../types'
|
||||
import type { SlashCommandHandler } from './types'
|
||||
import type { Locale } from '@/i18n-config/language'
|
||||
import { getI18n } from 'react-i18next'
|
||||
import { languages } from '@/i18n-config/language'
|
||||
import { registerCommands, unregisterCommands } from './command-bus'
|
||||
|
||||
// Language dependency types
|
||||
type LanguageDeps = {
|
||||
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
|
||||
setLocale?: (locale: string) => Promise<void>
|
||||
}
|
||||
|
||||
const buildLanguageCommands = (query: string): CommandSearchResult[] => {
|
||||
|
||||
@ -6,21 +6,20 @@ import type { SlashCommandHandler } from './types'
|
||||
* Responsible for managing registration, lookup, and search of all slash commands
|
||||
*/
|
||||
export class SlashCommandRegistry {
|
||||
private commands = new Map<string, SlashCommandHandler<unknown>>()
|
||||
private commandDeps = new Map<string, unknown>()
|
||||
private commands = new Map<string, SlashCommandHandler>()
|
||||
private commandDeps = new Map<string, any>()
|
||||
|
||||
/**
|
||||
* Register command handler
|
||||
*/
|
||||
register<TDeps = unknown>(handler: SlashCommandHandler<TDeps>, deps?: TDeps) {
|
||||
register<TDeps = any>(handler: SlashCommandHandler<TDeps>, deps?: TDeps) {
|
||||
// Register main command name
|
||||
// Cast to unknown first, then to SlashCommandHandler<unknown> to handle generic type variance
|
||||
this.commands.set(handler.name, handler as SlashCommandHandler<unknown>)
|
||||
this.commands.set(handler.name, handler)
|
||||
|
||||
// Register aliases
|
||||
if (handler.aliases) {
|
||||
handler.aliases.forEach((alias) => {
|
||||
this.commands.set(alias, handler as SlashCommandHandler<unknown>)
|
||||
this.commands.set(alias, handler)
|
||||
})
|
||||
}
|
||||
|
||||
@ -58,7 +57,7 @@ export class SlashCommandRegistry {
|
||||
/**
|
||||
* Find command handler
|
||||
*/
|
||||
findCommand(commandName: string): SlashCommandHandler<unknown> | undefined {
|
||||
findCommand(commandName: string): SlashCommandHandler | undefined {
|
||||
return this.commands.get(commandName)
|
||||
}
|
||||
|
||||
@ -66,7 +65,7 @@ export class SlashCommandRegistry {
|
||||
* Smart partial command matching
|
||||
* Prioritize alias matching, then match command name prefix
|
||||
*/
|
||||
private findBestPartialMatch(partialName: string): SlashCommandHandler<unknown> | undefined {
|
||||
private findBestPartialMatch(partialName: string): SlashCommandHandler | undefined {
|
||||
const lowerPartial = partialName.toLowerCase()
|
||||
|
||||
// First check if any alias starts with this
|
||||
@ -82,7 +81,7 @@ export class SlashCommandRegistry {
|
||||
/**
|
||||
* Find handler by alias prefix
|
||||
*/
|
||||
private findHandlerByAliasPrefix(prefix: string): SlashCommandHandler<unknown> | undefined {
|
||||
private findHandlerByAliasPrefix(prefix: string): SlashCommandHandler | undefined {
|
||||
for (const handler of this.getAllCommands()) {
|
||||
if (handler.aliases?.some(alias => alias.toLowerCase().startsWith(prefix)))
|
||||
return handler
|
||||
@ -93,7 +92,7 @@ export class SlashCommandRegistry {
|
||||
/**
|
||||
* Find handler by name prefix
|
||||
*/
|
||||
private findHandlerByNamePrefix(prefix: string): SlashCommandHandler<unknown> | undefined {
|
||||
private findHandlerByNamePrefix(prefix: string): SlashCommandHandler | undefined {
|
||||
return this.getAllCommands().find(handler =>
|
||||
handler.name.toLowerCase().startsWith(prefix),
|
||||
)
|
||||
@ -102,8 +101,8 @@ export class SlashCommandRegistry {
|
||||
/**
|
||||
* Get all registered commands (deduplicated)
|
||||
*/
|
||||
getAllCommands(): SlashCommandHandler<unknown>[] {
|
||||
const uniqueCommands = new Map<string, SlashCommandHandler<unknown>>()
|
||||
getAllCommands(): SlashCommandHandler[] {
|
||||
const uniqueCommands = new Map<string, SlashCommandHandler>()
|
||||
this.commands.forEach((handler) => {
|
||||
uniqueCommands.set(handler.name, handler)
|
||||
})
|
||||
@ -114,7 +113,7 @@ export class SlashCommandRegistry {
|
||||
* Get all available commands in current context (deduplicated and filtered)
|
||||
* Commands without isAvailable method are considered always available
|
||||
*/
|
||||
getAvailableCommands(): SlashCommandHandler<unknown>[] {
|
||||
getAvailableCommands(): SlashCommandHandler[] {
|
||||
return this.getAllCommands().filter(handler => this.isCommandAvailable(handler))
|
||||
}
|
||||
|
||||
@ -229,7 +228,7 @@ export class SlashCommandRegistry {
|
||||
/**
|
||||
* Get command dependencies
|
||||
*/
|
||||
getCommandDependencies(commandName: string): unknown {
|
||||
getCommandDependencies(commandName: string): any {
|
||||
return this.commandDeps.get(commandName)
|
||||
}
|
||||
|
||||
@ -237,7 +236,7 @@ export class SlashCommandRegistry {
|
||||
* Determine if a command is available in the current context.
|
||||
* Defaults to true when a handler does not implement the guard.
|
||||
*/
|
||||
private isCommandAvailable(handler: SlashCommandHandler<unknown>) {
|
||||
private isCommandAvailable(handler: SlashCommandHandler) {
|
||||
return handler.isAvailable?.() ?? true
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
'use client'
|
||||
import type { ScopeDescriptor } from '../types'
|
||||
import type { SlashCommandDependencies } from './types'
|
||||
import type { ActionItem } from '../types'
|
||||
import { useTheme } from 'next-themes'
|
||||
import { useEffect } from 'react'
|
||||
import { getI18n } from 'react-i18next'
|
||||
import { setLocaleOnClient } from '@/i18n-config'
|
||||
import { ACTION_KEYS } from '../../constants'
|
||||
import { accountCommand } from './account'
|
||||
import { bananaCommand } from './banana'
|
||||
import { executeCommand } from './command-bus'
|
||||
import { communityCommand } from './community'
|
||||
import { docsCommand } from './docs'
|
||||
import { forumCommand } from './forum'
|
||||
@ -18,11 +16,17 @@ import { zenCommand } from './zen'
|
||||
|
||||
const i18n = getI18n()
|
||||
|
||||
export const slashScope: ScopeDescriptor = {
|
||||
id: 'slash',
|
||||
shortcut: ACTION_KEYS.SLASH,
|
||||
export const slashAction: ActionItem = {
|
||||
key: '/',
|
||||
shortcut: '/',
|
||||
title: i18n.t('gotoAnything.actions.slashTitle', { ns: 'app' }),
|
||||
description: i18n.t('gotoAnything.actions.slashDesc', { ns: 'app' }),
|
||||
action: (result) => {
|
||||
if (result.type !== 'command')
|
||||
return
|
||||
const { command, args } = result.data
|
||||
executeCommand(command, args)
|
||||
},
|
||||
search: async (query, _searchTerm = '') => {
|
||||
// Delegate all search logic to the command registry system
|
||||
return slashCommandRegistry.search(query, i18n.language)
|
||||
@ -30,7 +34,7 @@ export const slashScope: ScopeDescriptor = {
|
||||
}
|
||||
|
||||
// Register/unregister default handlers for slash commands with external dependencies.
|
||||
export const registerSlashCommands = (deps: SlashCommandDependencies) => {
|
||||
export const registerSlashCommands = (deps: Record<string, any>) => {
|
||||
// Register command handlers to the registry system with their respective dependencies
|
||||
slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme })
|
||||
slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale })
|
||||
@ -39,7 +43,6 @@ export const registerSlashCommands = (deps: SlashCommandDependencies) => {
|
||||
slashCommandRegistry.register(communityCommand, {})
|
||||
slashCommandRegistry.register(accountCommand, {})
|
||||
slashCommandRegistry.register(zenCommand, {})
|
||||
slashCommandRegistry.register(bananaCommand, {})
|
||||
}
|
||||
|
||||
export const unregisterSlashCommands = () => {
|
||||
@ -51,7 +54,6 @@ export const unregisterSlashCommands = () => {
|
||||
slashCommandRegistry.unregister('community')
|
||||
slashCommandRegistry.unregister('account')
|
||||
slashCommandRegistry.unregister('zen')
|
||||
slashCommandRegistry.unregister('banana')
|
||||
}
|
||||
|
||||
export const SlashCommandProvider = () => {
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import type { CommandSearchResult } from '../types'
|
||||
import type { Locale } from '@/i18n-config/language'
|
||||
|
||||
/**
|
||||
* Slash command handler interface
|
||||
* Each slash command should implement this interface
|
||||
*/
|
||||
export type SlashCommandHandler<TDeps = unknown> = {
|
||||
export type SlashCommandHandler<TDeps = any> = {
|
||||
/** Command name (e.g., 'theme', 'language') */
|
||||
name: string
|
||||
|
||||
@ -52,31 +51,3 @@ export type SlashCommandHandler<TDeps = unknown> = {
|
||||
*/
|
||||
unregister?: () => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Theme command dependencies
|
||||
*/
|
||||
export type ThemeCommandDeps = {
|
||||
setTheme?: (value: 'light' | 'dark' | 'system') => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Language command dependencies
|
||||
*/
|
||||
export type LanguageCommandDeps = {
|
||||
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
|
||||
}
|
||||
|
||||
/**
|
||||
* Commands without external dependencies
|
||||
*/
|
||||
export type NoDepsCommandDeps = Record<string, never>
|
||||
|
||||
/**
|
||||
* Union type of all slash command dependencies
|
||||
* Used for type-safe dependency injection in registerSlashCommands
|
||||
*/
|
||||
export type SlashCommandDependencies = {
|
||||
setTheme?: (value: 'light' | 'dark' | 'system') => void
|
||||
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
|
||||
}
|
||||
|
||||
@ -1,59 +0,0 @@
|
||||
import type { SlashCommandHandler } from './types'
|
||||
import { RiSparklingFill } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { getI18n } from 'react-i18next'
|
||||
import { isInWorkflowPage, VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
|
||||
import { registerCommands, unregisterCommands } from './command-bus'
|
||||
|
||||
type VibeDeps = Record<string, never>
|
||||
|
||||
const VIBE_PROMPT_EXAMPLE = 'Summarize a document, classify sentiment, then notify Slack'
|
||||
|
||||
const dispatchVibeCommand = (input?: string) => {
|
||||
if (typeof document === 'undefined')
|
||||
return
|
||||
|
||||
document.dispatchEvent(new CustomEvent(VIBE_COMMAND_EVENT, { detail: { dsl: input } }))
|
||||
}
|
||||
|
||||
export const vibeCommand: SlashCommandHandler<VibeDeps> = {
|
||||
name: 'vibe',
|
||||
description: getI18n().t('gotoAnything.actions.vibeDesc', { ns: 'app' }),
|
||||
mode: 'submenu',
|
||||
isAvailable: () => isInWorkflowPage(),
|
||||
|
||||
async search(args: string, locale: string = 'en') {
|
||||
const trimmed = args.trim()
|
||||
const hasInput = !!trimmed
|
||||
|
||||
return [{
|
||||
id: 'vibe',
|
||||
title: getI18n().t('gotoAnything.actions.vibeTitle', { ns: 'app', lng: locale }) || 'Vibe',
|
||||
description: hasInput
|
||||
? getI18n().t('gotoAnything.actions.vibeDesc', { ns: 'app', lng: locale })
|
||||
: getI18n().t('gotoAnything.actions.vibeHint', { ns: 'app', lng: locale, prompt: VIBE_PROMPT_EXAMPLE }),
|
||||
type: 'command' as const,
|
||||
icon: (
|
||||
<div className="flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg">
|
||||
<RiSparklingFill className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
),
|
||||
data: {
|
||||
command: 'workflow.vibe',
|
||||
args: { dsl: trimmed },
|
||||
},
|
||||
}]
|
||||
},
|
||||
|
||||
register(_deps: VibeDeps) {
|
||||
registerCommands({
|
||||
'workflow.vibe': async (args) => {
|
||||
dispatchVibeCommand(args?.dsl)
|
||||
},
|
||||
})
|
||||
},
|
||||
|
||||
unregister() {
|
||||
unregisterCommands(['workflow.vibe'])
|
||||
},
|
||||
}
|
||||
@ -3,66 +3,228 @@
|
||||
*
|
||||
* This file defines the action registry for the goto-anything search system.
|
||||
* Actions handle different types of searches: apps, knowledge bases, plugins, workflow nodes, and commands.
|
||||
*
|
||||
* ## How to Add a New Slash Command
|
||||
*
|
||||
* 1. **Create Command Handler File** (in `./commands/` directory):
|
||||
* ```typescript
|
||||
* // commands/my-command.ts
|
||||
* import type { SlashCommandHandler } from './types'
|
||||
* import type { CommandSearchResult } from '../types'
|
||||
* import { registerCommands, unregisterCommands } from './command-bus'
|
||||
*
|
||||
* interface MyCommandDeps {
|
||||
* myService?: (data: any) => Promise<void>
|
||||
* }
|
||||
*
|
||||
* export const myCommand: SlashCommandHandler<MyCommandDeps> = {
|
||||
* name: 'mycommand',
|
||||
* aliases: ['mc'], // Optional aliases
|
||||
* description: 'My custom command description',
|
||||
*
|
||||
* async search(args: string, locale: string = 'en') {
|
||||
* // Return search results based on args
|
||||
* return [{
|
||||
* id: 'my-result',
|
||||
* title: 'My Command Result',
|
||||
* description: 'Description of the result',
|
||||
* type: 'command' as const,
|
||||
* data: { command: 'my.action', args: { value: args } }
|
||||
* }]
|
||||
* },
|
||||
*
|
||||
* register(deps: MyCommandDeps) {
|
||||
* registerCommands({
|
||||
* 'my.action': async (args) => {
|
||||
* await deps.myService?.(args?.value)
|
||||
* }
|
||||
* })
|
||||
* },
|
||||
*
|
||||
* unregister() {
|
||||
* unregisterCommands(['my.action'])
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* **Example for Self-Contained Command (no external dependencies):**
|
||||
* ```typescript
|
||||
* // commands/calculator-command.ts
|
||||
* export const calculatorCommand: SlashCommandHandler = {
|
||||
* name: 'calc',
|
||||
* aliases: ['calculator'],
|
||||
* description: 'Simple calculator',
|
||||
*
|
||||
* async search(args: string) {
|
||||
* if (!args.trim()) return []
|
||||
* try {
|
||||
* // Safe math evaluation (implement proper parser in real use)
|
||||
* const result = Function('"use strict"; return (' + args + ')')()
|
||||
* return [{
|
||||
* id: 'calc-result',
|
||||
* title: `${args} = ${result}`,
|
||||
* description: 'Calculator result',
|
||||
* type: 'command' as const,
|
||||
* data: { command: 'calc.copy', args: { result: result.toString() } }
|
||||
* }]
|
||||
* } catch {
|
||||
* return [{
|
||||
* id: 'calc-error',
|
||||
* title: 'Invalid expression',
|
||||
* description: 'Please enter a valid math expression',
|
||||
* type: 'command' as const,
|
||||
* data: { command: 'calc.noop', args: {} }
|
||||
* }]
|
||||
* }
|
||||
* },
|
||||
*
|
||||
* register() {
|
||||
* registerCommands({
|
||||
* 'calc.copy': (args) => navigator.clipboard.writeText(args.result),
|
||||
* 'calc.noop': () => {} // No operation
|
||||
* })
|
||||
* },
|
||||
*
|
||||
* unregister() {
|
||||
* unregisterCommands(['calc.copy', 'calc.noop'])
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* 2. **Register Command** (in `./commands/slash.tsx`):
|
||||
* ```typescript
|
||||
* import { myCommand } from './my-command'
|
||||
* import { calculatorCommand } from './calculator-command' // For self-contained commands
|
||||
*
|
||||
* export const registerSlashCommands = (deps: Record<string, any>) => {
|
||||
* slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme })
|
||||
* slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale })
|
||||
* slashCommandRegistry.register(myCommand, { myService: deps.myService }) // With dependencies
|
||||
* slashCommandRegistry.register(calculatorCommand) // Self-contained, no dependencies
|
||||
* }
|
||||
*
|
||||
* export const unregisterSlashCommands = () => {
|
||||
* slashCommandRegistry.unregister('theme')
|
||||
* slashCommandRegistry.unregister('language')
|
||||
* slashCommandRegistry.unregister('mycommand')
|
||||
* slashCommandRegistry.unregister('calc') // Add this line
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
*
|
||||
* 3. **Update SlashCommandProvider** (in `./commands/slash.tsx`):
|
||||
* ```typescript
|
||||
* export const SlashCommandProvider = () => {
|
||||
* const theme = useTheme()
|
||||
* const myService = useMyService() // Add external dependency if needed
|
||||
*
|
||||
* useEffect(() => {
|
||||
* registerSlashCommands({
|
||||
* setTheme: theme.setTheme, // Required for theme command
|
||||
* setLocale: setLocaleOnClient, // Required for language command
|
||||
* myService: myService, // Required for your custom command
|
||||
* // Note: calculatorCommand doesn't need dependencies, so not listed here
|
||||
* })
|
||||
* return () => unregisterSlashCommands()
|
||||
* }, [theme.setTheme, myService]) // Update dependency array for all dynamic deps
|
||||
*
|
||||
* return null
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* **Note:** Self-contained commands (like calculator) don't require dependencies but are
|
||||
* still registered through the same system for consistent lifecycle management.
|
||||
*
|
||||
* 4. **Usage**: Users can now type `/mycommand` or `/mc` to use your command
|
||||
*
|
||||
* ## Command System Architecture
|
||||
* - Commands are registered via `SlashCommandRegistry`
|
||||
* - Each command is self-contained with its own dependencies
|
||||
* - Commands support aliases for easier access
|
||||
* - Command execution is handled by the command bus system
|
||||
* - All commands should be registered through `SlashCommandProvider` for consistent lifecycle management
|
||||
*
|
||||
* ## Command Types
|
||||
* **Commands with External Dependencies:**
|
||||
* - Require external services, APIs, or React hooks
|
||||
* - Must provide dependencies in `SlashCommandProvider`
|
||||
* - Example: theme commands (needs useTheme), API commands (needs service)
|
||||
*
|
||||
* **Self-Contained Commands:**
|
||||
* - Pure logic operations, no external dependencies
|
||||
* - Still recommended to register through `SlashCommandProvider` for consistency
|
||||
* - Example: calculator, text manipulation commands
|
||||
*
|
||||
* ## Available Actions
|
||||
* - `@app` - Search applications
|
||||
* - `@knowledge` / `@kb` - Search knowledge bases
|
||||
* - `@plugin` - Search plugins
|
||||
* - `@node` - Search workflow nodes (workflow pages only)
|
||||
* - `/` - Execute slash commands (theme, language, etc.)
|
||||
*/
|
||||
|
||||
import type { ScopeContext, ScopeDescriptor, SearchResult } from './types'
|
||||
import { ACTION_KEYS } from '../constants'
|
||||
import { appScope } from './app'
|
||||
import { slashScope } from './commands'
|
||||
import type { ActionItem, SearchResult } from './types'
|
||||
import { appAction } from './app'
|
||||
import { slashAction } from './commands'
|
||||
import { slashCommandRegistry } from './commands/registry'
|
||||
import { knowledgeScope } from './knowledge'
|
||||
import { pluginScope } from './plugin'
|
||||
import { registerRagPipelineNodeScope } from './rag-pipeline-nodes'
|
||||
import { scopeRegistry, useScopeRegistry } from './scope-registry'
|
||||
import { registerWorkflowNodeScope } from './workflow-nodes'
|
||||
import { knowledgeAction } from './knowledge'
|
||||
import { pluginAction } from './plugin'
|
||||
import { ragPipelineNodesAction } from './rag-pipeline-nodes'
|
||||
import { workflowNodesAction } from './workflow-nodes'
|
||||
|
||||
let scopesInitialized = false
|
||||
// Create dynamic Actions based on context
|
||||
export const createActions = (isWorkflowPage: boolean, isRagPipelinePage: boolean) => {
|
||||
const baseActions = {
|
||||
slash: slashAction,
|
||||
app: appAction,
|
||||
knowledge: knowledgeAction,
|
||||
plugin: pluginAction,
|
||||
}
|
||||
|
||||
export const initGotoAnythingScopes = () => {
|
||||
if (scopesInitialized)
|
||||
return
|
||||
// Add appropriate node search based on context
|
||||
if (isRagPipelinePage) {
|
||||
return {
|
||||
...baseActions,
|
||||
node: ragPipelineNodesAction,
|
||||
}
|
||||
}
|
||||
else if (isWorkflowPage) {
|
||||
return {
|
||||
...baseActions,
|
||||
node: workflowNodesAction,
|
||||
}
|
||||
}
|
||||
|
||||
scopesInitialized = true
|
||||
|
||||
scopeRegistry.register(slashScope)
|
||||
scopeRegistry.register(appScope)
|
||||
scopeRegistry.register(knowledgeScope)
|
||||
scopeRegistry.register(pluginScope)
|
||||
registerWorkflowNodeScope()
|
||||
registerRagPipelineNodeScope()
|
||||
// Default actions without node search
|
||||
return baseActions
|
||||
}
|
||||
|
||||
export const useGotoAnythingScopes = (context: ScopeContext) => {
|
||||
initGotoAnythingScopes()
|
||||
return useScopeRegistry(context)
|
||||
// Legacy export for backward compatibility
|
||||
export const Actions = {
|
||||
slash: slashAction,
|
||||
app: appAction,
|
||||
knowledge: knowledgeAction,
|
||||
plugin: pluginAction,
|
||||
node: workflowNodesAction,
|
||||
}
|
||||
|
||||
const isSlashScope = (scope: ScopeDescriptor) => {
|
||||
if (scope.shortcut === ACTION_KEYS.SLASH)
|
||||
return true
|
||||
return scope.aliases?.includes(ACTION_KEYS.SLASH) ?? false
|
||||
}
|
||||
|
||||
const getScopeShortcuts = (scope: ScopeDescriptor) => [scope.shortcut, ...(scope.aliases ?? [])]
|
||||
|
||||
export const searchAnything = async (
|
||||
locale: string,
|
||||
query: string,
|
||||
scope: ScopeDescriptor | undefined,
|
||||
scopes: ScopeDescriptor[],
|
||||
actionItem?: ActionItem,
|
||||
dynamicActions?: Record<string, ActionItem>,
|
||||
): Promise<SearchResult[]> => {
|
||||
const trimmedQuery = query.trim()
|
||||
|
||||
if (scope) {
|
||||
if (actionItem) {
|
||||
const escapeRegExp = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
|
||||
const shortcuts = getScopeShortcuts(scope).map(escapeRegExp)
|
||||
const prefixPattern = new RegExp(`^(${shortcuts.join('|')})\\s*`)
|
||||
const prefixPattern = new RegExp(`^(${escapeRegExp(actionItem.key)}|${escapeRegExp(actionItem.shortcut)})\\s*`)
|
||||
const searchTerm = trimmedQuery.replace(prefixPattern, '').trim()
|
||||
try {
|
||||
return await scope.search(query, searchTerm, locale)
|
||||
return await actionItem.search(query, searchTerm, locale)
|
||||
}
|
||||
catch (error) {
|
||||
console.warn(`Search failed for ${scope.id}:`, error)
|
||||
console.warn(`Search failed for ${actionItem.key}:`, error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
@ -70,19 +232,19 @@ export const searchAnything = async (
|
||||
if (trimmedQuery.startsWith('@') || trimmedQuery.startsWith('/'))
|
||||
return []
|
||||
|
||||
// Filter out slash commands from general search
|
||||
const searchScopes = scopes.filter(scope => !isSlashScope(scope))
|
||||
const globalSearchActions = Object.values(dynamicActions || Actions)
|
||||
// Exclude slash commands from general search results
|
||||
.filter(action => action.key !== '/')
|
||||
|
||||
// Use Promise.allSettled to handle partial failures gracefully
|
||||
const searchPromises = searchScopes.map(async (action) => {
|
||||
const actionId = action.id
|
||||
const searchPromises = globalSearchActions.map(async (action) => {
|
||||
try {
|
||||
const results = await action.search(query, query, locale)
|
||||
return { success: true, data: results, actionType: actionId }
|
||||
return { success: true, data: results, actionType: action.key }
|
||||
}
|
||||
catch (error) {
|
||||
console.warn(`Search failed for ${actionId}:`, error)
|
||||
return { success: false, data: [], actionType: actionId, error }
|
||||
console.warn(`Search failed for ${action.key}:`, error)
|
||||
return { success: false, data: [], actionType: action.key, error }
|
||||
}
|
||||
})
|
||||
|
||||
@ -96,7 +258,7 @@ export const searchAnything = async (
|
||||
allResults.push(...result.value.data)
|
||||
}
|
||||
else {
|
||||
const actionKey = searchScopes[index]?.id || 'unknown'
|
||||
const actionKey = globalSearchActions[index]?.key || 'unknown'
|
||||
failedActions.push(actionKey)
|
||||
}
|
||||
})
|
||||
@ -107,31 +269,31 @@ export const searchAnything = async (
|
||||
return allResults
|
||||
}
|
||||
|
||||
// ...
|
||||
|
||||
export const matchAction = (query: string, scopes: ScopeDescriptor[]) => {
|
||||
const escapeRegExp = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
|
||||
return scopes.find((scope) => {
|
||||
export const matchAction = (query: string, actions: Record<string, ActionItem>) => {
|
||||
return Object.values(actions).find((action) => {
|
||||
// Special handling for slash commands
|
||||
if (isSlashScope(scope)) {
|
||||
if (action.key === '/') {
|
||||
// Get all registered commands from the registry
|
||||
const allCommands = slashCommandRegistry.getAllCommands()
|
||||
|
||||
// Check if query matches any registered command
|
||||
return allCommands.some((cmd) => {
|
||||
const cmdPattern = `/${cmd.name}`
|
||||
|
||||
// For direct mode commands, don't match (keep in command selector)
|
||||
if (cmd.mode === 'direct')
|
||||
return false
|
||||
|
||||
// For submenu mode commands, match when complete command is entered
|
||||
return query === cmdPattern || query.startsWith(`${cmdPattern} `)
|
||||
})
|
||||
}
|
||||
|
||||
// Check if query matches shortcut (exact or prefix)
|
||||
// Only match if it's the full shortcut followed by space
|
||||
const shortcuts = getScopeShortcuts(scope).map(escapeRegExp)
|
||||
const reg = new RegExp(`^(${shortcuts.join('|')})(?:\\s|$)`)
|
||||
const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`)
|
||||
return reg.test(query)
|
||||
})
|
||||
}
|
||||
|
||||
export * from './commands'
|
||||
export * from './scope-registry'
|
||||
export * from './types'
|
||||
export { appScope, knowledgeScope, pluginScope }
|
||||
export { appAction, knowledgeAction, pluginAction, workflowNodesAction }
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import type { KnowledgeSearchResult, ScopeDescriptor } from './types'
|
||||
import type { ActionItem, KnowledgeSearchResult } from './types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { searchDatasets } from '@/service/use-goto-anything'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { Folder } from '../../base/icons/src/vender/solid/files'
|
||||
import { ACTION_KEYS } from '../constants'
|
||||
|
||||
const EXTERNAL_PROVIDER = 'external' as const
|
||||
const isExternalProvider = (provider: string): boolean => provider === EXTERNAL_PROVIDER
|
||||
@ -31,15 +30,22 @@ const parser = (datasets: DataSet[]): KnowledgeSearchResult[] => {
|
||||
})
|
||||
}
|
||||
|
||||
export const knowledgeScope: ScopeDescriptor = {
|
||||
id: 'knowledge',
|
||||
shortcut: ACTION_KEYS.KNOWLEDGE,
|
||||
aliases: ['@kb'],
|
||||
export const knowledgeAction: ActionItem = {
|
||||
key: '@knowledge',
|
||||
shortcut: '@kb',
|
||||
title: 'Search Knowledge Bases',
|
||||
description: 'Search and navigate to your knowledge bases',
|
||||
// action,
|
||||
search: async (_, searchTerm = '', _locale) => {
|
||||
try {
|
||||
const response = await searchDatasets(searchTerm)
|
||||
const response = await fetchDatasets({
|
||||
url: '/datasets',
|
||||
params: {
|
||||
page: 1,
|
||||
limit: 10,
|
||||
keyword: searchTerm,
|
||||
},
|
||||
})
|
||||
const datasets = response?.data || []
|
||||
return parser(datasets)
|
||||
}
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import type { Plugin } from '../../plugins/types'
|
||||
import type { PluginSearchResult, ScopeDescriptor } from './types'
|
||||
import type { Plugin, PluginsFromMarketplaceResponse } from '../../plugins/types'
|
||||
import type { ActionItem, PluginSearchResult } from './types'
|
||||
import { renderI18nObject } from '@/i18n-config'
|
||||
import { searchPlugins } from '@/service/use-goto-anything'
|
||||
import { postMarketplace } from '@/service/base'
|
||||
import Icon from '../../plugins/card/base/card-icon'
|
||||
import { getPluginIconInMarketplace } from '../../plugins/marketplace/utils'
|
||||
import { ACTION_KEYS } from '../constants'
|
||||
|
||||
const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => {
|
||||
return plugins.map((plugin) => {
|
||||
@ -19,14 +18,21 @@ const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => {
|
||||
})
|
||||
}
|
||||
|
||||
export const pluginScope: ScopeDescriptor = {
|
||||
id: 'plugin',
|
||||
shortcut: ACTION_KEYS.PLUGIN,
|
||||
export const pluginAction: ActionItem = {
|
||||
key: '@plugin',
|
||||
shortcut: '@plugin',
|
||||
title: 'Search Plugins',
|
||||
description: 'Search and navigate to your plugins',
|
||||
search: async (_, searchTerm = '', locale) => {
|
||||
try {
|
||||
const response = await searchPlugins(searchTerm)
|
||||
const response = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>('/plugins/search/advanced', {
|
||||
body: {
|
||||
page: 1,
|
||||
page_size: 10,
|
||||
query: searchTerm,
|
||||
type: 'plugin',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response?.data?.plugins) {
|
||||
console.warn('Plugin search: Unexpected response structure', response)
|
||||
|
||||
@ -1,41 +1,24 @@
|
||||
import type { ScopeSearchHandler } from './scope-registry'
|
||||
import type { SearchResult } from './types'
|
||||
import { ACTION_KEYS } from '../constants'
|
||||
import { scopeRegistry } from './scope-registry'
|
||||
import type { ActionItem } from './types'
|
||||
|
||||
const scopeId = 'rag-pipeline-node'
|
||||
let scopeRegistered = false
|
||||
|
||||
const buildSearchHandler = (searchFn?: (searchTerm: string) => SearchResult[]): ScopeSearchHandler => {
|
||||
return async (_, searchTerm = '', _locale) => {
|
||||
// Create the RAG pipeline nodes action
|
||||
export const ragPipelineNodesAction: ActionItem = {
|
||||
key: '@node',
|
||||
shortcut: '@node',
|
||||
title: 'Search RAG Pipeline Nodes',
|
||||
description: 'Find and jump to nodes in the current RAG pipeline by name or type',
|
||||
searchFn: undefined, // Will be set by useRagPipelineSearch hook
|
||||
search: async (_, searchTerm = '', _locale) => {
|
||||
try {
|
||||
if (searchFn)
|
||||
return searchFn(searchTerm)
|
||||
// Use the searchFn if available (set by useRagPipelineSearch hook)
|
||||
if (ragPipelineNodesAction.searchFn)
|
||||
return ragPipelineNodesAction.searchFn(searchTerm)
|
||||
|
||||
// If not in RAG pipeline context, return empty array
|
||||
return []
|
||||
}
|
||||
catch (error) {
|
||||
console.warn('RAG pipeline nodes search failed:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const registerRagPipelineNodeScope = () => {
|
||||
if (scopeRegistered)
|
||||
return
|
||||
|
||||
scopeRegistered = true
|
||||
scopeRegistry.register({
|
||||
id: scopeId,
|
||||
shortcut: ACTION_KEYS.NODE,
|
||||
title: 'Search RAG Pipeline Nodes',
|
||||
description: 'Find and jump to nodes in the current RAG pipeline by name or type',
|
||||
isAvailable: context => context.isRagPipelinePage,
|
||||
search: buildSearchHandler(),
|
||||
})
|
||||
}
|
||||
|
||||
export const setRagPipelineNodesSearchFn = (fn: (searchTerm: string) => SearchResult[]) => {
|
||||
registerRagPipelineNodeScope()
|
||||
scopeRegistry.updateSearchHandler(scopeId, buildSearchHandler(fn))
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,123 +0,0 @@
|
||||
import type { SearchResult } from './types'
|
||||
|
||||
import { useCallback, useMemo, useSyncExternalStore } from 'react'
|
||||
|
||||
export type ScopeContext = {
|
||||
isWorkflowPage: boolean
|
||||
isRagPipelinePage: boolean
|
||||
isAdmin?: boolean
|
||||
}
|
||||
|
||||
export type ScopeSearchHandler = (
|
||||
query: string,
|
||||
searchTerm: string,
|
||||
locale?: string,
|
||||
) => Promise<SearchResult[]> | SearchResult[]
|
||||
|
||||
export type ScopeDescriptor = {
|
||||
/**
|
||||
* Unique identifier for the scope (e.g. 'app', 'plugin')
|
||||
*/
|
||||
id: string
|
||||
/**
|
||||
* Shortcut to trigger this scope (e.g. '@app')
|
||||
*/
|
||||
shortcut: string
|
||||
/**
|
||||
* Additional shortcuts that map to this scope (e.g. ['@kb'])
|
||||
*/
|
||||
aliases?: string[]
|
||||
/**
|
||||
* I18n key or string for the scope title
|
||||
*/
|
||||
title: string
|
||||
/**
|
||||
* Description for help text
|
||||
*/
|
||||
description: string
|
||||
/**
|
||||
* Search handler function
|
||||
*/
|
||||
search: ScopeSearchHandler
|
||||
/**
|
||||
* Predicate to check if this scope is available in current context
|
||||
*/
|
||||
isAvailable?: (context: ScopeContext) => boolean
|
||||
}
|
||||
|
||||
type Listener = () => void
|
||||
|
||||
class ScopeRegistry {
|
||||
private scopes: Map<string, ScopeDescriptor> = new Map()
|
||||
private listeners: Set<Listener> = new Set()
|
||||
private version = 0
|
||||
|
||||
register(scope: ScopeDescriptor) {
|
||||
this.scopes.set(scope.id, scope)
|
||||
this.notify()
|
||||
}
|
||||
|
||||
unregister(id: string) {
|
||||
if (this.scopes.delete(id))
|
||||
this.notify()
|
||||
}
|
||||
|
||||
getScope(id: string) {
|
||||
return this.scopes.get(id)
|
||||
}
|
||||
|
||||
getScopes(context: ScopeContext): ScopeDescriptor[] {
|
||||
return Array.from(this.scopes.values())
|
||||
.filter(scope => !scope.isAvailable || scope.isAvailable(context))
|
||||
.sort((a, b) => a.shortcut.localeCompare(b.shortcut))
|
||||
}
|
||||
|
||||
updateSearchHandler(id: string, search: ScopeSearchHandler) {
|
||||
const scope = this.scopes.get(id)
|
||||
if (!scope)
|
||||
return
|
||||
this.scopes.set(id, { ...scope, search })
|
||||
this.notify()
|
||||
}
|
||||
|
||||
getVersion() {
|
||||
return this.version
|
||||
}
|
||||
|
||||
subscribe(listener: Listener) {
|
||||
this.listeners.add(listener)
|
||||
return () => {
|
||||
this.listeners.delete(listener)
|
||||
}
|
||||
}
|
||||
|
||||
private notify() {
|
||||
this.version += 1
|
||||
this.listeners.forEach(listener => listener())
|
||||
}
|
||||
}
|
||||
|
||||
export const scopeRegistry = new ScopeRegistry()
|
||||
|
||||
export const useScopeRegistry = (context: ScopeContext) => {
|
||||
const subscribe = useCallback(
|
||||
(listener: Listener) => scopeRegistry.subscribe(listener),
|
||||
[],
|
||||
)
|
||||
|
||||
const getSnapshot = useCallback(
|
||||
() => scopeRegistry.getVersion(),
|
||||
[],
|
||||
)
|
||||
|
||||
const version = useSyncExternalStore(
|
||||
subscribe,
|
||||
getSnapshot,
|
||||
getSnapshot,
|
||||
)
|
||||
|
||||
return useMemo(
|
||||
() => scopeRegistry.getScopes(context),
|
||||
[version, context.isWorkflowPage, context.isRagPipelinePage, context.isAdmin],
|
||||
)
|
||||
}
|
||||
@ -1,4 +1,5 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { TypeWithI18N } from '../../base/form/types'
|
||||
import type { Plugin } from '../../plugins/types'
|
||||
import type { CommonNodeType } from '../../workflow/types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
@ -6,7 +7,7 @@ import type { App } from '@/types/app'
|
||||
|
||||
export type SearchResultType = 'app' | 'knowledge' | 'plugin' | 'workflow-node' | 'command'
|
||||
|
||||
export type BaseSearchResult<T = unknown> = {
|
||||
export type BaseSearchResult<T = any> = {
|
||||
id: string
|
||||
title: string
|
||||
description?: string
|
||||
@ -38,8 +39,20 @@ export type WorkflowNodeSearchResult = {
|
||||
|
||||
export type CommandSearchResult = {
|
||||
type: 'command'
|
||||
} & BaseSearchResult<{ command: string, args?: Record<string, unknown> }>
|
||||
} & BaseSearchResult<{ command: string, args?: Record<string, any> }>
|
||||
|
||||
export type SearchResult = AppSearchResult | PluginSearchResult | KnowledgeSearchResult | WorkflowNodeSearchResult | CommandSearchResult
|
||||
|
||||
export type { ScopeContext, ScopeDescriptor } from './scope-registry'
|
||||
export type ActionItem = {
|
||||
key: '@app' | '@knowledge' | '@plugin' | '@node' | '/'
|
||||
shortcut: string
|
||||
title: string | TypeWithI18N
|
||||
description: string
|
||||
action?: (data: SearchResult) => void
|
||||
searchFn?: (searchTerm: string) => SearchResult[]
|
||||
search: (
|
||||
query: string,
|
||||
searchTerm: string,
|
||||
locale?: string,
|
||||
) => (Promise<SearchResult[]> | SearchResult[])
|
||||
}
|
||||
|
||||
@ -1,41 +1,24 @@
|
||||
import type { ScopeSearchHandler } from './scope-registry'
|
||||
import type { SearchResult } from './types'
|
||||
import { ACTION_KEYS } from '../constants'
|
||||
import { scopeRegistry } from './scope-registry'
|
||||
import type { ActionItem } from './types'
|
||||
|
||||
const scopeId = 'workflow-node'
|
||||
let scopeRegistered = false
|
||||
|
||||
const buildSearchHandler = (searchFn?: (searchTerm: string) => SearchResult[]): ScopeSearchHandler => {
|
||||
return async (_, searchTerm = '', _locale) => {
|
||||
// Create the workflow nodes action
|
||||
export const workflowNodesAction: ActionItem = {
|
||||
key: '@node',
|
||||
shortcut: '@node',
|
||||
title: 'Search Workflow Nodes',
|
||||
description: 'Find and jump to nodes in the current workflow by name or type',
|
||||
searchFn: undefined, // Will be set by useWorkflowSearch hook
|
||||
search: async (_, searchTerm = '', _locale) => {
|
||||
try {
|
||||
if (searchFn)
|
||||
return searchFn(searchTerm)
|
||||
// Use the searchFn if available (set by useWorkflowSearch hook)
|
||||
if (workflowNodesAction.searchFn)
|
||||
return workflowNodesAction.searchFn(searchTerm)
|
||||
|
||||
// If not in workflow context, return empty array
|
||||
return []
|
||||
}
|
||||
catch (error) {
|
||||
console.warn('Workflow nodes search failed:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const registerWorkflowNodeScope = () => {
|
||||
if (scopeRegistered)
|
||||
return
|
||||
|
||||
scopeRegistered = true
|
||||
scopeRegistry.register({
|
||||
id: scopeId,
|
||||
shortcut: ACTION_KEYS.NODE,
|
||||
title: 'Search Workflow Nodes',
|
||||
description: 'Find and jump to nodes in the current workflow by name or type',
|
||||
isAvailable: context => context.isWorkflowPage,
|
||||
search: buildSearchHandler(),
|
||||
})
|
||||
}
|
||||
|
||||
export const setWorkflowNodesSearchFn = (fn: (searchTerm: string) => SearchResult[]) => {
|
||||
registerWorkflowNodeScope()
|
||||
scopeRegistry.updateSearchHandler(scopeId, buildSearchHandler(fn))
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { ScopeDescriptor } from './actions/scope-registry'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import type { ActionItem } from './actions/types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { Command } from 'cmdk'
|
||||
import * as React from 'react'
|
||||
@ -22,315 +22,263 @@ vi.mock('./actions/commands/registry', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
type CommandSelectorProps = React.ComponentProps<typeof CommandSelector>
|
||||
|
||||
const mockScopes: ScopeDescriptor[] = [
|
||||
{
|
||||
id: 'app',
|
||||
const createActions = (): Record<string, ActionItem> => ({
|
||||
app: {
|
||||
key: '@app',
|
||||
shortcut: '@app',
|
||||
title: 'Search Applications',
|
||||
description: 'Search apps',
|
||||
title: 'Apps',
|
||||
search: vi.fn(),
|
||||
},
|
||||
{
|
||||
id: 'knowledge',
|
||||
shortcut: '@knowledge',
|
||||
title: 'Search Knowledge Bases',
|
||||
description: 'Search knowledge bases',
|
||||
search: vi.fn(),
|
||||
},
|
||||
{
|
||||
id: 'plugin',
|
||||
description: '',
|
||||
} as ActionItem,
|
||||
plugin: {
|
||||
key: '@plugin',
|
||||
shortcut: '@plugin',
|
||||
title: 'Search Plugins',
|
||||
description: 'Search plugins',
|
||||
title: 'Plugins',
|
||||
search: vi.fn(),
|
||||
},
|
||||
{
|
||||
id: 'workflow-node',
|
||||
shortcut: '@node',
|
||||
title: 'Search Nodes',
|
||||
description: 'Search workflow nodes',
|
||||
search: vi.fn(),
|
||||
},
|
||||
]
|
||||
|
||||
const mockOnCommandSelect = vi.fn()
|
||||
const mockOnCommandValueChange = vi.fn()
|
||||
|
||||
const buildCommandSelector = (props: Partial<CommandSelectorProps> = {}) => (
|
||||
<Command>
|
||||
<Command.List>
|
||||
<CommandSelector
|
||||
scopes={mockScopes}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
{...props}
|
||||
/>
|
||||
</Command.List>
|
||||
</Command>
|
||||
)
|
||||
|
||||
const renderCommandSelector = (props: Partial<CommandSelectorProps> = {}) => {
|
||||
return render(buildCommandSelector(props))
|
||||
}
|
||||
description: '',
|
||||
} as ActionItem,
|
||||
})
|
||||
|
||||
describe('CommandSelector', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
it('should list contextual search actions and notify selection', async () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter="app"
|
||||
originalQuery="@app"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
const actionButton = screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')
|
||||
await userEvent.click(actionButton)
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith('@app')
|
||||
})
|
||||
|
||||
describe('Basic Rendering', () => {
|
||||
it('should render all scopes when no filter is provided', () => {
|
||||
renderCommandSelector()
|
||||
it('should render slash commands when query starts with slash', async () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.getByText('@knowledge')).toBeInTheDocument()
|
||||
expect(screen.getByText('@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByText('@node')).toBeInTheDocument()
|
||||
})
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter="zen"
|
||||
originalQuery="/zen"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
it('should render empty filter as showing all scopes', () => {
|
||||
renderCommandSelector({ searchFilter: '' })
|
||||
const slashItem = await screen.findByText('app.gotoAnything.actions.zenDesc')
|
||||
await userEvent.click(slashItem)
|
||||
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.getByText('@knowledge')).toBeInTheDocument()
|
||||
expect(screen.getByText('@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByText('@node')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Filtering Functionality', () => {
|
||||
it('should filter scopes based on searchFilter - single match', () => {
|
||||
renderCommandSelector({ searchFilter: 'k' })
|
||||
|
||||
expect(screen.queryByText('@app')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('@knowledge')).toBeInTheDocument()
|
||||
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('@node')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should filter scopes with multiple matches', () => {
|
||||
renderCommandSelector({ searchFilter: 'p' })
|
||||
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('@plugin')).toBeInTheDocument()
|
||||
expect(screen.queryByText('@node')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should be case-insensitive when filtering', () => {
|
||||
renderCommandSelector({ searchFilter: 'APP' })
|
||||
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should match partial strings', () => {
|
||||
renderCommandSelector({ searchFilter: 'od' })
|
||||
|
||||
expect(screen.queryByText('@app')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('@node')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Empty State', () => {
|
||||
it('should show empty state when no matches found', () => {
|
||||
renderCommandSelector({ searchFilter: 'xyz' })
|
||||
|
||||
expect(screen.queryByText('@app')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('@node')).not.toBeInTheDocument()
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show empty state when filter is empty', () => {
|
||||
renderCommandSelector({ searchFilter: '' })
|
||||
|
||||
expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Selection and Highlight Management', () => {
|
||||
it('should call onCommandValueChange when filter changes and first item differs', async () => {
|
||||
const { rerender } = renderCommandSelector({
|
||||
searchFilter: '',
|
||||
commandValue: '@app',
|
||||
onCommandValueChange: mockOnCommandValueChange,
|
||||
})
|
||||
|
||||
rerender(buildCommandSelector({
|
||||
searchFilter: 'k',
|
||||
commandValue: '@app',
|
||||
onCommandValueChange: mockOnCommandValueChange,
|
||||
}))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@knowledge')
|
||||
})
|
||||
})
|
||||
|
||||
it('should not call onCommandValueChange if current value still exists', async () => {
|
||||
const { rerender } = renderCommandSelector({
|
||||
searchFilter: '',
|
||||
commandValue: '@app',
|
||||
onCommandValueChange: mockOnCommandValueChange,
|
||||
})
|
||||
|
||||
rerender(buildCommandSelector({
|
||||
searchFilter: 'a',
|
||||
commandValue: '@app',
|
||||
onCommandValueChange: mockOnCommandValueChange,
|
||||
}))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle onCommandSelect callback correctly', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderCommandSelector({ searchFilter: 'k' })
|
||||
|
||||
await user.click(screen.getByText('@knowledge'))
|
||||
|
||||
expect(mockOnCommandSelect).toHaveBeenCalledWith('@knowledge')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty scopes array', () => {
|
||||
renderCommandSelector({ scopes: [] })
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle special characters in filter', () => {
|
||||
renderCommandSelector({ searchFilter: '@' })
|
||||
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.getByText('@knowledge')).toBeInTheDocument()
|
||||
expect(screen.getByText('@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByText('@node')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined onCommandValueChange gracefully', () => {
|
||||
const { rerender } = renderCommandSelector({ searchFilter: '' })
|
||||
|
||||
expect(() => {
|
||||
rerender(buildCommandSelector({ searchFilter: 'k' }))
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should list contextual scopes and notify selection', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderCommandSelector({ searchFilter: 'app', originalQuery: '@app' })
|
||||
|
||||
await user.click(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc'))
|
||||
|
||||
expect(mockOnCommandSelect).toHaveBeenCalledWith('@app')
|
||||
})
|
||||
|
||||
it('should render slash commands when query starts with slash', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderCommandSelector({ searchFilter: 'zen', originalQuery: '/zen' })
|
||||
|
||||
const slashItem = await screen.findByText('app.gotoAnything.actions.zenDesc')
|
||||
await user.click(slashItem)
|
||||
|
||||
expect(mockOnCommandSelect).toHaveBeenCalledWith('/zen')
|
||||
})
|
||||
expect(onSelect).toHaveBeenCalledWith('/zen')
|
||||
})
|
||||
|
||||
it('should show all slash commands when no filter provided', () => {
|
||||
renderCommandSelector({ searchFilter: '', originalQuery: '/' })
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="/"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
// Should show the zen command from mock
|
||||
expect(screen.getByText('/zen')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should exclude slash scope when in @ mode', () => {
|
||||
const scopesWithSlash: ScopeDescriptor[] = [
|
||||
...mockScopes,
|
||||
{
|
||||
id: 'slash',
|
||||
it('should exclude slash action when in @ mode', () => {
|
||||
const actions = {
|
||||
...createActions(),
|
||||
slash: {
|
||||
key: '/',
|
||||
shortcut: '/',
|
||||
title: 'Slash',
|
||||
description: '',
|
||||
search: vi.fn(),
|
||||
},
|
||||
]
|
||||
description: '',
|
||||
} as ActionItem,
|
||||
}
|
||||
const onSelect = vi.fn()
|
||||
|
||||
renderCommandSelector({ scopes: scopesWithSlash, searchFilter: '', originalQuery: '@' })
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
// Should show @ commands but not /
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.queryByText('/')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show all scopes when no filter in @ mode', () => {
|
||||
renderCommandSelector({ searchFilter: '', originalQuery: '@' })
|
||||
it('should show all actions when no filter in @ mode', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.getByText('@plugin')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should set default command value when items exist but value does not', () => {
|
||||
renderCommandSelector({
|
||||
searchFilter: '',
|
||||
originalQuery: '@',
|
||||
commandValue: 'non-existent',
|
||||
onCommandValueChange: mockOnCommandValueChange,
|
||||
})
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
const onCommandValueChange = vi.fn()
|
||||
|
||||
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@app')
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
commandValue="non-existent"
|
||||
onCommandValueChange={onCommandValueChange}
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(onCommandValueChange).toHaveBeenCalledWith('@app')
|
||||
})
|
||||
|
||||
it('should NOT set command value when value already exists in items', () => {
|
||||
renderCommandSelector({
|
||||
searchFilter: '',
|
||||
originalQuery: '@',
|
||||
commandValue: '@app',
|
||||
onCommandValueChange: mockOnCommandValueChange,
|
||||
})
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
const onCommandValueChange = vi.fn()
|
||||
|
||||
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
commandValue="@app"
|
||||
onCommandValueChange={onCommandValueChange}
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(onCommandValueChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show no matching commands message when filter has no results', () => {
|
||||
renderCommandSelector({ searchFilter: 'nonexistent', originalQuery: '@nonexistent' })
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter="nonexistent"
|
||||
originalQuery="@nonexistent"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show no matching commands for slash mode with no results', () => {
|
||||
renderCommandSelector({ searchFilter: 'nonexistentcommand', originalQuery: '/nonexistentcommand' })
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter="nonexistentcommand"
|
||||
originalQuery="/nonexistentcommand"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render description for @ commands', () => {
|
||||
renderCommandSelector({ searchFilter: '', originalQuery: '@' })
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.actions.searchPluginsDesc')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render group header for @ mode', () => {
|
||||
renderCommandSelector({ searchFilter: '', originalQuery: '@' })
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.selectSearchType')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render group header for slash mode', () => {
|
||||
renderCommandSelector({ searchFilter: '', originalQuery: '/' })
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="/"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.groups.commands')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
import type { FC } from 'react'
|
||||
import type { ScopeDescriptor } from './actions/scope-registry'
|
||||
import type { ActionItem } from './actions/types'
|
||||
import { Command } from 'cmdk'
|
||||
import { usePathname } from 'next/navigation'
|
||||
import { useEffect, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { slashCommandRegistry } from './actions/commands/registry'
|
||||
import { ACTION_KEYS } from './constants'
|
||||
|
||||
type Props = {
|
||||
scopes: ScopeDescriptor[]
|
||||
actions: Record<string, ActionItem>
|
||||
onCommandSelect: (commandKey: string) => void
|
||||
searchFilter?: string
|
||||
commandValue?: string
|
||||
@ -16,7 +15,7 @@ type Props = {
|
||||
originalQuery?: string
|
||||
}
|
||||
|
||||
const CommandSelector: FC<Props> = ({ scopes, onCommandSelect, searchFilter, commandValue, onCommandValueChange, originalQuery }) => {
|
||||
const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, commandValue, onCommandValueChange, originalQuery }) => {
|
||||
const { t } = useTranslation()
|
||||
const pathname = usePathname()
|
||||
|
||||
@ -44,31 +43,22 @@ const CommandSelector: FC<Props> = ({ scopes, onCommandSelect, searchFilter, com
|
||||
}))
|
||||
}, [isSlashMode, searchFilter, pathname])
|
||||
|
||||
const filteredScopes = useMemo(() => {
|
||||
const filteredActions = useMemo(() => {
|
||||
if (isSlashMode)
|
||||
return []
|
||||
|
||||
return scopes.filter((scope) => {
|
||||
return Object.values(actions).filter((action) => {
|
||||
// Exclude slash action when in @ mode
|
||||
if (scope.id === 'slash' || scope.shortcut === ACTION_KEYS.SLASH)
|
||||
if (action.key === '/')
|
||||
return false
|
||||
if (!searchFilter)
|
||||
return true
|
||||
|
||||
// Match against shortcut/aliases or title
|
||||
const filterLower = searchFilter.toLowerCase()
|
||||
const shortcuts = [scope.shortcut, ...(scope.aliases || [])]
|
||||
return shortcuts.some(shortcut => shortcut.toLowerCase().includes(filterLower))
|
||||
|| scope.title.toLowerCase().includes(filterLower)
|
||||
}).map(scope => ({
|
||||
key: scope.shortcut, // Map to shortcut for UI display consistency
|
||||
shortcut: scope.shortcut,
|
||||
title: scope.title,
|
||||
description: scope.description,
|
||||
}))
|
||||
}, [scopes, searchFilter, isSlashMode])
|
||||
return action.shortcut.toLowerCase().includes(filterLower)
|
||||
})
|
||||
}, [actions, searchFilter, isSlashMode])
|
||||
|
||||
const allItems = isSlashMode ? slashCommands : filteredScopes
|
||||
const allItems = isSlashMode ? slashCommands : filteredActions
|
||||
|
||||
useEffect(() => {
|
||||
if (allItems.length > 0 && onCommandValueChange) {
|
||||
@ -126,7 +116,6 @@ const CommandSelector: FC<Props> = ({ scopes, onCommandSelect, searchFilter, com
|
||||
'/docs': 'gotoAnything.actions.docDesc',
|
||||
'/community': 'gotoAnything.actions.communityDesc',
|
||||
'/zen': 'gotoAnything.actions.zenDesc',
|
||||
'/banana': 'gotoAnything.actions.vibeDesc',
|
||||
} as const
|
||||
return t(slashKeyMap[item.key as keyof typeof slashKeyMap] || item.description, { ns: 'app' })
|
||||
})()
|
||||
|
||||
@ -83,10 +83,10 @@ describe('EmptyState', () => {
|
||||
})
|
||||
|
||||
it('should show specific search hint with shortcuts', () => {
|
||||
const Actions = [
|
||||
{ id: 'app', shortcut: '@app', title: 'App', description: '', search: vi.fn() },
|
||||
{ id: 'plugin', shortcut: '@plugin', title: 'Plugin', description: '', search: vi.fn() },
|
||||
] as import('../actions/types').ScopeDescriptor[]
|
||||
const Actions = {
|
||||
app: { key: '@app', shortcut: '@app' },
|
||||
plugin: { key: '@plugin', shortcut: '@plugin' },
|
||||
} as unknown as Record<string, import('../actions/types').ActionItem>
|
||||
render(<EmptyState variant="no-results" searchMode="general" Actions={Actions} />)
|
||||
|
||||
expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:@app, @plugin')).toBeInTheDocument()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { ScopeDescriptor } from '../actions/types'
|
||||
import type { ActionItem } from '../actions/types'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
export type EmptyStateVariant = 'no-results' | 'error' | 'default' | 'loading'
|
||||
@ -10,14 +10,14 @@ export type EmptyStateProps = {
|
||||
variant: EmptyStateVariant
|
||||
searchMode?: string
|
||||
error?: Error | null
|
||||
Actions?: ScopeDescriptor[]
|
||||
Actions?: Record<string, ActionItem>
|
||||
}
|
||||
|
||||
const EmptyState: FC<EmptyStateProps> = ({
|
||||
variant,
|
||||
searchMode = 'general',
|
||||
error,
|
||||
Actions = [],
|
||||
Actions = {},
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
@ -88,7 +88,7 @@ const EmptyState: FC<EmptyStateProps> = ({
|
||||
return t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' })
|
||||
}
|
||||
|
||||
const shortcuts = Actions.map(scope => scope.shortcut).join(', ')
|
||||
const shortcuts = Object.values(Actions).map(action => action.shortcut).join(', ')
|
||||
return t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts })
|
||||
}
|
||||
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
/**
|
||||
* Goto Anything Constants
|
||||
* Centralized constants for action keys
|
||||
*/
|
||||
|
||||
/**
|
||||
* Action keys for scope-based searches
|
||||
*/
|
||||
export const ACTION_KEYS = {
|
||||
APP: '@app',
|
||||
KNOWLEDGE: '@knowledge',
|
||||
PLUGIN: '@plugin',
|
||||
NODE: '@node',
|
||||
SLASH: '/',
|
||||
} as const
|
||||
|
||||
/**
|
||||
* Type-safe action key union type
|
||||
*/
|
||||
export type ActionKey = typeof ACTION_KEYS[keyof typeof ACTION_KEYS]
|
||||
@ -32,17 +32,23 @@ vi.mock('../actions/commands/registry', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
const mockExecuteCommand = vi.fn()
|
||||
|
||||
vi.mock('../actions/commands', () => ({
|
||||
executeCommand: (...args: unknown[]) => mockExecuteCommand(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/constants', () => ({
|
||||
VIBE_COMMAND_EVENT: 'vibe-command',
|
||||
}))
|
||||
const createMockActionItem = (
|
||||
key: '@app' | '@knowledge' | '@plugin' | '@node' | '/',
|
||||
extra: Record<string, unknown> = {},
|
||||
) => ({
|
||||
key,
|
||||
shortcut: key,
|
||||
title: `${key} title`,
|
||||
description: `${key} description`,
|
||||
search: vi.fn().mockResolvedValue([]),
|
||||
...extra,
|
||||
})
|
||||
|
||||
const createMockOptions = (overrides = {}) => ({
|
||||
Actions: {
|
||||
slash: createMockActionItem('/', { action: vi.fn() }),
|
||||
app: createMockActionItem('@app'),
|
||||
},
|
||||
setSearchQuery: vi.fn(),
|
||||
clearSelection: vi.fn(),
|
||||
inputRef: { current: { focus: vi.fn() } } as unknown as React.RefObject<HTMLInputElement>,
|
||||
@ -54,7 +60,6 @@ describe('useGotoAnythingNavigation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockFindCommandResult = null
|
||||
mockExecuteCommand.mockReset()
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
@ -216,8 +221,13 @@ describe('useGotoAnythingNavigation', () => {
|
||||
expect(mockRouterPush).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should execute command via executeCommand for command type', () => {
|
||||
const options = createMockOptions()
|
||||
it('should execute slash command action for command type', () => {
|
||||
const actionMock = vi.fn()
|
||||
const options = createMockOptions({
|
||||
Actions: {
|
||||
slash: { key: '/', shortcut: '/', action: actionMock },
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
@ -232,7 +242,7 @@ describe('useGotoAnythingNavigation', () => {
|
||||
result.current.handleNavigate(commandResult)
|
||||
})
|
||||
|
||||
expect(mockExecuteCommand).toHaveBeenCalledWith('theme.set', { theme: 'dark' })
|
||||
expect(actionMock).toHaveBeenCalledWith(commandResult)
|
||||
})
|
||||
|
||||
it('should set activePlugin for plugin type', () => {
|
||||
@ -358,8 +368,10 @@ describe('useGotoAnythingNavigation', () => {
|
||||
// No error should occur
|
||||
})
|
||||
|
||||
it('should handle command execution without error', () => {
|
||||
const options = createMockOptions()
|
||||
it('should handle missing slash action', () => {
|
||||
const options = createMockOptions({
|
||||
Actions: {},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
@ -373,7 +385,7 @@ describe('useGotoAnythingNavigation', () => {
|
||||
})
|
||||
})
|
||||
|
||||
expect(mockExecuteCommand).toHaveBeenCalledWith('test-command', undefined)
|
||||
// No error should occur
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -2,12 +2,10 @@
|
||||
|
||||
import type { RefObject } from 'react'
|
||||
import type { Plugin } from '../../plugins/types'
|
||||
import type { SearchResult } from '../actions/types'
|
||||
import type { ActionItem, SearchResult } from '../actions/types'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
|
||||
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
|
||||
import { executeCommand } from '../actions/commands'
|
||||
import { slashCommandRegistry } from '../actions/commands/registry'
|
||||
|
||||
export type UseGotoAnythingNavigationReturn = {
|
||||
@ -18,6 +16,7 @@ export type UseGotoAnythingNavigationReturn = {
|
||||
}
|
||||
|
||||
export type UseGotoAnythingNavigationOptions = {
|
||||
Actions: Record<string, ActionItem>
|
||||
setSearchQuery: (query: string) => void
|
||||
clearSelection: () => void
|
||||
inputRef: RefObject<HTMLInputElement | null>
|
||||
@ -28,6 +27,7 @@ export const useGotoAnythingNavigation = (
|
||||
options: UseGotoAnythingNavigationOptions,
|
||||
): UseGotoAnythingNavigationReturn => {
|
||||
const {
|
||||
Actions,
|
||||
setSearchQuery,
|
||||
clearSelection,
|
||||
inputRef,
|
||||
@ -67,16 +67,9 @@ export const useGotoAnythingNavigation = (
|
||||
|
||||
switch (result.type) {
|
||||
case 'command': {
|
||||
if (result.data.command === 'workflow.vibe') {
|
||||
if (typeof document !== 'undefined') {
|
||||
document.dispatchEvent(new CustomEvent(VIBE_COMMAND_EVENT, { detail: { dsl: result.data.args?.dsl } }))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Execute slash commands using the command bus
|
||||
const { command, args } = result.data
|
||||
executeCommand(command, args)
|
||||
// Execute slash commands
|
||||
const action = Actions.slash
|
||||
action?.action?.(result)
|
||||
break
|
||||
}
|
||||
case 'plugin':
|
||||
@ -86,12 +79,13 @@ export const useGotoAnythingNavigation = (
|
||||
// Handle workflow node selection and navigation
|
||||
if (result.metadata?.nodeId)
|
||||
selectWorkflowNode(result.metadata.nodeId, true)
|
||||
|
||||
break
|
||||
default:
|
||||
if (result.path)
|
||||
router.push(result.path)
|
||||
}
|
||||
}, [router, onClose, setSearchQuery])
|
||||
}, [router, Actions, onClose, setSearchQuery])
|
||||
|
||||
return {
|
||||
handleCommandSelect,
|
||||
|
||||
@ -35,11 +35,11 @@ vi.mock('../actions', () => ({
|
||||
searchAnything: (...args: unknown[]) => mockSearchAnything(...args),
|
||||
}))
|
||||
|
||||
const createMockScopeDescriptor = (id: string, shortcut: string) => ({
|
||||
id,
|
||||
shortcut,
|
||||
title: `${shortcut} title`,
|
||||
description: `${shortcut} description`,
|
||||
const createMockActionItem = (key: '@app' | '@knowledge' | '@plugin' | '@node' | '/') => ({
|
||||
key,
|
||||
shortcut: key,
|
||||
title: `${key} title`,
|
||||
description: `${key} description`,
|
||||
search: vi.fn().mockResolvedValue([]),
|
||||
})
|
||||
|
||||
@ -47,7 +47,7 @@ const createMockOptions = (overrides = {}) => ({
|
||||
searchQueryDebouncedValue: '',
|
||||
searchMode: 'general',
|
||||
isCommandsMode: false,
|
||||
scopes: [createMockScopeDescriptor('app', '@app')],
|
||||
Actions: { app: createMockActionItem('@app') },
|
||||
isWorkflowPage: false,
|
||||
isRagPipelinePage: false,
|
||||
cmdVal: '_',
|
||||
@ -300,36 +300,36 @@ describe('useGotoAnythingResults', () => {
|
||||
|
||||
describe('queryFn execution', () => {
|
||||
it('should call matchAction with lowercased query', async () => {
|
||||
const mockScopes = [createMockScopeDescriptor('app', '@app')]
|
||||
mockMatchAction.mockReturnValue(mockScopes[0])
|
||||
const mockActions = { app: createMockActionItem('@app') }
|
||||
mockMatchAction.mockReturnValue({ key: '@app' })
|
||||
mockSearchAnything.mockResolvedValue([])
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
searchQueryDebouncedValue: 'TEST QUERY',
|
||||
scopes: mockScopes,
|
||||
Actions: mockActions,
|
||||
})))
|
||||
|
||||
expect(capturedQueryFn).toBeDefined()
|
||||
await capturedQueryFn!()
|
||||
|
||||
expect(mockMatchAction).toHaveBeenCalledWith('test query', mockScopes)
|
||||
expect(mockMatchAction).toHaveBeenCalledWith('test query', mockActions)
|
||||
})
|
||||
|
||||
it('should call searchAnything with correct parameters', async () => {
|
||||
const mockScopes = [createMockScopeDescriptor('app', '@app')]
|
||||
const mockAction = mockScopes[0]
|
||||
const mockActions = { app: createMockActionItem('@app') }
|
||||
const mockAction = { key: '@app' }
|
||||
mockMatchAction.mockReturnValue(mockAction)
|
||||
mockSearchAnything.mockResolvedValue([{ id: '1', type: 'app', title: 'Result' }])
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
searchQueryDebouncedValue: 'My Query',
|
||||
scopes: mockScopes,
|
||||
Actions: mockActions,
|
||||
})))
|
||||
|
||||
expect(capturedQueryFn).toBeDefined()
|
||||
const result = await capturedQueryFn!()
|
||||
|
||||
expect(mockSearchAnything).toHaveBeenCalledWith('en_US', 'my query', mockAction, mockScopes)
|
||||
expect(mockSearchAnything).toHaveBeenCalledWith('en_US', 'my query', mockAction, mockActions)
|
||||
expect(result).toEqual([{ id: '1', type: 'app', title: 'Result' }])
|
||||
})
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
|
||||
import type { ScopeDescriptor, SearchResult } from '../actions/types'
|
||||
import type { ActionItem, SearchResult } from '../actions/types'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { useEffect, useMemo } from 'react'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
@ -19,7 +19,7 @@ export type UseGotoAnythingResultsOptions = {
|
||||
searchQueryDebouncedValue: string
|
||||
searchMode: string
|
||||
isCommandsMode: boolean
|
||||
scopes: ScopeDescriptor[]
|
||||
Actions: Record<string, ActionItem>
|
||||
isWorkflowPage: boolean
|
||||
isRagPipelinePage: boolean
|
||||
cmdVal: string
|
||||
@ -33,7 +33,7 @@ export const useGotoAnythingResults = (
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isCommandsMode,
|
||||
scopes,
|
||||
Actions,
|
||||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
cmdVal,
|
||||
@ -42,9 +42,13 @@ export const useGotoAnythingResults = (
|
||||
|
||||
const defaultLocale = useGetLanguage()
|
||||
|
||||
// Use action keys as stable cache key instead of the full Actions object
|
||||
// (Actions contains functions which are not serializable)
|
||||
const actionKeys = useMemo(() => Object.keys(Actions).sort(), [Actions])
|
||||
|
||||
const { data: searchResults = [], isLoading, isError, error } = useQuery(
|
||||
{
|
||||
// eslint-disable-next-line @tanstack/query/exhaustive-deps -- scopes intentionally excluded: contains non-serializable functions; scope IDs provide stable representation
|
||||
// eslint-disable-next-line @tanstack/query/exhaustive-deps -- Actions intentionally excluded: contains non-serializable functions; actionKeys provides stable representation
|
||||
queryKey: [
|
||||
'goto-anything',
|
||||
'search-result',
|
||||
@ -53,12 +57,12 @@ export const useGotoAnythingResults = (
|
||||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
defaultLocale,
|
||||
scopes.map(s => s.id).sort().join(','),
|
||||
actionKeys,
|
||||
],
|
||||
queryFn: async () => {
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const scope = matchAction(query, scopes)
|
||||
return await searchAnything(defaultLocale, query, scope, scopes)
|
||||
const action = matchAction(query, Actions)
|
||||
return await searchAnything(defaultLocale, query, action, Actions)
|
||||
},
|
||||
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
|
||||
staleTime: 30000,
|
||||
|
||||
@ -1,25 +1,9 @@
|
||||
import type { ScopeDescriptor } from '../actions/types'
|
||||
import type { ActionItem } from '../actions/types'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useGotoAnythingSearch } from './use-goto-anything-search'
|
||||
|
||||
let mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
|
||||
let mockMatchActionResult: ScopeDescriptor | undefined
|
||||
|
||||
const baseScopesMock: ScopeDescriptor[] = [
|
||||
{ id: 'slash', shortcut: '/', title: 'Slash', description: 'Slash commands', search: vi.fn() },
|
||||
{ id: 'app', shortcut: '@app', title: 'App', description: 'Search apps', search: vi.fn() },
|
||||
{ id: 'knowledge', shortcut: '@knowledge', title: 'Knowledge', description: 'Search KB', search: vi.fn() },
|
||||
]
|
||||
|
||||
const workflowScopesMock: ScopeDescriptor[] = [
|
||||
...baseScopesMock,
|
||||
{ id: 'node', shortcut: '@node', title: 'Node', description: 'Search nodes', search: vi.fn() },
|
||||
]
|
||||
|
||||
const ragScopesMock: ScopeDescriptor[] = [
|
||||
...baseScopesMock,
|
||||
{ id: 'ragNode', shortcut: '@node', title: 'RAG Node', description: 'Search RAG nodes', search: vi.fn() },
|
||||
]
|
||||
let mockMatchActionResult: Partial<ActionItem> | undefined
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounce: <T>(value: T) => value,
|
||||
@ -30,12 +14,19 @@ vi.mock('../context', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('../actions', () => ({
|
||||
useGotoAnythingScopes: (context: { isWorkflowPage: boolean, isRagPipelinePage: boolean }) => {
|
||||
if (context.isWorkflowPage)
|
||||
return workflowScopesMock
|
||||
if (context.isRagPipelinePage)
|
||||
return ragScopesMock
|
||||
return baseScopesMock
|
||||
createActions: (isWorkflowPage: boolean, isRagPipelinePage: boolean) => {
|
||||
const base = {
|
||||
slash: { key: '/', shortcut: '/' },
|
||||
app: { key: '@app', shortcut: '@app' },
|
||||
knowledge: { key: '@knowledge', shortcut: '@kb' },
|
||||
}
|
||||
if (isWorkflowPage) {
|
||||
return { ...base, node: { key: '@node', shortcut: '@node' } }
|
||||
}
|
||||
if (isRagPipelinePage) {
|
||||
return { ...base, ragNode: { key: '@node', shortcut: '@node' } }
|
||||
}
|
||||
return base
|
||||
},
|
||||
matchAction: () => mockMatchActionResult,
|
||||
}))
|
||||
@ -83,30 +74,30 @@ describe('useGotoAnythingSearch', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('scopes', () => {
|
||||
it('should provide scopes based on context', () => {
|
||||
describe('Actions', () => {
|
||||
it('should provide Actions based on context', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.scopes).toBeDefined()
|
||||
expect(Array.isArray(result.current.scopes)).toBe(true)
|
||||
expect(result.current.Actions).toBeDefined()
|
||||
expect(typeof result.current.Actions).toBe('object')
|
||||
})
|
||||
|
||||
it('should include node scope when on workflow page', () => {
|
||||
it('should include node action when on workflow page', () => {
|
||||
mockContextValue = { isWorkflowPage: true, isRagPipelinePage: false }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.scopes.find(s => s.id === 'node')).toBeDefined()
|
||||
expect(result.current.Actions.node).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include ragNode scope when on RAG pipeline page', () => {
|
||||
it('should include ragNode action when on RAG pipeline page', () => {
|
||||
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: true }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.scopes.find(s => s.id === 'ragNode')).toBeDefined()
|
||||
expect(result.current.Actions.ragNode).toBeDefined()
|
||||
})
|
||||
|
||||
it('should not include node scopes when on regular page', () => {
|
||||
it('should not include node actions when on regular page', () => {
|
||||
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.scopes.find(s => s.id === 'node')).toBeUndefined()
|
||||
expect(result.current.scopes.find(s => s.id === 'ragNode')).toBeUndefined()
|
||||
expect(result.current.Actions.node).toBeUndefined()
|
||||
expect(result.current.Actions.ragNode).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
@ -154,7 +145,7 @@ describe('useGotoAnythingSearch', () => {
|
||||
})
|
||||
|
||||
it('should return false when query starts with "@" and action matches', () => {
|
||||
mockMatchActionResult = baseScopesMock.find(s => s.id === 'app')
|
||||
mockMatchActionResult = { key: '@app', shortcut: '@app' }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
@ -215,8 +206,8 @@ describe('useGotoAnythingSearch', () => {
|
||||
expect(result.current.searchMode).toBe('general')
|
||||
})
|
||||
|
||||
it('should return action shortcut when action matches', () => {
|
||||
mockMatchActionResult = baseScopesMock.find(s => s.id === 'app')
|
||||
it('should return action key when action matches', () => {
|
||||
mockMatchActionResult = { key: '@app', shortcut: '@app' }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
@ -226,8 +217,8 @@ describe('useGotoAnythingSearch', () => {
|
||||
expect(result.current.searchMode).toBe('@app')
|
||||
})
|
||||
|
||||
it('should return "@command" when action is slash', () => {
|
||||
mockMatchActionResult = baseScopesMock.find(s => s.id === 'slash')
|
||||
it('should return "@command" when action key is "/"', () => {
|
||||
mockMatchActionResult = { key: '/', shortcut: '/' }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
'use client'
|
||||
|
||||
import type { ScopeDescriptor } from '../actions/types'
|
||||
import type { ActionItem } from '../actions/types'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { matchAction, useGotoAnythingScopes } from '../actions'
|
||||
import { ACTION_KEYS } from '../constants'
|
||||
import { createActions, matchAction } from '../actions'
|
||||
import { useGotoAnythingContext } from '../context'
|
||||
|
||||
export type UseGotoAnythingSearchReturn = {
|
||||
@ -16,7 +15,7 @@ export type UseGotoAnythingSearchReturn = {
|
||||
cmdVal: string
|
||||
setCmdVal: (val: string) => void
|
||||
clearSelection: () => void
|
||||
scopes: ScopeDescriptor[]
|
||||
Actions: Record<string, ActionItem>
|
||||
}
|
||||
|
||||
export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
|
||||
@ -24,8 +23,10 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
|
||||
const [searchQuery, setSearchQuery] = useState<string>('')
|
||||
const [cmdVal, setCmdVal] = useState<string>('_')
|
||||
|
||||
// Fetch scopes from registry based on context
|
||||
const scopes = useGotoAnythingScopes({ isWorkflowPage, isRagPipelinePage })
|
||||
// Filter actions based on context
|
||||
const Actions = useMemo(() => {
|
||||
return createActions(isWorkflowPage, isRagPipelinePage)
|
||||
}, [isWorkflowPage, isRagPipelinePage])
|
||||
|
||||
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
|
||||
wait: 300,
|
||||
@ -34,30 +35,28 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
|
||||
const isCommandsMode = useMemo(() => {
|
||||
const trimmed = searchQuery.trim()
|
||||
return trimmed === '@' || trimmed === '/'
|
||||
|| (trimmed.startsWith('@') && !matchAction(trimmed, scopes))
|
||||
|| (trimmed.startsWith('/') && !matchAction(trimmed, scopes))
|
||||
}, [searchQuery, scopes])
|
||||
|| (trimmed.startsWith('@') && !matchAction(trimmed, Actions))
|
||||
|| (trimmed.startsWith('/') && !matchAction(trimmed, Actions))
|
||||
}, [searchQuery, Actions])
|
||||
|
||||
const searchMode = useMemo(() => {
|
||||
if (isCommandsMode) {
|
||||
// Distinguish between @ (scopes) and / (commands) mode
|
||||
if (searchQuery.trim().startsWith('@'))
|
||||
return 'scopes'
|
||||
else if (searchQuery.trim().startsWith('/'))
|
||||
return 'commands'
|
||||
return 'commands'
|
||||
return 'commands' // default fallback
|
||||
}
|
||||
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const action = matchAction(query, scopes)
|
||||
const action = matchAction(query, Actions)
|
||||
|
||||
if (!action)
|
||||
return 'general'
|
||||
|
||||
if (action.id === 'slash' || action.shortcut === ACTION_KEYS.SLASH)
|
||||
return '@command'
|
||||
|
||||
return action.shortcut
|
||||
}, [searchQueryDebouncedValue, scopes, isCommandsMode, searchQuery])
|
||||
return action.key === '/' ? '@command' : action.key
|
||||
}, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery])
|
||||
|
||||
// Prevent automatic selection of the first option when cmdVal is not set
|
||||
const clearSelection = useCallback(() => {
|
||||
@ -73,6 +72,6 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
|
||||
cmdVal,
|
||||
setCmdVal,
|
||||
clearSelection,
|
||||
scopes,
|
||||
Actions,
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,93 +0,0 @@
|
||||
import { keepPreviousData, useQuery } from '@tanstack/react-query'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useMemo } from 'react'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { matchAction, searchAnything, useGotoAnythingScopes } from '../actions'
|
||||
import { ACTION_KEYS } from '../constants'
|
||||
import { useGotoAnythingContext } from '../context'
|
||||
|
||||
export const useSearch = (searchQuery: string) => {
|
||||
const defaultLocale = useGetLanguage()
|
||||
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
|
||||
|
||||
// Fetch scopes from registry based on context
|
||||
const scopes = useGotoAnythingScopes({ isWorkflowPage, isRagPipelinePage })
|
||||
|
||||
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
|
||||
wait: 300,
|
||||
})
|
||||
|
||||
const isCommandsMode = searchQuery.trim() === '@' || searchQuery.trim() === '/'
|
||||
|| (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), scopes))
|
||||
|| (searchQuery.trim().startsWith('/') && !matchAction(searchQuery.trim(), scopes))
|
||||
|
||||
const searchMode = useMemo(() => {
|
||||
if (isCommandsMode) {
|
||||
// Distinguish between @ (scopes) and / (commands) mode
|
||||
if (searchQuery.trim().startsWith('@'))
|
||||
return 'scopes'
|
||||
else if (searchQuery.trim().startsWith('/'))
|
||||
return 'commands'
|
||||
return 'commands' // default fallback
|
||||
}
|
||||
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const action = matchAction(query, scopes)
|
||||
|
||||
if (!action)
|
||||
return 'general'
|
||||
|
||||
if (action.id === 'slash' || action.shortcut === ACTION_KEYS.SLASH)
|
||||
return '@command'
|
||||
|
||||
return action.shortcut
|
||||
}, [searchQueryDebouncedValue, scopes, isCommandsMode, searchQuery])
|
||||
|
||||
const { data: searchResults = [], isLoading, isError, error } = useQuery(
|
||||
{
|
||||
queryKey: [
|
||||
'goto-anything',
|
||||
'search-result',
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
defaultLocale,
|
||||
scopes.map(s => s.id).sort().join(','),
|
||||
],
|
||||
queryFn: async () => {
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const scope = matchAction(query, scopes)
|
||||
return await searchAnything(defaultLocale, query, scope, scopes)
|
||||
},
|
||||
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
|
||||
staleTime: 30000,
|
||||
gcTime: 300000,
|
||||
placeholderData: keepPreviousData,
|
||||
},
|
||||
)
|
||||
|
||||
const dedupedResults = useMemo(() => {
|
||||
if (!searchQuery.trim())
|
||||
return []
|
||||
|
||||
const seen = new Set<string>()
|
||||
return searchResults.filter((result) => {
|
||||
const key = `${result.type}-${result.id}`
|
||||
if (seen.has(key))
|
||||
return false
|
||||
seen.add(key)
|
||||
return true
|
||||
})
|
||||
}, [searchResults, searchQuery])
|
||||
|
||||
return {
|
||||
scopes,
|
||||
searchResults: dedupedResults,
|
||||
isLoading,
|
||||
isError,
|
||||
error,
|
||||
searchMode,
|
||||
isCommandsMode,
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,5 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { ScopeDescriptor } from './actions/scope-registry'
|
||||
import type { SearchResult } from './actions/types'
|
||||
import type { ActionItem, SearchResult } from './actions/types'
|
||||
import { act, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
@ -59,7 +58,6 @@ const triggerKeyPress = (combo: string) => {
|
||||
let mockQueryResult = { data: [] as TestSearchResult[], isLoading: false, isError: false, error: null as Error | null }
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
useQuery: () => mockQueryResult,
|
||||
keepPreviousData: (data: unknown) => data,
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
@ -72,30 +70,33 @@ vi.mock('./context', () => ({
|
||||
GotoAnythingProvider: ({ children }: { children: React.ReactNode }) => <>{children}</>,
|
||||
}))
|
||||
|
||||
type MatchAction = typeof import('./actions').matchAction
|
||||
type SearchAnything = typeof import('./actions').searchAnything
|
||||
|
||||
const mockState = vi.hoisted(() => {
|
||||
const state = {
|
||||
scopes: [] as ScopeDescriptor[],
|
||||
useGotoAnythingScopesMock: vi.fn(() => state.scopes),
|
||||
matchActionMock: vi.fn<MatchAction>(() => undefined),
|
||||
searchAnythingMock: vi.fn<SearchAnything>(async () => []),
|
||||
}
|
||||
|
||||
return state
|
||||
const createActionItem = (key: ActionItem['key'], shortcut: string): ActionItem => ({
|
||||
key,
|
||||
shortcut,
|
||||
title: `${key} title`,
|
||||
description: `${key} desc`,
|
||||
action: vi.fn(),
|
||||
search: vi.fn(),
|
||||
})
|
||||
|
||||
const actionsMock = {
|
||||
slash: createActionItem('/', '/'),
|
||||
app: createActionItem('@app', '@app'),
|
||||
plugin: createActionItem('@plugin', '@plugin'),
|
||||
}
|
||||
|
||||
const createActionsMock = vi.fn(() => actionsMock)
|
||||
const matchActionMock = vi.fn(() => undefined)
|
||||
const searchAnythingMock = vi.fn(async () => mockQueryResult.data)
|
||||
|
||||
vi.mock('./actions', () => ({
|
||||
__esModule: true,
|
||||
matchAction: (...args: Parameters<MatchAction>) => mockState.matchActionMock(...args),
|
||||
searchAnything: (...args: Parameters<SearchAnything>) => mockState.searchAnythingMock(...args),
|
||||
useGotoAnythingScopes: () => mockState.useGotoAnythingScopesMock(),
|
||||
createActions: () => createActionsMock(),
|
||||
matchAction: () => matchActionMock(),
|
||||
searchAnything: () => searchAnythingMock(),
|
||||
}))
|
||||
|
||||
vi.mock('./actions/commands', () => ({
|
||||
SlashCommandProvider: () => null,
|
||||
executeCommand: vi.fn(),
|
||||
}))
|
||||
|
||||
type MockSlashCommand = {
|
||||
@ -113,20 +114,6 @@ vi.mock('./actions/commands/registry', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
const createScope = (id: ScopeDescriptor['id'], shortcut: string): ScopeDescriptor => ({
|
||||
id,
|
||||
shortcut,
|
||||
title: `${id} title`,
|
||||
description: `${id} desc`,
|
||||
search: vi.fn(),
|
||||
})
|
||||
|
||||
const scopesMock = [
|
||||
createScope('slash', '/'),
|
||||
createScope('app', '@app'),
|
||||
createScope('plugin', '@plugin'),
|
||||
]
|
||||
|
||||
vi.mock('@/app/components/workflow/utils/common', () => ({
|
||||
getKeyboardKeyCodeBySystem: () => 'ctrl',
|
||||
getKeyboardKeyNameBySystem: (key: string) => key,
|
||||
@ -153,10 +140,8 @@ describe('GotoAnything', () => {
|
||||
routerPush.mockClear()
|
||||
Object.keys(keyPressHandlers).forEach(key => delete keyPressHandlers[key])
|
||||
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
|
||||
mockState.scopes = scopesMock
|
||||
mockState.matchActionMock.mockReset()
|
||||
mockState.searchAnythingMock.mockClear()
|
||||
mockState.searchAnythingMock.mockImplementation(async () => mockQueryResult.data as SearchResult[])
|
||||
matchActionMock.mockReset()
|
||||
searchAnythingMock.mockClear()
|
||||
mockFindCommand = null
|
||||
})
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ const GotoAnything: FC<Props> = ({
|
||||
cmdVal,
|
||||
setCmdVal,
|
||||
clearSelection,
|
||||
scopes,
|
||||
Actions,
|
||||
} = useGotoAnythingSearch()
|
||||
|
||||
// Modal state management
|
||||
@ -76,7 +76,7 @@ const GotoAnything: FC<Props> = ({
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isCommandsMode,
|
||||
scopes,
|
||||
Actions,
|
||||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
cmdVal,
|
||||
@ -90,6 +90,7 @@ const GotoAnything: FC<Props> = ({
|
||||
activePlugin,
|
||||
setActivePlugin,
|
||||
} = useGotoAnythingNavigation({
|
||||
Actions,
|
||||
setSearchQuery,
|
||||
clearSelection,
|
||||
inputRef,
|
||||
@ -178,7 +179,7 @@ const GotoAnything: FC<Props> = ({
|
||||
{isCommandsMode
|
||||
? (
|
||||
<CommandSelector
|
||||
scopes={scopes}
|
||||
actions={Actions}
|
||||
onCommandSelect={handleCommandSelect}
|
||||
searchFilter={searchQuery.trim().substring(1)}
|
||||
commandValue={cmdVal}
|
||||
@ -197,7 +198,7 @@ const GotoAnything: FC<Props> = ({
|
||||
<EmptyState
|
||||
variant="no-results"
|
||||
searchMode={searchMode}
|
||||
Actions={scopes}
|
||||
Actions={Actions}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@ -1,19 +1,27 @@
|
||||
'use client'
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import type { App } from '@/types/app'
|
||||
import { useRef } from 'react'
|
||||
import * as React from 'react'
|
||||
import { useMemo, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import AppInputsForm from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-form'
|
||||
import { useAppInputsFormSchema } from '@/app/components/plugins/plugin-detail-panel/app-selector/hooks/use-app-inputs-form-schema'
|
||||
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { useAppDetail } from '@/service/use-apps'
|
||||
import { useFileUploadConfig } from '@/service/use-common'
|
||||
import { useAppWorkflow } from '@/service/use-workflow'
|
||||
import { AppModeEnum, Resolution } from '@/types/app'
|
||||
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
value?: {
|
||||
app_id: string
|
||||
inputs: Record<string, unknown>
|
||||
inputs: Record<string, any>
|
||||
}
|
||||
appDetail: App
|
||||
onFormChange: (value: Record<string, unknown>) => void
|
||||
onFormChange: (value: Record<string, any>) => void
|
||||
}
|
||||
|
||||
const AppInputsPanel = ({
|
||||
@ -22,33 +30,155 @@ const AppInputsPanel = ({
|
||||
onFormChange,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const inputsRef = useRef<Record<string, unknown>>(value?.inputs || {})
|
||||
const inputsRef = useRef<any>(value?.inputs || {})
|
||||
const isBasicApp = appDetail.mode !== AppModeEnum.ADVANCED_CHAT && appDetail.mode !== AppModeEnum.WORKFLOW
|
||||
const { data: fileUploadConfig } = useFileUploadConfig()
|
||||
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
|
||||
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(isBasicApp ? '' : appDetail.id)
|
||||
const isLoading = isAppLoading || isWorkflowLoading
|
||||
|
||||
const { inputFormSchema, isLoading } = useAppInputsFormSchema({ appDetail })
|
||||
const basicAppFileConfig = useMemo(() => {
|
||||
let fileConfig: FileUpload
|
||||
if (isBasicApp)
|
||||
fileConfig = currentApp?.model_config?.file_upload as FileUpload
|
||||
else
|
||||
fileConfig = currentWorkflow?.features?.file_upload as FileUpload
|
||||
return {
|
||||
image: {
|
||||
detail: fileConfig?.image?.detail || Resolution.high,
|
||||
enabled: !!fileConfig?.image?.enabled,
|
||||
number_limits: fileConfig?.image?.number_limits || 3,
|
||||
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
},
|
||||
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
|
||||
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
|
||||
allowed_file_extensions: fileConfig?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
|
||||
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods || fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
|
||||
}
|
||||
}, [currentApp?.model_config?.file_upload, currentWorkflow?.features?.file_upload, isBasicApp])
|
||||
|
||||
const handleFormChange = (newValue: Record<string, unknown>) => {
|
||||
inputsRef.current = newValue
|
||||
onFormChange(newValue)
|
||||
const inputFormSchema = useMemo(() => {
|
||||
if (!currentApp)
|
||||
return []
|
||||
let inputFormSchema = []
|
||||
if (isBasicApp) {
|
||||
inputFormSchema = currentApp.model_config?.user_input_form?.filter((item: any) => !item.external_data_tool).map((item: any) => {
|
||||
if (item.paragraph) {
|
||||
return {
|
||||
...item.paragraph,
|
||||
type: 'paragraph',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.number) {
|
||||
return {
|
||||
...item.number,
|
||||
type: 'number',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.checkbox) {
|
||||
return {
|
||||
...item.checkbox,
|
||||
type: 'checkbox',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.select) {
|
||||
return {
|
||||
...item.select,
|
||||
type: 'select',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
|
||||
if (item['file-list']) {
|
||||
return {
|
||||
...item['file-list'],
|
||||
type: 'file-list',
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (item.file) {
|
||||
return {
|
||||
...item.file,
|
||||
type: 'file',
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (item.json_object) {
|
||||
return {
|
||||
...item.json_object,
|
||||
type: 'json_object',
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...item['text-input'],
|
||||
type: 'text-input',
|
||||
required: false,
|
||||
}
|
||||
}) || []
|
||||
}
|
||||
else {
|
||||
const startNode = currentWorkflow?.graph?.nodes.find(node => node.data.type === BlockEnum.Start) as any
|
||||
inputFormSchema = startNode?.data.variables.map((variable: any) => {
|
||||
if (variable.type === InputVarType.multiFiles) {
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (variable.type === InputVarType.singleFile) {
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
}
|
||||
}) || []
|
||||
}
|
||||
if ((currentApp.mode === AppModeEnum.COMPLETION || currentApp.mode === AppModeEnum.WORKFLOW) && basicAppFileConfig.enabled) {
|
||||
inputFormSchema.push({
|
||||
label: 'Image Upload',
|
||||
variable: '#image#',
|
||||
type: InputVarType.singleFile,
|
||||
required: false,
|
||||
...basicAppFileConfig,
|
||||
fileUploadConfig,
|
||||
})
|
||||
}
|
||||
return inputFormSchema || []
|
||||
}, [basicAppFileConfig, currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
|
||||
|
||||
const handleFormChange = (value: Record<string, any>) => {
|
||||
inputsRef.current = value
|
||||
onFormChange(value)
|
||||
}
|
||||
|
||||
const hasInputs = inputFormSchema.length > 0
|
||||
|
||||
return (
|
||||
<div className={cn('flex max-h-[240px] flex-col rounded-b-2xl border-t border-divider-subtle pb-4')}>
|
||||
{isLoading && <div className="pt-3"><Loading type="app" /></div>}
|
||||
{!isLoading && (
|
||||
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">
|
||||
{t('appSelector.params', { ns: 'app' })}
|
||||
</div>
|
||||
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">{t('appSelector.params', { ns: 'app' })}</div>
|
||||
)}
|
||||
{!isLoading && !hasInputs && (
|
||||
{!isLoading && !inputFormSchema.length && (
|
||||
<div className="flex h-16 flex-col items-center justify-center">
|
||||
<div className="system-sm-regular text-text-tertiary">
|
||||
{t('appSelector.noParams', { ns: 'app' })}
|
||||
</div>
|
||||
<div className="system-sm-regular text-text-tertiary">{t('appSelector.noParams', { ns: 'app' })}</div>
|
||||
</div>
|
||||
)}
|
||||
{!isLoading && hasInputs && (
|
||||
{!isLoading && !!inputFormSchema.length && (
|
||||
<div className="grow overflow-y-auto">
|
||||
<AppInputsForm
|
||||
inputs={value?.inputs || {}}
|
||||
|
||||
@ -1,211 +0,0 @@
|
||||
'use client'
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import type { FileUploadConfigResponse } from '@/models/common'
|
||||
import type { App } from '@/types/app'
|
||||
import type { FetchWorkflowDraftResponse } from '@/types/workflow'
|
||||
import { useMemo } from 'react'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { useAppDetail } from '@/service/use-apps'
|
||||
import { useFileUploadConfig } from '@/service/use-common'
|
||||
import { useAppWorkflow } from '@/service/use-workflow'
|
||||
import { AppModeEnum, Resolution } from '@/types/app'
|
||||
|
||||
const BASIC_INPUT_TYPE_MAP: Record<string, string> = {
|
||||
'paragraph': 'paragraph',
|
||||
'number': 'number',
|
||||
'checkbox': 'checkbox',
|
||||
'select': 'select',
|
||||
'file-list': 'file-list',
|
||||
'file': 'file',
|
||||
'json_object': 'json_object',
|
||||
}
|
||||
|
||||
const FILE_INPUT_TYPES = new Set(['file-list', 'file'])
|
||||
|
||||
const WORKFLOW_FILE_VAR_TYPES = new Set([InputVarType.multiFiles, InputVarType.singleFile])
|
||||
|
||||
type InputSchemaItem = {
|
||||
label?: string
|
||||
variable?: string
|
||||
type: string
|
||||
required: boolean
|
||||
fileUploadConfig?: FileUploadConfigResponse
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
function isBasicAppMode(mode: string): boolean {
|
||||
return mode !== AppModeEnum.ADVANCED_CHAT && mode !== AppModeEnum.WORKFLOW
|
||||
}
|
||||
|
||||
function supportsImageUpload(mode: string): boolean {
|
||||
return mode === AppModeEnum.COMPLETION || mode === AppModeEnum.WORKFLOW
|
||||
}
|
||||
|
||||
function buildFileConfig(fileConfig: FileUpload | undefined) {
|
||||
return {
|
||||
image: {
|
||||
detail: fileConfig?.image?.detail || Resolution.high,
|
||||
enabled: !!fileConfig?.image?.enabled,
|
||||
number_limits: fileConfig?.image?.number_limits || 3,
|
||||
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
},
|
||||
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
|
||||
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
|
||||
allowed_file_extensions: fileConfig?.allowed_file_extensions
|
||||
|| [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
|
||||
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods
|
||||
|| fileConfig?.image?.transfer_methods
|
||||
|| ['local_file', 'remote_url'],
|
||||
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
|
||||
}
|
||||
}
|
||||
|
||||
function mapBasicAppInputItem(
|
||||
item: Record<string, unknown>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem | null {
|
||||
for (const [key, type] of Object.entries(BASIC_INPUT_TYPE_MAP)) {
|
||||
if (!item[key])
|
||||
continue
|
||||
|
||||
const inputData = item[key] as Record<string, unknown>
|
||||
const needsFileConfig = FILE_INPUT_TYPES.has(key)
|
||||
|
||||
return {
|
||||
...inputData,
|
||||
type,
|
||||
required: false,
|
||||
...(needsFileConfig && { fileUploadConfig }),
|
||||
}
|
||||
}
|
||||
|
||||
const textInput = item['text-input'] as Record<string, unknown> | undefined
|
||||
if (!textInput)
|
||||
return null
|
||||
|
||||
return {
|
||||
...textInput,
|
||||
type: 'text-input',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
|
||||
function mapWorkflowVariable(
|
||||
variable: Record<string, unknown>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem {
|
||||
const needsFileConfig = WORKFLOW_FILE_VAR_TYPES.has(variable.type as InputVarType)
|
||||
|
||||
return {
|
||||
...variable,
|
||||
type: variable.type as string,
|
||||
required: false,
|
||||
...(needsFileConfig && { fileUploadConfig }),
|
||||
}
|
||||
}
|
||||
|
||||
function createImageUploadSchema(
|
||||
basicFileConfig: ReturnType<typeof buildFileConfig>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem {
|
||||
return {
|
||||
label: 'Image Upload',
|
||||
variable: '#image#',
|
||||
type: InputVarType.singleFile,
|
||||
required: false,
|
||||
...basicFileConfig,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
function buildBasicAppSchema(
|
||||
currentApp: App,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem[] {
|
||||
const userInputForm = currentApp.model_config?.user_input_form as Array<Record<string, unknown>> | undefined
|
||||
if (!userInputForm)
|
||||
return []
|
||||
|
||||
return userInputForm
|
||||
.filter((item: Record<string, unknown>) => !item.external_data_tool)
|
||||
.map((item: Record<string, unknown>) => mapBasicAppInputItem(item, fileUploadConfig))
|
||||
.filter((item): item is InputSchemaItem => item !== null)
|
||||
}
|
||||
|
||||
function buildWorkflowSchema(
|
||||
workflow: FetchWorkflowDraftResponse,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem[] {
|
||||
const startNode = workflow.graph?.nodes.find(
|
||||
node => node.data.type === BlockEnum.Start,
|
||||
) as { data: { variables: Array<Record<string, unknown>> } } | undefined
|
||||
|
||||
if (!startNode?.data.variables)
|
||||
return []
|
||||
|
||||
return startNode.data.variables.map(
|
||||
variable => mapWorkflowVariable(variable, fileUploadConfig),
|
||||
)
|
||||
}
|
||||
|
||||
type UseAppInputsFormSchemaParams = {
|
||||
appDetail: App
|
||||
}
|
||||
|
||||
type UseAppInputsFormSchemaResult = {
|
||||
inputFormSchema: InputSchemaItem[]
|
||||
isLoading: boolean
|
||||
fileUploadConfig?: FileUploadConfigResponse
|
||||
}
|
||||
|
||||
export function useAppInputsFormSchema({
|
||||
appDetail,
|
||||
}: UseAppInputsFormSchemaParams): UseAppInputsFormSchemaResult {
|
||||
const isBasicApp = isBasicAppMode(appDetail.mode)
|
||||
|
||||
const { data: fileUploadConfig } = useFileUploadConfig()
|
||||
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
|
||||
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(
|
||||
isBasicApp ? '' : appDetail.id,
|
||||
)
|
||||
|
||||
const isLoading = isAppLoading || isWorkflowLoading
|
||||
|
||||
const inputFormSchema = useMemo(() => {
|
||||
if (!currentApp)
|
||||
return []
|
||||
|
||||
if (!isBasicApp && !currentWorkflow)
|
||||
return []
|
||||
|
||||
// Build base schema based on app type
|
||||
// Note: currentWorkflow is guaranteed to be defined here due to the early return above
|
||||
const baseSchema = isBasicApp
|
||||
? buildBasicAppSchema(currentApp, fileUploadConfig)
|
||||
: buildWorkflowSchema(currentWorkflow!, fileUploadConfig)
|
||||
|
||||
if (!supportsImageUpload(currentApp.mode))
|
||||
return baseSchema
|
||||
|
||||
const rawFileConfig = isBasicApp
|
||||
? currentApp.model_config?.file_upload as FileUpload
|
||||
: currentWorkflow?.features?.file_upload as FileUpload
|
||||
|
||||
const basicFileConfig = buildFileConfig(rawFileConfig)
|
||||
|
||||
if (!basicFileConfig.enabled)
|
||||
return baseSchema
|
||||
|
||||
return [
|
||||
...baseSchema,
|
||||
createImageUploadSchema(basicFileConfig, fileUploadConfig),
|
||||
]
|
||||
}, [currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
|
||||
|
||||
return {
|
||||
inputFormSchema,
|
||||
isLoading,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
@ -6,6 +6,7 @@ import Toast from '@/app/components/base/toast'
|
||||
import { PluginSource } from '../types'
|
||||
import DetailHeader from './detail-header'
|
||||
|
||||
// Use vi.hoisted for mock functions used in vi.mock factories
|
||||
const {
|
||||
mockSetShowUpdatePluginModal,
|
||||
mockRefreshModelProviders,
|
||||
|
||||
@ -1,2 +1,416 @@
|
||||
// Re-export from refactored module for backward compatibility
|
||||
export { default } from './detail-header/index'
|
||||
import type { PluginDetail } from '../types'
|
||||
import {
|
||||
RiArrowLeftRightLine,
|
||||
RiBugLine,
|
||||
RiCloseLine,
|
||||
RiHardDrive3Line,
|
||||
} from '@remixicon/react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { Github } from '@/app/components/base/icons/src/public/common'
|
||||
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth'
|
||||
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
|
||||
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
|
||||
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
|
||||
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useGetLanguage, useLocale } from '@/context/i18n'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { uninstallPlugin } from '@/service/plugins'
|
||||
import { useAllToolProviders, useInvalidateAllToolProviders } from '@/service/use-tools'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import { AutoUpdateLine } from '../../base/icons/src/vender/system'
|
||||
import Verified from '../base/badges/verified'
|
||||
import DeprecationNotice from '../base/deprecation-notice'
|
||||
import Icon from '../card/base/card-icon'
|
||||
import Description from '../card/base/description'
|
||||
import OrgInfo from '../card/base/org-info'
|
||||
import Title from '../card/base/title'
|
||||
import { useGitHubReleases } from '../install-plugin/hooks'
|
||||
import useReferenceSetting from '../plugin-page/use-reference-setting'
|
||||
import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting/types'
|
||||
import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../reference-setting-modal/auto-update-setting/utils'
|
||||
import { PluginCategoryEnum, PluginSource } from '../types'
|
||||
|
||||
const i18nPrefix = 'action'
|
||||
|
||||
type Props = {
|
||||
detail: PluginDetail
|
||||
isReadmeView?: boolean
|
||||
onHide?: () => void
|
||||
onUpdate?: (isDelete?: boolean) => void
|
||||
}
|
||||
|
||||
const DetailHeader = ({
|
||||
detail,
|
||||
isReadmeView = false,
|
||||
onHide,
|
||||
onUpdate,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { userProfile: { timezone } } = useAppContext()
|
||||
|
||||
const { theme } = useTheme()
|
||||
const locale = useGetLanguage()
|
||||
const currentLocale = useLocale()
|
||||
const { checkForUpdates, fetchReleases } = useGitHubReleases()
|
||||
const { setShowUpdatePluginModal } = useModalContext()
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
|
||||
const {
|
||||
id,
|
||||
source,
|
||||
tenant_id,
|
||||
version,
|
||||
latest_unique_identifier,
|
||||
latest_version,
|
||||
meta,
|
||||
plugin_id,
|
||||
status,
|
||||
deprecated_reason,
|
||||
alternative_plugin_id,
|
||||
} = detail
|
||||
|
||||
const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail
|
||||
const isTool = category === PluginCategoryEnum.tool
|
||||
const providerBriefInfo = tool?.identity
|
||||
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
|
||||
const { data: collectionList = [] } = useAllToolProviders(isTool)
|
||||
const provider = useMemo(() => {
|
||||
return collectionList.find(collection => collection.name === providerKey)
|
||||
}, [collectionList, providerKey])
|
||||
const isFromGitHub = source === PluginSource.github
|
||||
const isFromMarketplace = source === PluginSource.marketplace
|
||||
|
||||
const [isShow, setIsShow] = useState(false)
|
||||
const [targetVersion, setTargetVersion] = useState({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
const hasNewVersion = useMemo(() => {
|
||||
if (isFromMarketplace)
|
||||
return !!latest_version && latest_version !== version
|
||||
|
||||
return false
|
||||
}, [isFromMarketplace, latest_version, version])
|
||||
|
||||
const iconFileName = theme === 'dark' && icon_dark ? icon_dark : icon
|
||||
const iconSrc = iconFileName
|
||||
? (iconFileName.startsWith('http') ? iconFileName : `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenant_id}&filename=${iconFileName}`)
|
||||
: ''
|
||||
|
||||
const detailUrl = useMemo(() => {
|
||||
if (isFromGitHub)
|
||||
return `https://github.com/${meta!.repo}`
|
||||
if (isFromMarketplace)
|
||||
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: currentLocale, theme })
|
||||
return ''
|
||||
}, [author, isFromGitHub, isFromMarketplace, meta, name, theme])
|
||||
|
||||
const [isShowUpdateModal, {
|
||||
setTrue: showUpdateModal,
|
||||
setFalse: hideUpdateModal,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const { referenceSetting } = useReferenceSetting()
|
||||
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
||||
const isAutoUpgradeEnabled = useMemo(() => {
|
||||
if (!enable_marketplace)
|
||||
return false
|
||||
if (!autoUpgradeInfo || !isFromMarketplace)
|
||||
return false
|
||||
if (autoUpgradeInfo.strategy_setting === 'disabled')
|
||||
return false
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
|
||||
return true
|
||||
return false
|
||||
}, [autoUpgradeInfo, plugin_id, isFromMarketplace])
|
||||
|
||||
const [isDowngrade, setIsDowngrade] = useState(false)
|
||||
const handleUpdate = async (isDowngrade?: boolean) => {
|
||||
if (isFromMarketplace) {
|
||||
setIsDowngrade(!!isDowngrade)
|
||||
showUpdateModal()
|
||||
return
|
||||
}
|
||||
|
||||
const owner = meta!.repo.split('/')[0] || author
|
||||
const repo = meta!.repo.split('/')[1] || name
|
||||
const fetchedReleases = await fetchReleases(owner, repo)
|
||||
if (fetchedReleases.length === 0)
|
||||
return
|
||||
const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta!.version)
|
||||
Toast.notify(toastProps)
|
||||
if (needUpdate) {
|
||||
setShowUpdatePluginModal({
|
||||
onSaveCallback: () => {
|
||||
onUpdate?.()
|
||||
},
|
||||
payload: {
|
||||
type: PluginSource.github,
|
||||
category: detail.declaration.category,
|
||||
github: {
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
repo: meta!.repo,
|
||||
version: meta!.version,
|
||||
package: meta!.package,
|
||||
releases: fetchedReleases,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleUpdatedFromMarketplace = () => {
|
||||
onUpdate?.()
|
||||
hideUpdateModal()
|
||||
}
|
||||
|
||||
const [isShowPluginInfo, {
|
||||
setTrue: showPluginInfo,
|
||||
setFalse: hidePluginInfo,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const [isShowDeleteConfirm, {
|
||||
setTrue: showDeleteConfirm,
|
||||
setFalse: hideDeleteConfirm,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const [deleting, {
|
||||
setTrue: showDeleting,
|
||||
setFalse: hideDeleting,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleDelete = useCallback(async () => {
|
||||
showDeleting()
|
||||
const res = await uninstallPlugin(id)
|
||||
hideDeleting()
|
||||
if (res.success) {
|
||||
hideDeleteConfirm()
|
||||
onUpdate?.(true)
|
||||
if (PluginCategoryEnum.model.includes(category))
|
||||
refreshModelProviders()
|
||||
if (PluginCategoryEnum.tool.includes(category))
|
||||
invalidateAllToolProviders()
|
||||
trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name })
|
||||
}
|
||||
}, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders, plugin_id, name])
|
||||
|
||||
return (
|
||||
<div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}>
|
||||
<div className="flex">
|
||||
<div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}>
|
||||
<Icon src={iconSrc} />
|
||||
</div>
|
||||
<div className="ml-3 w-0 grow">
|
||||
<div className="flex h-5 items-center">
|
||||
<Title title={label[locale]} />
|
||||
{verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />}
|
||||
{!!version && (
|
||||
<PluginVersionPicker
|
||||
disabled={!isFromMarketplace || isReadmeView}
|
||||
isShow={isShow}
|
||||
onShowChange={setIsShow}
|
||||
pluginID={plugin_id}
|
||||
currentVersion={version}
|
||||
onSelect={(state) => {
|
||||
setTargetVersion(state)
|
||||
handleUpdate(state.isDowngrade)
|
||||
}}
|
||||
trigger={(
|
||||
<Badge
|
||||
className={cn(
|
||||
'mx-1',
|
||||
isShow && 'bg-state-base-hover',
|
||||
(isShow || isFromMarketplace) && 'hover:bg-state-base-hover',
|
||||
)}
|
||||
uppercase={false}
|
||||
text={(
|
||||
<>
|
||||
<div>{isFromGitHub ? meta!.version : version}</div>
|
||||
{isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />}
|
||||
</>
|
||||
)}
|
||||
hasRedCornerMark={hasNewVersion}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
{/* Auto update info */}
|
||||
{isAutoUpgradeEnabled && !isReadmeView && (
|
||||
<Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}>
|
||||
{/* add a a div to fix tooltip hover not show problem */}
|
||||
<div>
|
||||
<Badge className="mr-1 cursor-pointer px-1">
|
||||
<AutoUpdateLine className="size-3" />
|
||||
</Badge>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{(hasNewVersion || isFromGitHub) && (
|
||||
<Button
|
||||
variant="secondary-accent"
|
||||
size="small"
|
||||
className="!h-5"
|
||||
onClick={() => {
|
||||
if (isFromMarketplace) {
|
||||
setTargetVersion({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
}
|
||||
handleUpdate()
|
||||
}}
|
||||
>
|
||||
{t('detailPanel.operation.update', { ns: 'plugin' })}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-1 flex h-4 items-center justify-between">
|
||||
<div className="mt-0.5 flex items-center">
|
||||
<OrgInfo
|
||||
packageNameClassName="w-auto"
|
||||
orgName={author}
|
||||
packageName={name?.includes('/') ? (name.split('/').pop() || '') : name}
|
||||
/>
|
||||
{!!source && (
|
||||
<>
|
||||
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
|
||||
{source === PluginSource.marketplace && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.marketplace', { ns: 'plugin' })}>
|
||||
<div><BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.github && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.github', { ns: 'plugin' })}>
|
||||
<div><Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.local && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.local', { ns: 'plugin' })}>
|
||||
<div><RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.debugging && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.debugging', { ns: 'plugin' })}>
|
||||
<div><RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{!isReadmeView && (
|
||||
<div className="flex gap-1">
|
||||
<OperationDropdown
|
||||
source={source}
|
||||
onInfo={showPluginInfo}
|
||||
onCheckVersion={handleUpdate}
|
||||
onRemove={showDeleteConfirm}
|
||||
detailUrl={detailUrl}
|
||||
/>
|
||||
<ActionButton onClick={onHide}>
|
||||
<RiCloseLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{isFromMarketplace && (
|
||||
<DeprecationNotice
|
||||
status={status}
|
||||
deprecatedReason={deprecated_reason}
|
||||
alternativePluginId={alternative_plugin_id}
|
||||
alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })}
|
||||
className="mt-3"
|
||||
/>
|
||||
)}
|
||||
{!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2}></Description>}
|
||||
{
|
||||
category === PluginCategoryEnum.tool && !isReadmeView && (
|
||||
<PluginAuth
|
||||
pluginPayload={{
|
||||
provider: provider?.name || '',
|
||||
category: AuthCategory.tool,
|
||||
providerType: provider?.type || '',
|
||||
detail,
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{isShowPluginInfo && (
|
||||
<PluginInfo
|
||||
repository={isFromGitHub ? meta?.repo : ''}
|
||||
release={version}
|
||||
packageName={meta?.package || ''}
|
||||
onHide={hidePluginInfo}
|
||||
/>
|
||||
)}
|
||||
{isShowDeleteConfirm && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
|
||||
content={(
|
||||
<div>
|
||||
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
|
||||
<span className="system-md-semibold">{label[locale]}</span>
|
||||
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
|
||||
<br />
|
||||
</div>
|
||||
)}
|
||||
onCancel={hideDeleteConfirm}
|
||||
onConfirm={handleDelete}
|
||||
isLoading={deleting}
|
||||
isDisabled={deleting}
|
||||
/>
|
||||
)}
|
||||
{
|
||||
isShowUpdateModal && (
|
||||
<UpdateFromMarketplace
|
||||
pluginId={plugin_id}
|
||||
payload={{
|
||||
category: detail.declaration.category,
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
payload: detail.declaration,
|
||||
},
|
||||
targetPackageInfo: {
|
||||
id: targetVersion.unique_identifier,
|
||||
version: targetVersion.version,
|
||||
},
|
||||
}}
|
||||
onCancel={hideUpdateModal}
|
||||
onSave={handleUpdatedFromMarketplace}
|
||||
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default DetailHeader
|
||||
|
||||
@ -1,539 +0,0 @@
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from '../hooks'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../../types'
|
||||
import HeaderModals from './header-modals'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: () => 'en_US',
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/confirm', () => ({
|
||||
default: ({ isShow, title, onCancel, onConfirm, isLoading }: {
|
||||
isShow: boolean
|
||||
title: string
|
||||
onCancel: () => void
|
||||
onConfirm: () => void
|
||||
isLoading: boolean
|
||||
}) => isShow
|
||||
? (
|
||||
<div data-testid="delete-confirm">
|
||||
<div data-testid="delete-title">{title}</div>
|
||||
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
|
||||
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/plugin-page/plugin-info', () => ({
|
||||
default: ({ repository, release, packageName, onHide }: {
|
||||
repository: string
|
||||
release: string
|
||||
packageName: string
|
||||
onHide: () => void
|
||||
}) => (
|
||||
<div data-testid="plugin-info">
|
||||
<div data-testid="plugin-info-repo">{repository}</div>
|
||||
<div data-testid="plugin-info-release">{release}</div>
|
||||
<div data-testid="plugin-info-package">{packageName}</div>
|
||||
<button data-testid="plugin-info-close" onClick={onHide}>Close</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/update-plugin/from-market-place', () => ({
|
||||
default: ({ pluginId, onSave, onCancel, isShowDowngradeWarningModal }: {
|
||||
pluginId: string
|
||||
onSave: () => void
|
||||
onCancel: () => void
|
||||
isShowDowngradeWarningModal: boolean
|
||||
}) => (
|
||||
<div data-testid="update-modal">
|
||||
<div data-testid="update-plugin-id">{pluginId}</div>
|
||||
<div data-testid="update-downgrade-warning">{String(isShowDowngradeWarningModal)}</div>
|
||||
<button data-testid="update-modal-save" onClick={onSave}>Save</button>
|
||||
<button data-testid="update-modal-cancel" onClick={onCancel}>Cancel</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
|
||||
id: 'test-id',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-02',
|
||||
name: 'Test Plugin',
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_unique_identifier: 'test-uid',
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test Plugin Label' },
|
||||
description: { en_US: 'Test description' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
installation_id: 'install-1',
|
||||
tenant_id: 'tenant-1',
|
||||
endpoints_setups: 0,
|
||||
endpoints_active: 0,
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
latest_unique_identifier: 'new-uid',
|
||||
source: PluginSource.marketplace,
|
||||
meta: undefined,
|
||||
status: 'active',
|
||||
deprecated_reason: '',
|
||||
alternative_plugin_id: '',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModalStatesMock = (overrides: Partial<ModalStates> = {}): ModalStates => ({
|
||||
isShowUpdateModal: false,
|
||||
showUpdateModal: vi.fn<() => void>(),
|
||||
hideUpdateModal: vi.fn<() => void>(),
|
||||
isShowPluginInfo: false,
|
||||
showPluginInfo: vi.fn<() => void>(),
|
||||
hidePluginInfo: vi.fn<() => void>(),
|
||||
isShowDeleteConfirm: false,
|
||||
showDeleteConfirm: vi.fn<() => void>(),
|
||||
hideDeleteConfirm: vi.fn<() => void>(),
|
||||
deleting: false,
|
||||
showDeleting: vi.fn<() => void>(),
|
||||
hideDeleting: vi.fn<() => void>(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createTargetVersion = (overrides: Partial<VersionTarget> = {}): VersionTarget => ({
|
||||
version: '2.0.0',
|
||||
unique_identifier: 'new-uid',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('HeaderModals', () => {
|
||||
let mockOnUpdatedFromMarketplace: () => void
|
||||
let mockOnDelete: () => void
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockOnUpdatedFromMarketplace = vi.fn<() => void>()
|
||||
mockOnDelete = vi.fn<() => void>()
|
||||
})
|
||||
|
||||
describe('Plugin Info Modal', () => {
|
||||
it('should not render plugin info modal when isShowPluginInfo is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('plugin-info')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render plugin info modal when isShowPluginInfo is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass GitHub repo to plugin info for GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'test-pkg' },
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('owner/repo')
|
||||
})
|
||||
|
||||
it('should pass empty string for repo for non-GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail({ source: PluginSource.marketplace })}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
|
||||
})
|
||||
|
||||
it('should call hidePluginInfo when close button is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('plugin-info-close'))
|
||||
|
||||
expect(modalStates.hidePluginInfo).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Delete Confirm Modal', () => {
|
||||
it('should not render delete confirm when isShowDeleteConfirm is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render delete confirm when isShowDeleteConfirm is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show correct delete title', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('delete-title')).toHaveTextContent('action.delete')
|
||||
})
|
||||
|
||||
it('should call hideDeleteConfirm when cancel is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-cancel'))
|
||||
|
||||
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onDelete when confirm is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
|
||||
expect(mockOnDelete).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should disable confirm button when deleting', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true, deleting: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('confirm-ok')).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Update Modal', () => {
|
||||
it('should not render update modal when isShowUpdateModal is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('update-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render update modal when isShowUpdateModal is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass plugin id to update modal', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail({ plugin_id: 'my-plugin-id' })}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-plugin-id')).toHaveTextContent('my-plugin-id')
|
||||
})
|
||||
|
||||
it('should call onUpdatedFromMarketplace when save is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('update-modal-save'))
|
||||
|
||||
expect(mockOnUpdatedFromMarketplace).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call hideUpdateModal when cancel is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('update-modal-cancel'))
|
||||
|
||||
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show downgrade warning when isDowngrade and isAutoUpgradeEnabled are true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={true}
|
||||
isAutoUpgradeEnabled={true}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('true')
|
||||
})
|
||||
|
||||
it('should not show downgrade warning when only isDowngrade is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={true}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
|
||||
})
|
||||
|
||||
it('should not show downgrade warning when only isAutoUpgradeEnabled is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={true}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple Modals', () => {
|
||||
it('should render multiple modals when multiple are open', () => {
|
||||
const modalStates = createModalStatesMock({
|
||||
isShowPluginInfo: true,
|
||||
isShowDeleteConfirm: true,
|
||||
isShowUpdateModal: true,
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined target version values', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={{ version: undefined, unique_identifier: undefined }}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty meta for GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: undefined,
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
|
||||
expect(screen.getByTestId('plugin-info-package')).toHaveTextContent('')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,107 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from '../hooks'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
|
||||
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { PluginSource } from '../../../types'
|
||||
|
||||
const i18nPrefix = 'action'
|
||||
|
||||
type HeaderModalsProps = {
|
||||
detail: PluginDetail
|
||||
modalStates: ModalStates
|
||||
targetVersion: VersionTarget
|
||||
isDowngrade: boolean
|
||||
isAutoUpgradeEnabled: boolean
|
||||
onUpdatedFromMarketplace: () => void
|
||||
onDelete: () => void
|
||||
}
|
||||
|
||||
const HeaderModals: FC<HeaderModalsProps> = ({
|
||||
detail,
|
||||
modalStates,
|
||||
targetVersion,
|
||||
isDowngrade,
|
||||
isAutoUpgradeEnabled,
|
||||
onUpdatedFromMarketplace,
|
||||
onDelete,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const locale = useGetLanguage()
|
||||
|
||||
const { source, version, meta } = detail
|
||||
const { label } = detail.declaration || detail
|
||||
const isFromGitHub = source === PluginSource.github
|
||||
|
||||
const {
|
||||
isShowUpdateModal,
|
||||
hideUpdateModal,
|
||||
isShowPluginInfo,
|
||||
hidePluginInfo,
|
||||
isShowDeleteConfirm,
|
||||
hideDeleteConfirm,
|
||||
deleting,
|
||||
} = modalStates
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Plugin Info Modal */}
|
||||
{isShowPluginInfo && (
|
||||
<PluginInfo
|
||||
repository={isFromGitHub ? meta?.repo : ''}
|
||||
release={version}
|
||||
packageName={meta?.package || ''}
|
||||
onHide={hidePluginInfo}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Delete Confirm Modal */}
|
||||
{isShowDeleteConfirm && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
|
||||
content={(
|
||||
<div>
|
||||
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
|
||||
<span className="system-md-semibold">{label[locale]}</span>
|
||||
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
|
||||
<br />
|
||||
</div>
|
||||
)}
|
||||
onCancel={hideDeleteConfirm}
|
||||
onConfirm={onDelete}
|
||||
isLoading={deleting}
|
||||
isDisabled={deleting}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Update from Marketplace Modal */}
|
||||
{isShowUpdateModal && (
|
||||
<UpdateFromMarketplace
|
||||
pluginId={detail.plugin_id}
|
||||
payload={{
|
||||
category: detail.declaration?.category ?? '',
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
payload: detail.declaration ?? undefined,
|
||||
},
|
||||
targetPackageInfo: {
|
||||
id: targetVersion.unique_identifier || '',
|
||||
version: targetVersion.version || '',
|
||||
},
|
||||
}}
|
||||
onCancel={hideUpdateModal}
|
||||
onSave={onUpdatedFromMarketplace}
|
||||
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default HeaderModals
|
||||
@ -1,2 +0,0 @@
|
||||
export { default as HeaderModals } from './header-modals'
|
||||
export { default as PluginSourceBadge } from './plugin-source-badge'
|
||||
@ -1,200 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../../types'
|
||||
import PluginSourceBadge from './plugin-source-badge'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => (
|
||||
<div data-testid="tooltip" data-content={popupContent}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('PluginSourceBadge', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Source Icon Rendering', () => {
|
||||
it('should render marketplace source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.marketplace')
|
||||
})
|
||||
|
||||
it('should render github source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.github')
|
||||
})
|
||||
|
||||
it('should render local source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.local')
|
||||
})
|
||||
|
||||
it('should render debugging source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.debugging')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Separator Rendering', () => {
|
||||
it('should render separator dot before marketplace badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
expect(separator?.textContent).toBe('·')
|
||||
})
|
||||
|
||||
it('should render separator dot before github badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
expect(separator?.textContent).toBe('·')
|
||||
})
|
||||
|
||||
it('should render separator dot before local badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render separator dot before debugging badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Tooltip Content', () => {
|
||||
it('should show marketplace tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.marketplace',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show github tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.github',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show local tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.local',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show debugging tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.debugging',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Icon Element Structure', () => {
|
||||
it('should render icon inside tooltip for marketplace', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for github', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for local', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for debugging', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Lookup Table Coverage', () => {
|
||||
it('should handle all PluginSource enum values', () => {
|
||||
const allSources = Object.values(PluginSource)
|
||||
|
||||
allSources.forEach((source) => {
|
||||
const { container } = render(<PluginSourceBadge source={source} />)
|
||||
// Should render either tooltip or nothing
|
||||
expect(container).toBeTruthy()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Invalid Source Handling', () => {
|
||||
it('should return null for unknown source type', () => {
|
||||
// Use type assertion to test invalid source value
|
||||
const invalidSource = 'unknown_source' as PluginSource
|
||||
const { container } = render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
// Should render nothing (empty container)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('should not render separator for invalid source', () => {
|
||||
const invalidSource = 'invalid' as PluginSource
|
||||
const { container } = render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render tooltip for invalid source', () => {
|
||||
const invalidSource = '' as PluginSource
|
||||
render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,59 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import {
|
||||
RiBugLine,
|
||||
RiHardDrive3Line,
|
||||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Github } from '@/app/components/base/icons/src/public/common'
|
||||
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { PluginSource } from '../../../types'
|
||||
|
||||
type SourceConfig = {
|
||||
icon: ReactNode
|
||||
tipKey: string
|
||||
}
|
||||
|
||||
type PluginSourceBadgeProps = {
|
||||
source: PluginSource
|
||||
}
|
||||
|
||||
const SOURCE_CONFIG_MAP: Record<PluginSource, SourceConfig | null> = {
|
||||
[PluginSource.marketplace]: {
|
||||
icon: <BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" />,
|
||||
tipKey: 'detailPanel.categoryTip.marketplace',
|
||||
},
|
||||
[PluginSource.github]: {
|
||||
icon: <Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" />,
|
||||
tipKey: 'detailPanel.categoryTip.github',
|
||||
},
|
||||
[PluginSource.local]: {
|
||||
icon: <RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" />,
|
||||
tipKey: 'detailPanel.categoryTip.local',
|
||||
},
|
||||
[PluginSource.debugging]: {
|
||||
icon: <RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" />,
|
||||
tipKey: 'detailPanel.categoryTip.debugging',
|
||||
},
|
||||
}
|
||||
|
||||
const PluginSourceBadge: FC<PluginSourceBadgeProps> = ({ source }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const config = SOURCE_CONFIG_MAP[source]
|
||||
if (!config)
|
||||
return null
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
|
||||
<Tooltip popupContent={t(config.tipKey as never, { ns: 'plugin' })}>
|
||||
<div>{config.icon}</div>
|
||||
</Tooltip>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default PluginSourceBadge
|
||||
@ -1,3 +0,0 @@
|
||||
export { useDetailHeaderState } from './use-detail-header-state'
|
||||
export type { ModalStates, UseDetailHeaderStateReturn, VersionPickerState, VersionTarget } from './use-detail-header-state'
|
||||
export { usePluginOperations } from './use-plugin-operations'
|
||||
@ -1,409 +0,0 @@
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../../types'
|
||||
import { useDetailHeaderState } from './use-detail-header-state'
|
||||
|
||||
let mockEnableMarketplace = true
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) =>
|
||||
selector({ systemFeatures: { enable_marketplace: mockEnableMarketplace } }),
|
||||
}))
|
||||
|
||||
let mockAutoUpgradeInfo: {
|
||||
strategy_setting: string
|
||||
upgrade_mode: string
|
||||
include_plugins: string[]
|
||||
exclude_plugins: string[]
|
||||
upgrade_time_of_day: number
|
||||
} | null = null
|
||||
|
||||
vi.mock('../../../plugin-page/use-reference-setting', () => ({
|
||||
default: () => ({
|
||||
referenceSetting: mockAutoUpgradeInfo ? { auto_upgrade: mockAutoUpgradeInfo } : null,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../../reference-setting-modal/auto-update-setting/types', () => ({
|
||||
AUTO_UPDATE_MODE: {
|
||||
update_all: 'update_all',
|
||||
partial: 'partial',
|
||||
exclude: 'exclude',
|
||||
},
|
||||
}))
|
||||
|
||||
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
|
||||
id: 'test-id',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-02',
|
||||
name: 'Test Plugin',
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_unique_identifier: 'test-uid',
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test Plugin Label' },
|
||||
description: { en_US: 'Test description' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
installation_id: 'install-1',
|
||||
tenant_id: 'tenant-1',
|
||||
endpoints_setups: 0,
|
||||
endpoints_active: 0,
|
||||
version: '1.0.0',
|
||||
latest_version: '1.0.0',
|
||||
latest_unique_identifier: 'test-uid',
|
||||
source: PluginSource.marketplace,
|
||||
meta: undefined,
|
||||
status: 'active',
|
||||
deprecated_reason: '',
|
||||
alternative_plugin_id: '',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('useDetailHeaderState', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAutoUpgradeInfo = null
|
||||
mockEnableMarketplace = true
|
||||
})
|
||||
|
||||
describe('Source Type Detection', () => {
|
||||
it('should detect marketplace source', () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isFromMarketplace).toBe(true)
|
||||
expect(result.current.isFromGitHub).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect GitHub source', () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isFromGitHub).toBe(true)
|
||||
expect(result.current.isFromMarketplace).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect local source', () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.local })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isFromGitHub).toBe(false)
|
||||
expect(result.current.isFromMarketplace).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Version State', () => {
|
||||
it('should detect new version available for marketplace plugin', () => {
|
||||
const detail = createPluginDetail({
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
source: PluginSource.marketplace,
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.hasNewVersion).toBe(true)
|
||||
})
|
||||
|
||||
it('should not detect new version when versions match', () => {
|
||||
const detail = createPluginDetail({
|
||||
version: '1.0.0',
|
||||
latest_version: '1.0.0',
|
||||
source: PluginSource.marketplace,
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.hasNewVersion).toBe(false)
|
||||
})
|
||||
|
||||
it('should not detect new version for non-marketplace source', () => {
|
||||
const detail = createPluginDetail({
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.hasNewVersion).toBe(false)
|
||||
})
|
||||
|
||||
it('should not detect new version when latest_version is empty', () => {
|
||||
const detail = createPluginDetail({
|
||||
version: '1.0.0',
|
||||
latest_version: '',
|
||||
source: PluginSource.marketplace,
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.hasNewVersion).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Version Picker State', () => {
|
||||
it('should initialize version picker as hidden', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.versionPicker.isShow).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle version picker visibility', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.versionPicker.setIsShow(true)
|
||||
})
|
||||
expect(result.current.versionPicker.isShow).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.versionPicker.setIsShow(false)
|
||||
})
|
||||
expect(result.current.versionPicker.isShow).toBe(false)
|
||||
})
|
||||
|
||||
it('should update target version', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.versionPicker.setTargetVersion({
|
||||
version: '2.0.0',
|
||||
unique_identifier: 'new-uid',
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.versionPicker.targetVersion.version).toBe('2.0.0')
|
||||
expect(result.current.versionPicker.targetVersion.unique_identifier).toBe('new-uid')
|
||||
})
|
||||
|
||||
it('should set isDowngrade when provided in target version', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.versionPicker.setTargetVersion({
|
||||
version: '0.5.0',
|
||||
unique_identifier: 'old-uid',
|
||||
isDowngrade: true,
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.versionPicker.isDowngrade).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Modal States', () => {
|
||||
it('should initialize all modals as hidden', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.modalStates.isShowUpdateModal).toBe(false)
|
||||
expect(result.current.modalStates.isShowPluginInfo).toBe(false)
|
||||
expect(result.current.modalStates.isShowDeleteConfirm).toBe(false)
|
||||
expect(result.current.modalStates.deleting).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle update modal', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.showUpdateModal()
|
||||
})
|
||||
expect(result.current.modalStates.isShowUpdateModal).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.hideUpdateModal()
|
||||
})
|
||||
expect(result.current.modalStates.isShowUpdateModal).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle plugin info modal', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.showPluginInfo()
|
||||
})
|
||||
expect(result.current.modalStates.isShowPluginInfo).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.hidePluginInfo()
|
||||
})
|
||||
expect(result.current.modalStates.isShowPluginInfo).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle delete confirm modal', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.showDeleteConfirm()
|
||||
})
|
||||
expect(result.current.modalStates.isShowDeleteConfirm).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.hideDeleteConfirm()
|
||||
})
|
||||
expect(result.current.modalStates.isShowDeleteConfirm).toBe(false)
|
||||
})
|
||||
|
||||
it('should toggle deleting state', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.showDeleting()
|
||||
})
|
||||
expect(result.current.modalStates.deleting).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.modalStates.hideDeleting()
|
||||
})
|
||||
expect(result.current.modalStates.deleting).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Auto Upgrade Detection', () => {
|
||||
it('should disable auto upgrade when marketplace is disabled', () => {
|
||||
mockEnableMarketplace = false
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'update_all',
|
||||
include_plugins: [],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade when strategy is disabled', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'disabled',
|
||||
upgrade_mode: 'update_all',
|
||||
include_plugins: [],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should enable auto upgrade for update_all mode', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'update_all',
|
||||
include_plugins: [],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('should enable auto upgrade for partial mode when plugin is included', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'partial',
|
||||
include_plugins: ['test-plugin'],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade for partial mode when plugin is not included', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'partial',
|
||||
include_plugins: ['other-plugin'],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should enable auto upgrade for exclude mode when plugin is not excluded', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'exclude',
|
||||
include_plugins: [],
|
||||
exclude_plugins: ['other-plugin'],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade for exclude mode when plugin is excluded', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'exclude',
|
||||
include_plugins: [],
|
||||
exclude_plugins: ['test-plugin'],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade for non-marketplace source', () => {
|
||||
mockAutoUpgradeInfo = {
|
||||
strategy_setting: 'enabled',
|
||||
upgrade_mode: 'update_all',
|
||||
include_plugins: [],
|
||||
exclude_plugins: [],
|
||||
upgrade_time_of_day: 36000,
|
||||
}
|
||||
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('should disable auto upgrade when no auto upgrade info', () => {
|
||||
mockAutoUpgradeInfo = null
|
||||
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() => useDetailHeaderState(detail))
|
||||
|
||||
expect(result.current.isAutoUpgradeEnabled).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,132 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useReferenceSetting from '../../../plugin-page/use-reference-setting'
|
||||
import { AUTO_UPDATE_MODE } from '../../../reference-setting-modal/auto-update-setting/types'
|
||||
import { PluginSource } from '../../../types'
|
||||
|
||||
export type VersionTarget = {
|
||||
version: string | undefined
|
||||
unique_identifier: string | undefined
|
||||
isDowngrade?: boolean
|
||||
}
|
||||
|
||||
export type ModalStates = {
|
||||
isShowUpdateModal: boolean
|
||||
showUpdateModal: () => void
|
||||
hideUpdateModal: () => void
|
||||
isShowPluginInfo: boolean
|
||||
showPluginInfo: () => void
|
||||
hidePluginInfo: () => void
|
||||
isShowDeleteConfirm: boolean
|
||||
showDeleteConfirm: () => void
|
||||
hideDeleteConfirm: () => void
|
||||
deleting: boolean
|
||||
showDeleting: () => void
|
||||
hideDeleting: () => void
|
||||
}
|
||||
|
||||
export type VersionPickerState = {
|
||||
isShow: boolean
|
||||
setIsShow: (show: boolean) => void
|
||||
targetVersion: VersionTarget
|
||||
setTargetVersion: (version: VersionTarget) => void
|
||||
isDowngrade: boolean
|
||||
setIsDowngrade: (downgrade: boolean) => void
|
||||
}
|
||||
|
||||
export type UseDetailHeaderStateReturn = {
|
||||
modalStates: ModalStates
|
||||
versionPicker: VersionPickerState
|
||||
hasNewVersion: boolean
|
||||
isAutoUpgradeEnabled: boolean
|
||||
isFromGitHub: boolean
|
||||
isFromMarketplace: boolean
|
||||
}
|
||||
|
||||
export const useDetailHeaderState = (detail: PluginDetail): UseDetailHeaderStateReturn => {
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { referenceSetting } = useReferenceSetting()
|
||||
|
||||
const {
|
||||
source,
|
||||
version,
|
||||
latest_version,
|
||||
latest_unique_identifier,
|
||||
plugin_id,
|
||||
} = detail
|
||||
|
||||
const isFromGitHub = source === PluginSource.github
|
||||
const isFromMarketplace = source === PluginSource.marketplace
|
||||
const [isShow, setIsShow] = useState(false)
|
||||
const [targetVersion, setTargetVersion] = useState<VersionTarget>({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
const [isDowngrade, setIsDowngrade] = useState(false)
|
||||
|
||||
const [isShowUpdateModal, { setTrue: showUpdateModal, setFalse: hideUpdateModal }] = useBoolean(false)
|
||||
const [isShowPluginInfo, { setTrue: showPluginInfo, setFalse: hidePluginInfo }] = useBoolean(false)
|
||||
const [isShowDeleteConfirm, { setTrue: showDeleteConfirm, setFalse: hideDeleteConfirm }] = useBoolean(false)
|
||||
const [deleting, { setTrue: showDeleting, setFalse: hideDeleting }] = useBoolean(false)
|
||||
|
||||
const hasNewVersion = useMemo(() => {
|
||||
if (isFromMarketplace)
|
||||
return !!latest_version && latest_version !== version
|
||||
return false
|
||||
}, [isFromMarketplace, latest_version, version])
|
||||
|
||||
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
||||
|
||||
const isAutoUpgradeEnabled = useMemo(() => {
|
||||
if (!enable_marketplace || !autoUpgradeInfo || !isFromMarketplace)
|
||||
return false
|
||||
if (autoUpgradeInfo.strategy_setting === 'disabled')
|
||||
return false
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
|
||||
return true
|
||||
return false
|
||||
}, [autoUpgradeInfo, plugin_id, isFromMarketplace, enable_marketplace])
|
||||
|
||||
const handleSetTargetVersion = useCallback((version: VersionTarget) => {
|
||||
setTargetVersion(version)
|
||||
if (version.isDowngrade !== undefined)
|
||||
setIsDowngrade(version.isDowngrade)
|
||||
}, [])
|
||||
|
||||
return {
|
||||
modalStates: {
|
||||
isShowUpdateModal,
|
||||
showUpdateModal,
|
||||
hideUpdateModal,
|
||||
isShowPluginInfo,
|
||||
showPluginInfo,
|
||||
hidePluginInfo,
|
||||
isShowDeleteConfirm,
|
||||
showDeleteConfirm,
|
||||
hideDeleteConfirm,
|
||||
deleting,
|
||||
showDeleting,
|
||||
hideDeleting,
|
||||
},
|
||||
versionPicker: {
|
||||
isShow,
|
||||
setIsShow,
|
||||
targetVersion,
|
||||
setTargetVersion: handleSetTargetVersion,
|
||||
isDowngrade,
|
||||
setIsDowngrade,
|
||||
},
|
||||
hasNewVersion,
|
||||
isAutoUpgradeEnabled,
|
||||
isFromGitHub,
|
||||
isFromMarketplace,
|
||||
}
|
||||
}
|
||||
@ -1,549 +0,0 @@
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from './use-detail-header-state'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import * as amplitude from '@/app/components/base/amplitude'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { PluginSource } from '../../../types'
|
||||
import { usePluginOperations } from './use-plugin-operations'
|
||||
|
||||
type VersionPickerMock = {
|
||||
setTargetVersion: (version: VersionTarget) => void
|
||||
setIsDowngrade: (downgrade: boolean) => void
|
||||
}
|
||||
|
||||
const {
|
||||
mockSetShowUpdatePluginModal,
|
||||
mockRefreshModelProviders,
|
||||
mockInvalidateAllToolProviders,
|
||||
mockUninstallPlugin,
|
||||
mockFetchReleases,
|
||||
mockCheckForUpdates,
|
||||
} = vi.hoisted(() => {
|
||||
return {
|
||||
mockSetShowUpdatePluginModal: vi.fn(),
|
||||
mockRefreshModelProviders: vi.fn(),
|
||||
mockInvalidateAllToolProviders: vi.fn(),
|
||||
mockUninstallPlugin: vi.fn(() => Promise.resolve({ success: true })),
|
||||
mockFetchReleases: vi.fn(() => Promise.resolve([{ tag_name: 'v2.0.0' }])),
|
||||
mockCheckForUpdates: vi.fn(() => ({ needUpdate: true, toastProps: { type: 'success', message: 'Update available' } })),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContext: () => ({
|
||||
setShowUpdatePluginModal: mockSetShowUpdatePluginModal,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
refreshModelProviders: mockRefreshModelProviders,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/plugins', () => ({
|
||||
uninstallPlugin: mockUninstallPlugin,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders,
|
||||
}))
|
||||
|
||||
vi.mock('../../../install-plugin/hooks', () => ({
|
||||
useGitHubReleases: () => ({
|
||||
checkForUpdates: mockCheckForUpdates,
|
||||
fetchReleases: mockFetchReleases,
|
||||
}),
|
||||
}))
|
||||
|
||||
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
|
||||
id: 'test-id',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-02',
|
||||
name: 'Test Plugin',
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_unique_identifier: 'test-uid',
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test Plugin Label' },
|
||||
description: { en_US: 'Test description' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
installation_id: 'install-1',
|
||||
tenant_id: 'tenant-1',
|
||||
endpoints_setups: 0,
|
||||
endpoints_active: 0,
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
latest_unique_identifier: 'new-uid',
|
||||
source: PluginSource.marketplace,
|
||||
meta: undefined,
|
||||
status: 'active',
|
||||
deprecated_reason: '',
|
||||
alternative_plugin_id: '',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModalStatesMock = (): ModalStates => ({
|
||||
isShowUpdateModal: false,
|
||||
showUpdateModal: vi.fn(),
|
||||
hideUpdateModal: vi.fn(),
|
||||
isShowPluginInfo: false,
|
||||
showPluginInfo: vi.fn(),
|
||||
hidePluginInfo: vi.fn(),
|
||||
isShowDeleteConfirm: false,
|
||||
showDeleteConfirm: vi.fn(),
|
||||
hideDeleteConfirm: vi.fn(),
|
||||
deleting: false,
|
||||
showDeleting: vi.fn(),
|
||||
hideDeleting: vi.fn(),
|
||||
})
|
||||
|
||||
const createVersionPickerMock = (): VersionPickerMock => ({
|
||||
setTargetVersion: vi.fn<(version: VersionTarget) => void>(),
|
||||
setIsDowngrade: vi.fn<(downgrade: boolean) => void>(),
|
||||
})
|
||||
|
||||
describe('usePluginOperations', () => {
|
||||
let modalStates: ModalStates
|
||||
let versionPicker: VersionPickerMock
|
||||
let mockOnUpdate: (isDelete?: boolean) => void
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
modalStates = createModalStatesMock()
|
||||
versionPicker = createVersionPickerMock()
|
||||
mockOnUpdate = vi.fn()
|
||||
vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() }))
|
||||
vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {})
|
||||
})
|
||||
|
||||
describe('Marketplace Update Flow', () => {
|
||||
it('should show update modal for marketplace plugin', async () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(modalStates.showUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should set isDowngrade when downgrading', async () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate(true)
|
||||
})
|
||||
|
||||
expect(versionPicker.setIsDowngrade).toHaveBeenCalledWith(true)
|
||||
expect(modalStates.showUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onUpdate and hide modal on successful marketplace update', () => {
|
||||
const detail = createPluginDetail({ source: PluginSource.marketplace })
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleUpdatedFromMarketplace()
|
||||
})
|
||||
|
||||
expect(mockOnUpdate).toHaveBeenCalled()
|
||||
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('GitHub Update Flow', () => {
|
||||
it('should fetch releases from GitHub', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockFetchReleases).toHaveBeenCalledWith('owner', 'repo')
|
||||
})
|
||||
|
||||
it('should check for updates after fetching releases', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockCheckForUpdates).toHaveBeenCalled()
|
||||
expect(Toast.notify).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show update plugin modal when update is needed', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockSetShowUpdatePluginModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not show modal when no releases found', async () => {
|
||||
mockFetchReleases.mockResolvedValueOnce([])
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockSetShowUpdatePluginModal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not show modal when no update needed', async () => {
|
||||
mockCheckForUpdates.mockReturnValueOnce({
|
||||
needUpdate: false,
|
||||
toastProps: { type: 'info', message: 'Already up to date' },
|
||||
})
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockSetShowUpdatePluginModal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should use author and name as fallback for repo parsing', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: '/', version: 'v1.0.0', package: 'pkg' },
|
||||
declaration: {
|
||||
author: 'fallback-author',
|
||||
name: 'fallback-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test' },
|
||||
description: { en_US: 'Test' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
expect(mockFetchReleases).toHaveBeenCalledWith('fallback-author', 'fallback-name')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Delete Flow', () => {
|
||||
it('should call uninstallPlugin with correct id', async () => {
|
||||
const detail = createPluginDetail({ id: 'plugin-to-delete' })
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockUninstallPlugin).toHaveBeenCalledWith('plugin-to-delete')
|
||||
})
|
||||
|
||||
it('should show and hide deleting state during delete', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(modalStates.showDeleting).toHaveBeenCalled()
|
||||
expect(modalStates.hideDeleting).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onUpdate with true after successful delete', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockOnUpdate).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('should hide delete confirm after successful delete', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should refresh model providers when deleting model plugin', async () => {
|
||||
const detail = createPluginDetail({
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'model',
|
||||
label: { en_US: 'Test' },
|
||||
description: { en_US: 'Test' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockRefreshModelProviders).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should invalidate tool providers when deleting tool plugin', async () => {
|
||||
const detail = createPluginDetail({
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test' },
|
||||
description: { en_US: 'Test' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockInvalidateAllToolProviders).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should track plugin uninstalled event', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(amplitude.trackEvent).toHaveBeenCalledWith('plugin_uninstalled', expect.objectContaining({
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_name: 'test-plugin-name',
|
||||
}))
|
||||
})
|
||||
|
||||
it('should not call onUpdate when delete fails', async () => {
|
||||
mockUninstallPlugin.mockResolvedValueOnce({ success: false })
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockOnUpdate).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Optional onUpdate Callback', () => {
|
||||
it('should not throw when onUpdate is not provided for marketplace update', () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(() => {
|
||||
result.current.handleUpdatedFromMarketplace()
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('should not throw when onUpdate is not provided for delete', async () => {
|
||||
const detail = createPluginDetail()
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: true,
|
||||
}),
|
||||
)
|
||||
|
||||
await expect(
|
||||
act(async () => {
|
||||
await result.current.handleDelete()
|
||||
}),
|
||||
).resolves.not.toThrow()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,143 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from './use-detail-header-state'
|
||||
import { useCallback } from 'react'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { uninstallPlugin } from '@/service/plugins'
|
||||
import { useInvalidateAllToolProviders } from '@/service/use-tools'
|
||||
import { useGitHubReleases } from '../../../install-plugin/hooks'
|
||||
import { PluginCategoryEnum, PluginSource } from '../../../types'
|
||||
|
||||
type UsePluginOperationsParams = {
|
||||
detail: PluginDetail
|
||||
modalStates: ModalStates
|
||||
versionPicker: {
|
||||
setTargetVersion: (version: VersionTarget) => void
|
||||
setIsDowngrade: (downgrade: boolean) => void
|
||||
}
|
||||
isFromMarketplace: boolean
|
||||
onUpdate?: (isDelete?: boolean) => void
|
||||
}
|
||||
|
||||
type UsePluginOperationsReturn = {
|
||||
handleUpdate: (isDowngrade?: boolean) => Promise<void>
|
||||
handleUpdatedFromMarketplace: () => void
|
||||
handleDelete: () => Promise<void>
|
||||
}
|
||||
|
||||
export const usePluginOperations = ({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace,
|
||||
onUpdate,
|
||||
}: UsePluginOperationsParams): UsePluginOperationsReturn => {
|
||||
const { checkForUpdates, fetchReleases } = useGitHubReleases()
|
||||
const { setShowUpdatePluginModal } = useModalContext()
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
||||
|
||||
const { id, meta, plugin_id } = detail
|
||||
const { author, category, name } = detail.declaration || detail
|
||||
|
||||
const handleUpdate = useCallback(async (isDowngrade?: boolean) => {
|
||||
if (isFromMarketplace) {
|
||||
versionPicker.setIsDowngrade(!!isDowngrade)
|
||||
modalStates.showUpdateModal()
|
||||
return
|
||||
}
|
||||
|
||||
if (!meta?.repo || !meta?.version || !meta?.package) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'Missing plugin metadata for GitHub update',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const owner = meta.repo.split('/')[0] || author
|
||||
const repo = meta.repo.split('/')[1] || name
|
||||
const fetchedReleases = await fetchReleases(owner, repo)
|
||||
if (fetchedReleases.length === 0)
|
||||
return
|
||||
|
||||
const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta.version)
|
||||
Toast.notify(toastProps)
|
||||
|
||||
if (needUpdate) {
|
||||
setShowUpdatePluginModal({
|
||||
onSaveCallback: () => {
|
||||
onUpdate?.()
|
||||
},
|
||||
payload: {
|
||||
type: PluginSource.github,
|
||||
category,
|
||||
github: {
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
repo: meta.repo,
|
||||
version: meta.version,
|
||||
package: meta.package,
|
||||
releases: fetchedReleases,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}, [
|
||||
isFromMarketplace,
|
||||
meta,
|
||||
author,
|
||||
name,
|
||||
fetchReleases,
|
||||
checkForUpdates,
|
||||
setShowUpdatePluginModal,
|
||||
detail,
|
||||
onUpdate,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
])
|
||||
|
||||
const handleUpdatedFromMarketplace = useCallback(() => {
|
||||
onUpdate?.()
|
||||
modalStates.hideUpdateModal()
|
||||
}, [onUpdate, modalStates])
|
||||
|
||||
const handleDelete = useCallback(async () => {
|
||||
modalStates.showDeleting()
|
||||
const res = await uninstallPlugin(id)
|
||||
modalStates.hideDeleting()
|
||||
|
||||
if (res.success) {
|
||||
modalStates.hideDeleteConfirm()
|
||||
onUpdate?.(true)
|
||||
|
||||
if (PluginCategoryEnum.model.includes(category))
|
||||
refreshModelProviders()
|
||||
|
||||
if (PluginCategoryEnum.tool.includes(category))
|
||||
invalidateAllToolProviders()
|
||||
|
||||
trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name })
|
||||
}
|
||||
}, [
|
||||
id,
|
||||
category,
|
||||
plugin_id,
|
||||
name,
|
||||
modalStates,
|
||||
onUpdate,
|
||||
refreshModelProviders,
|
||||
invalidateAllToolProviders,
|
||||
])
|
||||
|
||||
return {
|
||||
handleUpdate,
|
||||
handleUpdatedFromMarketplace,
|
||||
handleDelete,
|
||||
}
|
||||
}
|
||||
@ -1,286 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { PluginDetail } from '../../types'
|
||||
import {
|
||||
RiArrowLeftRightLine,
|
||||
RiCloseLine,
|
||||
} from '@remixicon/react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth'
|
||||
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
|
||||
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGetLanguage, useLocale } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { useAllToolProviders } from '@/service/use-tools'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import { AutoUpdateLine } from '../../../base/icons/src/vender/system'
|
||||
import Verified from '../../base/badges/verified'
|
||||
import DeprecationNotice from '../../base/deprecation-notice'
|
||||
import Icon from '../../card/base/card-icon'
|
||||
import Description from '../../card/base/description'
|
||||
import OrgInfo from '../../card/base/org-info'
|
||||
import Title from '../../card/base/title'
|
||||
import useReferenceSetting from '../../plugin-page/use-reference-setting'
|
||||
import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../../reference-setting-modal/auto-update-setting/utils'
|
||||
import { PluginCategoryEnum, PluginSource } from '../../types'
|
||||
import { HeaderModals, PluginSourceBadge } from './components'
|
||||
import { useDetailHeaderState, usePluginOperations } from './hooks'
|
||||
|
||||
type Props = {
|
||||
detail: PluginDetail
|
||||
isReadmeView?: boolean
|
||||
onHide?: () => void
|
||||
onUpdate?: (isDelete?: boolean) => void
|
||||
}
|
||||
|
||||
const getIconSrc = (icon: string | undefined, iconDark: string | undefined, theme: string, tenantId: string): string => {
|
||||
const iconFileName = theme === 'dark' && iconDark ? iconDark : icon
|
||||
if (!iconFileName)
|
||||
return ''
|
||||
return iconFileName.startsWith('http')
|
||||
? iconFileName
|
||||
: `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenantId}&filename=${iconFileName}`
|
||||
}
|
||||
|
||||
const getDetailUrl = (
|
||||
source: PluginSource,
|
||||
meta: PluginDetail['meta'],
|
||||
author: string,
|
||||
name: string,
|
||||
locale: string,
|
||||
theme: string,
|
||||
): string => {
|
||||
if (source === PluginSource.github) {
|
||||
const repo = meta?.repo
|
||||
if (!repo)
|
||||
return ''
|
||||
return `https://github.com/${repo}`
|
||||
}
|
||||
if (source === PluginSource.marketplace)
|
||||
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: locale, theme })
|
||||
return ''
|
||||
}
|
||||
|
||||
const DetailHeader = ({
|
||||
detail,
|
||||
isReadmeView = false,
|
||||
onHide,
|
||||
onUpdate,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { userProfile: { timezone } } = useAppContext()
|
||||
const { theme } = useTheme()
|
||||
const locale = useGetLanguage()
|
||||
const currentLocale = useLocale()
|
||||
const { referenceSetting } = useReferenceSetting()
|
||||
|
||||
const {
|
||||
source,
|
||||
tenant_id,
|
||||
version,
|
||||
latest_version,
|
||||
latest_unique_identifier,
|
||||
meta,
|
||||
plugin_id,
|
||||
status,
|
||||
deprecated_reason,
|
||||
alternative_plugin_id,
|
||||
} = detail
|
||||
|
||||
const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail
|
||||
|
||||
const {
|
||||
modalStates,
|
||||
versionPicker,
|
||||
hasNewVersion,
|
||||
isAutoUpgradeEnabled,
|
||||
isFromGitHub,
|
||||
isFromMarketplace,
|
||||
} = useDetailHeaderState(detail)
|
||||
|
||||
const {
|
||||
handleUpdate,
|
||||
handleUpdatedFromMarketplace,
|
||||
handleDelete,
|
||||
} = usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace,
|
||||
onUpdate,
|
||||
})
|
||||
|
||||
const isTool = category === PluginCategoryEnum.tool
|
||||
const providerBriefInfo = tool?.identity
|
||||
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
|
||||
const { data: collectionList = [] } = useAllToolProviders(isTool)
|
||||
const provider = useMemo(() => {
|
||||
return collectionList.find(collection => collection.name === providerKey)
|
||||
}, [collectionList, providerKey])
|
||||
|
||||
const iconSrc = getIconSrc(icon, icon_dark, theme, tenant_id)
|
||||
const detailUrl = getDetailUrl(source, meta, author, name, currentLocale, theme)
|
||||
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
||||
|
||||
const handleVersionSelect = (state: { version: string, unique_identifier: string, isDowngrade?: boolean }) => {
|
||||
versionPicker.setTargetVersion(state)
|
||||
handleUpdate(state.isDowngrade)
|
||||
}
|
||||
|
||||
const handleTriggerLatestUpdate = () => {
|
||||
if (isFromMarketplace) {
|
||||
versionPicker.setTargetVersion({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
}
|
||||
handleUpdate()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}>
|
||||
<div className="flex">
|
||||
{/* Plugin Icon */}
|
||||
<div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}>
|
||||
<Icon src={iconSrc} />
|
||||
</div>
|
||||
|
||||
{/* Plugin Info */}
|
||||
<div className="ml-3 w-0 grow">
|
||||
{/* Title Row */}
|
||||
<div className="flex h-5 items-center">
|
||||
<Title title={label[locale]} />
|
||||
{verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />}
|
||||
|
||||
{/* Version Picker */}
|
||||
{!!version && (
|
||||
<PluginVersionPicker
|
||||
disabled={!isFromMarketplace || isReadmeView}
|
||||
isShow={versionPicker.isShow}
|
||||
onShowChange={versionPicker.setIsShow}
|
||||
pluginID={plugin_id}
|
||||
currentVersion={version}
|
||||
onSelect={handleVersionSelect}
|
||||
trigger={(
|
||||
<Badge
|
||||
className={cn(
|
||||
'mx-1',
|
||||
versionPicker.isShow && 'bg-state-base-hover',
|
||||
(versionPicker.isShow || isFromMarketplace) && 'hover:bg-state-base-hover',
|
||||
)}
|
||||
uppercase={false}
|
||||
text={(
|
||||
<>
|
||||
<div>{isFromGitHub ? (meta?.version ?? version ?? '') : version}</div>
|
||||
{isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />}
|
||||
</>
|
||||
)}
|
||||
hasRedCornerMark={hasNewVersion}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Auto Update Badge */}
|
||||
{isAutoUpgradeEnabled && !isReadmeView && (
|
||||
<Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}>
|
||||
<div>
|
||||
<Badge className="mr-1 cursor-pointer px-1">
|
||||
<AutoUpdateLine className="size-3" />
|
||||
</Badge>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{/* Update Button */}
|
||||
{(hasNewVersion || isFromGitHub) && (
|
||||
<Button
|
||||
variant="secondary-accent"
|
||||
size="small"
|
||||
className="!h-5"
|
||||
onClick={handleTriggerLatestUpdate}
|
||||
>
|
||||
{t('detailPanel.operation.update', { ns: 'plugin' })}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Org Info Row */}
|
||||
<div className="mb-1 flex h-4 items-center justify-between">
|
||||
<div className="mt-0.5 flex items-center">
|
||||
<OrgInfo
|
||||
packageNameClassName="w-auto"
|
||||
orgName={author}
|
||||
packageName={name?.includes('/') ? (name.split('/').pop() || '') : name}
|
||||
/>
|
||||
{!!source && <PluginSourceBadge source={source} />}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Action Buttons */}
|
||||
{!isReadmeView && (
|
||||
<div className="flex gap-1">
|
||||
<OperationDropdown
|
||||
source={source}
|
||||
onInfo={modalStates.showPluginInfo}
|
||||
onCheckVersion={handleUpdate}
|
||||
onRemove={modalStates.showDeleteConfirm}
|
||||
detailUrl={detailUrl}
|
||||
/>
|
||||
<ActionButton onClick={onHide}>
|
||||
<RiCloseLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Deprecation Notice */}
|
||||
{isFromMarketplace && (
|
||||
<DeprecationNotice
|
||||
status={status}
|
||||
deprecatedReason={deprecated_reason}
|
||||
alternativePluginId={alternative_plugin_id}
|
||||
alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })}
|
||||
className="mt-3"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Description */}
|
||||
{!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2} />}
|
||||
|
||||
{/* Plugin Auth for Tools */}
|
||||
{category === PluginCategoryEnum.tool && !isReadmeView && (
|
||||
<PluginAuth
|
||||
pluginPayload={{
|
||||
provider: provider?.name || '',
|
||||
category: AuthCategory.tool,
|
||||
providerType: provider?.type || '',
|
||||
detail,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Modals */}
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={versionPicker.targetVersion}
|
||||
isDowngrade={versionPicker.isDowngrade}
|
||||
isAutoUpgradeEnabled={isAutoUpgradeEnabled}
|
||||
onUpdatedFromMarketplace={handleUpdatedFromMarketplace}
|
||||
onDelete={handleDelete}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default DetailHeader
|
||||
@ -1,11 +1,16 @@
|
||||
import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
// Import after mocks
|
||||
import { SupportedCreationMethods } from '@/app/components/plugins/types'
|
||||
import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types'
|
||||
import { CommonCreateModal } from './common-modal'
|
||||
|
||||
// ============================================================================
|
||||
// Type Definitions
|
||||
// ============================================================================
|
||||
|
||||
type PluginDetail = {
|
||||
plugin_id: string
|
||||
provider: string
|
||||
@ -28,6 +33,10 @@ type TriggerLogEntity = {
|
||||
level: 'info' | 'warn' | 'error'
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Factory Functions
|
||||
// ============================================================================
|
||||
|
||||
function createMockPluginDetail(overrides: Partial<PluginDetail> = {}): PluginDetail {
|
||||
return {
|
||||
plugin_id: 'test-plugin-id',
|
||||
@ -65,12 +74,18 @@ function createMockLogData(logs: TriggerLogEntity[] = []): { logs: TriggerLogEnt
|
||||
return { logs }
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Setup
|
||||
// ============================================================================
|
||||
|
||||
// Mock plugin store
|
||||
const mockPluginDetail = createMockPluginDetail()
|
||||
const mockUsePluginStore = vi.fn(() => mockPluginDetail)
|
||||
vi.mock('../../store', () => ({
|
||||
usePluginStore: () => mockUsePluginStore(),
|
||||
}))
|
||||
|
||||
// Mock subscription list hook
|
||||
const mockRefetch = vi.fn()
|
||||
vi.mock('../use-subscription-list', () => ({
|
||||
useSubscriptionList: () => ({
|
||||
@ -78,11 +93,13 @@ vi.mock('../use-subscription-list', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock service hooks
|
||||
const mockVerifyCredentials = vi.fn()
|
||||
const mockCreateBuilder = vi.fn()
|
||||
const mockBuildSubscription = vi.fn()
|
||||
const mockUpdateBuilder = vi.fn()
|
||||
|
||||
// Configurable pending states
|
||||
let mockIsVerifyingCredentials = false
|
||||
let mockIsBuilding = false
|
||||
const setMockPendingStates = (verifying: boolean, building: boolean) => {
|
||||
@ -112,15 +129,18 @@ vi.mock('@/service/use-triggers', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock error parser
|
||||
const mockParsePluginErrorMessage = vi.fn().mockResolvedValue(null)
|
||||
vi.mock('@/utils/error-parser', () => ({
|
||||
parsePluginErrorMessage: (...args: unknown[]) => mockParsePluginErrorMessage(...args),
|
||||
}))
|
||||
|
||||
// Mock URL validation
|
||||
vi.mock('@/utils/urlValidation', () => ({
|
||||
isPrivateOrLocalAddress: vi.fn().mockReturnValue(false),
|
||||
}))
|
||||
|
||||
// Mock toast
|
||||
const mockToastNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
@ -128,6 +148,7 @@ vi.mock('@/app/components/base/toast', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock Modal component
|
||||
vi.mock('@/app/components/base/modal/modal', () => ({
|
||||
default: ({
|
||||
children,
|
||||
@ -158,6 +179,7 @@ vi.mock('@/app/components/base/modal/modal', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// Configurable form mock values
|
||||
type MockFormValuesConfig = {
|
||||
values: Record<string, unknown>
|
||||
isCheckValidated: boolean
|
||||
@ -168,6 +190,7 @@ let mockFormValuesConfig: MockFormValuesConfig = {
|
||||
}
|
||||
let mockGetFormReturnsNull = false
|
||||
|
||||
// Separate validation configs for different forms
|
||||
let mockSubscriptionFormValidated = true
|
||||
let mockAutoParamsFormValidated = true
|
||||
let mockManualPropsFormValidated = true
|
||||
@ -184,6 +207,7 @@ const setMockFormValidation = (subscription: boolean, autoParams: boolean, manua
|
||||
mockManualPropsFormValidated = manualProps
|
||||
}
|
||||
|
||||
// Mock BaseForm component with ref support
|
||||
vi.mock('@/app/components/base/form/components/base', async () => {
|
||||
const React = await import('react')
|
||||
|
||||
@ -195,6 +219,7 @@ vi.mock('@/app/components/base/form/components/base', async () => {
|
||||
type MockBaseFormProps = { formSchemas: Array<{ name: string }>, onChange?: () => void }
|
||||
|
||||
function MockBaseFormInner({ formSchemas, onChange }: MockBaseFormProps, ref: React.ForwardedRef<MockFormRef>) {
|
||||
// Determine which form this is based on schema
|
||||
const isSubscriptionForm = formSchemas.some((s: { name: string }) => s.name === 'subscription_name')
|
||||
const isAutoParamsForm = formSchemas.some((s: { name: string }) =>
|
||||
['repo_name', 'branch', 'repo', 'text_field', 'dynamic_field', 'bool_field', 'text_input_field', 'unknown_field', 'count'].includes(s.name),
|
||||
@ -240,10 +265,12 @@ vi.mock('@/app/components/base/form/components/base', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
// Mock EncryptedBottom component
|
||||
vi.mock('@/app/components/base/encrypted-bottom', () => ({
|
||||
EncryptedBottom: () => <div data-testid="encrypted-bottom">Encrypted</div>,
|
||||
}))
|
||||
|
||||
// Mock LogViewer component
|
||||
vi.mock('../log-viewer', () => ({
|
||||
default: ({ logs }: { logs: TriggerLogEntity[] }) => (
|
||||
<div data-testid="log-viewer">
|
||||
@ -254,6 +281,7 @@ vi.mock('../log-viewer', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// Mock debounce
|
||||
vi.mock('es-toolkit/compat', () => ({
|
||||
debounce: (fn: (...args: unknown[]) => unknown) => {
|
||||
const debouncedFn = (...args: unknown[]) => fn(...args)
|
||||
@ -262,6 +290,10 @@ vi.mock('es-toolkit/compat', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// ============================================================================
|
||||
// Test Suites
|
||||
// ============================================================================
|
||||
|
||||
describe('CommonCreateModal', () => {
|
||||
const defaultProps = {
|
||||
onClose: vi.fn(),
|
||||
@ -409,8 +441,7 @@ describe('CommonCreateModal', () => {
|
||||
})
|
||||
|
||||
it('should call onConfirm handler when confirm button is clicked', () => {
|
||||
// Provide builder so the guard passes and credentials check is reached
|
||||
render(<CommonCreateModal {...defaultProps} builder={createMockSubscriptionBuilder()} />)
|
||||
render(<CommonCreateModal {...defaultProps} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-confirm'))
|
||||
|
||||
@ -790,9 +821,6 @@ describe('CommonCreateModal', () => {
|
||||
expect(mockCreateBuilder).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Flush pending state updates from createBuilder promise resolution
|
||||
await act(async () => {})
|
||||
|
||||
const input = screen.getByTestId('form-field-webhook_url')
|
||||
fireEvent.change(input, { target: { value: 'https://example.com/webhook' } })
|
||||
|
||||
@ -1212,22 +1240,13 @@ describe('CommonCreateModal', () => {
|
||||
|
||||
render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.MANUAL} />)
|
||||
|
||||
// Wait for createBuilder to complete and state to update
|
||||
await waitFor(() => {
|
||||
expect(mockCreateBuilder).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Allow React to process the state update from createBuilder
|
||||
await act(async () => {})
|
||||
|
||||
const input = screen.getByTestId('form-field-webhook_url')
|
||||
fireEvent.change(input, { target: { value: 'https://example.com/webhook' } })
|
||||
|
||||
// Wait for updateBuilder to be called, then check the toast
|
||||
await waitFor(() => {
|
||||
expect(mockUpdateBuilder).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
@ -1428,8 +1447,7 @@ describe('CommonCreateModal', () => {
|
||||
})
|
||||
mockUsePluginStore.mockReturnValue(detailWithCredentials)
|
||||
|
||||
// Provide builder so the guard passes and credentials check is reached
|
||||
render(<CommonCreateModal {...defaultProps} builder={createMockSubscriptionBuilder()} />)
|
||||
render(<CommonCreateModal {...defaultProps} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('modal-confirm'))
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user