mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 05:37:48 +08:00
feat: dev snippet backend (#36804)
This commit is contained in:
@ -53,6 +53,7 @@ def _normalize_snippet_list_query_args(query_args: MultiDict[str, str]) -> dict[
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
@ -104,7 +105,7 @@ class CustomizedSnippetsApi(Resource):
|
||||
@console_ns.doc("create_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -161,7 +162,7 @@ class CustomizedSnippetDetailApi(Resource):
|
||||
@console_ns.doc("update_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -35,7 +35,6 @@ class CustomizedSnippet(Base):
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
|
||||
@ -84,6 +84,8 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
class SnippetPendingData(BaseModel):
|
||||
import_mode: str
|
||||
yaml_content: str
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
snippet_id: str | None
|
||||
|
||||
|
||||
@ -255,6 +257,8 @@ class SnippetDslService:
|
||||
pending_data = SnippetPendingData(
|
||||
import_mode=import_mode,
|
||||
yaml_content=content,
|
||||
name=name,
|
||||
description=description,
|
||||
snippet_id=snippet_id,
|
||||
)
|
||||
redis_client.setex(
|
||||
@ -333,12 +337,37 @@ class SnippetDslService:
|
||||
pending_data_str = pending_data.decode("utf-8") if isinstance(pending_data, bytes) else pending_data
|
||||
pending = SnippetPendingData.model_validate_json(pending_data_str)
|
||||
|
||||
# Re-import with the pending data
|
||||
return self.import_snippet(
|
||||
data = yaml.safe_load(pending.yaml_content)
|
||||
if not isinstance(data, dict):
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Invalid YAML format: expected a dictionary",
|
||||
)
|
||||
|
||||
snippet = None
|
||||
if pending.snippet_id:
|
||||
stmt = select(CustomizedSnippet).where(
|
||||
CustomizedSnippet.id == pending.snippet_id,
|
||||
CustomizedSnippet.tenant_id == account.current_tenant_id,
|
||||
)
|
||||
snippet = self._session.scalar(stmt)
|
||||
|
||||
snippet = self._create_or_update_snippet(
|
||||
snippet=snippet,
|
||||
data=data,
|
||||
account=account,
|
||||
import_mode=pending.import_mode,
|
||||
yaml_content=pending.yaml_content,
|
||||
snippet_id=pending.snippet_id,
|
||||
name=pending.name,
|
||||
description=pending.description,
|
||||
)
|
||||
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
return SnippetImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.COMPLETED,
|
||||
snippet_id=snippet.id,
|
||||
imported_dsl_version=data.get("version", "0.1.0"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -178,27 +178,14 @@ class SnippetService:
|
||||
Create a new snippet.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param name: Snippet name (must be unique per tenant)
|
||||
:param name: Snippet name
|
||||
:param description: Snippet description
|
||||
:param snippet_type: Type of snippet (node or group)
|
||||
:param icon_info: Icon information
|
||||
:param input_fields: Input field definitions
|
||||
:param account: Creator account
|
||||
:return: Created CustomizedSnippet
|
||||
:raises ValueError: If name already exists
|
||||
"""
|
||||
# Check if name already exists for this tenant
|
||||
existing = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(
|
||||
CustomizedSnippet.tenant_id == tenant_id,
|
||||
CustomizedSnippet.name == name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Snippet with name '{name}' already exists")
|
||||
|
||||
snippet = CustomizedSnippet(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
@ -232,18 +219,6 @@ class SnippetService:
|
||||
:return: Updated CustomizedSnippet
|
||||
"""
|
||||
if "name" in data:
|
||||
# Check if new name already exists for this tenant
|
||||
existing = (
|
||||
session.query(CustomizedSnippet)
|
||||
.where(
|
||||
CustomizedSnippet.tenant_id == snippet.tenant_id,
|
||||
CustomizedSnippet.name == data["name"],
|
||||
CustomizedSnippet.id != snippet.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Snippet with name '{data['name']}' already exists")
|
||||
snippet.name = data["name"]
|
||||
|
||||
if "description" in data:
|
||||
|
||||
@ -271,6 +271,34 @@ class TestTagBindingCollectionApi:
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_create_snippet_binding_success(self, app: Flask, admin_user, payload_patch):
|
||||
api = TagBindingCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"tag_ids": ["tag-1"],
|
||||
"target_id": "snippet-1",
|
||||
"type": "snippet",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
binding_payload = save_mock.call_args.args[0]
|
||||
assert binding_payload.type == TagType.SNIPPET
|
||||
assert binding_payload.target_id == "snippet-1"
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_create_forbidden(self, app: Flask, readonly_user, payload_patch):
|
||||
api = TagBindingCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
46
api/tests/unit_tests/services/test_snippet_dsl_service.py
Normal file
46
api/tests/unit_tests/services/test_snippet_dsl_service.py
Normal file
@ -0,0 +1,46 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
from services.snippet_dsl_service import ImportStatus, SnippetDslService, SnippetPendingData
|
||||
|
||||
|
||||
def test_confirm_import_creates_snippet_from_pending_data(monkeypatch):
|
||||
service = SnippetDslService(session=SimpleNamespace(scalar=Mock(return_value=None)))
|
||||
account = SimpleNamespace(id="account-1", current_tenant_id="tenant-1")
|
||||
snippet = SimpleNamespace(id="snippet-new")
|
||||
yaml_content = """
|
||||
version: 9.0.0
|
||||
kind: snippet
|
||||
snippet:
|
||||
name: From DSL
|
||||
type: node
|
||||
workflow:
|
||||
graph:
|
||||
nodes: []
|
||||
edges: []
|
||||
"""
|
||||
pending = SnippetPendingData(
|
||||
import_mode="yaml-content",
|
||||
yaml_content=yaml_content,
|
||||
name="Override name",
|
||||
description="Override description",
|
||||
snippet_id=None,
|
||||
)
|
||||
create_or_update = Mock(return_value=snippet)
|
||||
monkeypatch.setattr(service, "_create_or_update_snippet", create_or_update)
|
||||
monkeypatch.setattr("services.snippet_dsl_service.redis_client.get", Mock(return_value=pending.model_dump_json()))
|
||||
redis_delete = Mock()
|
||||
monkeypatch.setattr("services.snippet_dsl_service.redis_client.delete", redis_delete)
|
||||
|
||||
result = service.confirm_import(import_id="import-1", account=account)
|
||||
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.snippet_id == "snippet-new"
|
||||
assert result.imported_dsl_version == "9.0.0"
|
||||
create_or_update.assert_called_once()
|
||||
_, kwargs = create_or_update.call_args
|
||||
assert kwargs["snippet"] is None
|
||||
assert kwargs["account"] is account
|
||||
assert kwargs["name"] == "Override name"
|
||||
assert kwargs["description"] == "Override description"
|
||||
redis_delete.assert_called_once_with("snippet_import_info:import-1")
|
||||
@ -6,11 +6,21 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.snippet import SnippetType
|
||||
from models.workflow import Workflow, WorkflowKind, WorkflowType
|
||||
from services.errors.app import WorkflowNotFoundError
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
|
||||
class _SessionWithoutNameLookup:
|
||||
def __init__(self) -> None:
|
||||
self.add = Mock()
|
||||
self.commit = Mock()
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
raise AssertionError("snippet name uniqueness lookup should not be used")
|
||||
|
||||
|
||||
def _create_workflow(*, workflow_id: str, version: str, graph: dict, features: dict) -> Workflow:
|
||||
return Workflow(
|
||||
id=workflow_id,
|
||||
@ -28,6 +38,49 @@ def _create_workflow(*, workflow_id: str, version: str, graph: dict, features: d
|
||||
)
|
||||
|
||||
|
||||
def test_create_snippet_allows_duplicate_names(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = _SessionWithoutNameLookup()
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr("services.snippet_service.db.session", session)
|
||||
|
||||
snippet = SnippetService.create_snippet(
|
||||
tenant_id="tenant-1",
|
||||
name="shared name",
|
||||
description=None,
|
||||
snippet_type=SnippetType.NODE,
|
||||
icon_info=None,
|
||||
input_fields=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
assert snippet.name == "shared name"
|
||||
session.add.assert_called_once_with(snippet)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_update_snippet_allows_duplicate_names() -> None:
|
||||
session = _SessionWithoutNameLookup()
|
||||
snippet = SimpleNamespace(
|
||||
id="snippet-1",
|
||||
tenant_id="tenant-1",
|
||||
name="old name",
|
||||
description="",
|
||||
icon_info=None,
|
||||
)
|
||||
|
||||
result = SnippetService.update_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account_id="account-1",
|
||||
data={"name": "shared name"},
|
||||
)
|
||||
|
||||
assert result is snippet
|
||||
assert snippet.name == "shared name"
|
||||
session.add.assert_called_once_with(snippet)
|
||||
|
||||
|
||||
def test_restore_published_snippet_workflow_to_draft_copies_source_snapshot(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user