From 0ded89d67358aa880fdbfa4f40b2d64f477bbd46 Mon Sep 17 00:00:00 2001 From: FFXN Date: Thu, 4 Jun 2026 16:38:17 +0800 Subject: [PATCH] feat: add comprehensive tests for snippet import and workflow generation --- .../test_snippet_workflow_draft_variable.py | 80 ++++++ .../test_workflow_app_runner_single_node.py | 65 +++++ .../apps/workflow/test_app_generator_extra.py | 82 +++++++ .../services/test_snippet_dsl_service.py | 230 +++++++++++++++++- .../services/test_snippet_generate_service.py | 140 +++++++++++ .../services/test_snippet_service.py | 117 ++++++++- .../unit_tests/services/test_tag_service.py | 30 +++ 7 files changed, 741 insertions(+), 3 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py create mode 100644 api/tests/unit_tests/services/test_snippet_generate_service.py diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py new file mode 100644 index 00000000000..6fb15ba7378 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py @@ -0,0 +1,80 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from flask import Flask + +from controllers.console.snippets import snippet_workflow_draft_variable as module +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from services.workflow_draft_variable_service import WorkflowDraftVariableList + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_snippet_workflow_draft_variable") + app.config["TESTING"] = True + return app + + +def test_ensure_snippet_draft_variable_row_allowed_rejects_system_variable(): + variable = SimpleNamespace(node_id=SYSTEM_VARIABLE_NODE_ID) + + with pytest.raises(module.NotFoundError, match="variable not found"): + module._ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id="var-1") + + +def test_ensure_snippet_draft_variable_row_allowed_rejects_conversation_variable(): + variable = SimpleNamespace(node_id=CONVERSATION_VARIABLE_NODE_ID) + + with pytest.raises(module.NotFoundError, match="variable not found"): + module._ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id="var-1") + + +def test_ensure_snippet_draft_variable_row_allowed_accepts_canvas_node_variable(): + variable = SimpleNamespace(node_id="llm-1") + + module._ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id="var-1") + + +def test_conversation_variables_returns_empty_list(app): + api = module.SnippetConversationVariableCollectionApi() + handler = _unwrap(api.get) + + with app.test_request_context("/"): + result = handler(api, snippet=SimpleNamespace(id="snippet-1")) + + assert result == WorkflowDraftVariableList(variables=[]) + + +def test_system_variables_returns_empty_list(app): + api = module.SnippetSystemVariableCollectionApi() + handler = _unwrap(api.get) + + with app.test_request_context("/"): + result = handler(api, snippet=SimpleNamespace(id="snippet-1")) + + assert result == WorkflowDraftVariableList(variables=[]) + + +def test_delete_variable_collection_deletes_current_user_variables(app, monkeypatch): + draft_var_service = SimpleNamespace(delete_user_workflow_variables=Mock()) + monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) + monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1")) + db_session = Mock() + db_session.return_value = SimpleNamespace() + monkeypatch.setattr(module.db, "session", db_session) + api = module.SnippetWorkflowVariableCollectionApi() + handler = _unwrap(api.delete) + + with app.test_request_context("/", method="DELETE"): + response = handler(api, snippet=SimpleNamespace(id="snippet-1")) + + assert response.status_code == 204 + draft_var_service.delete_user_workflow_variables.assert_called_once_with("snippet-1", user_id="user-1") + db_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 248fed53883..8f820add697 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -163,3 +163,68 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None: ) assert seen_configs == [workflow.graph_dict["nodes"][0]] + + +def test_run_adds_inputs_with_snippet_compatible_start_aliases() -> None: + app_config = MagicMock() + app_config.app_id = "app" + app_config.tenant_id = "tenant" + app_config.workflow_id = "workflow" + + app_generate_entity = MagicMock(spec=WorkflowAppGenerateEntity) + app_generate_entity.app_config = app_config + app_generate_entity.inputs = {"question": "hello"} + app_generate_entity.files = [] + app_generate_entity.user_id = "user" + app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + app_generate_entity.workflow_execution_id = "execution-id" + app_generate_entity.task_id = "task-id" + app_generate_entity.call_depth = 0 + app_generate_entity.trace_manager = None + app_generate_entity.single_iteration_run = None + app_generate_entity.single_loop_run = None + + workflow = MagicMock(spec=Workflow) + workflow.tenant_id = "tenant" + workflow.app_id = "app" + workflow.id = "workflow" + workflow.type = "workflow" + workflow.version = "v1" + workflow.graph_dict = {"nodes": [], "edges": []} + workflow.environment_variables = [] + workflow.kind_or_standard = "snippet" + + runner = WorkflowAppRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(spec=AppQueueManager), + variable_loader=MagicMock(), + workflow=workflow, + system_user_id="system-user", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + mock_workflow_entry = MagicMock() + mock_workflow_entry.graph_engine = MagicMock() + mock_workflow_entry.graph_engine.layer = MagicMock() + mock_workflow_entry.run.return_value = iter([]) + + with ( + patch("core.app.apps.workflow.app_runner.RedisChannel"), + patch("core.app.apps.workflow.app_runner.redis_client"), + patch("core.app.apps.workflow.app_runner.WorkflowEntry", return_value=mock_workflow_entry), + patch("core.app.apps.workflow.app_runner.build_system_variables", return_value={}), + patch("core.app.apps.workflow.app_runner.build_bootstrap_variables", return_value=[]), + patch("core.app.apps.workflow.app_runner.add_variables_to_pool"), + patch("core.app.apps.workflow.app_runner.get_default_root_node_id", return_value="root-node"), + patch("core.app.apps.workflow.app_runner.get_compatible_start_aliases", return_value=("legacy-start",)) as aliases, + patch("core.app.apps.workflow.app_runner.add_node_inputs_to_pool") as add_inputs, + patch.object(runner, "_init_graph", return_value=MagicMock()), + ): + runner.run() + + aliases.assert_called_once_with(workflow_kind="snippet", root_node_id="root-node") + add_inputs.assert_called_once() + assert add_inputs.call_args.kwargs["node_id"] == "root-node" + assert add_inputs.call_args.kwargs["inputs"] == {"question": "hello"} + assert add_inputs.call_args.kwargs["aliases"] == ("legacy-start",) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py index 941a47b572e..a5703c83a97 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py @@ -1,6 +1,8 @@ from __future__ import annotations +import contextlib from types import SimpleNamespace +from unittest.mock import Mock import pytest @@ -297,3 +299,83 @@ class TestWorkflowAppGeneratorResume: assert result.ok is True assert captured_entity is not None assert captured_entity.trace_manager is existing_trace_manager + + +class TestWorkflowAppGeneratorWorker: + def test_generate_worker_uses_end_user_session_for_external_invocation(self, monkeypatch: pytest.MonkeyPatch): + generator = WorkflowAppGenerator() + + workflow = SimpleNamespace( + id="workflow-id", + tenant_id="tenant", + app_id="app", + graph_dict={}, + type="workflow", + version="1", + ) + end_user = SimpleNamespace(id="end-user-id", session_id="session-id") + session = SimpleNamespace(scalar=Mock(side_effect=[workflow, end_user])) + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + runner_kwargs = {} + + class _Runner: + def __init__(self, **kwargs): + runner_kwargs.update(kwargs) + + def run(self): + return None + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.preserve_flask_contexts", + lambda flask_app, context_vars: contextlib.nullcontext(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.session_factory.create_session", + lambda: _SessionContext(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerator._ensure_snippet_start_node_in_worker", + lambda self, *, session, workflow: workflow, + ) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppRunner", _Runner) + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="end-user-id", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_execution_id="run-id", + call_depth=0, + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=SimpleNamespace(), + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + ) + + assert runner_kwargs["system_user_id"] == "session-id" diff --git a/api/tests/unit_tests/services/test_snippet_dsl_service.py b/api/tests/unit_tests/services/test_snippet_dsl_service.py index 97d996df55d..77c92fcc7d8 100644 --- a/api/tests/unit_tests/services/test_snippet_dsl_service.py +++ b/api/tests/unit_tests/services/test_snippet_dsl_service.py @@ -1,7 +1,123 @@ from types import SimpleNamespace from unittest.mock import Mock -from services.snippet_dsl_service import ImportStatus, SnippetDslService, SnippetPendingData +import pytest + +from graphon.nodes import BuiltinNodeTypes +from services.snippet_dsl_service import ( + ImportMode, + ImportStatus, + SnippetDslService, + SnippetPendingData, + _check_version_compatibility, +) + + +@pytest.mark.parametrize( + ("version", "expected"), + [ + ("not-a-version", ImportStatus.FAILED), + ("999.0.0", ImportStatus.PENDING), + ("0.1.0", ImportStatus.COMPLETED), + ], + ) +def test_check_version_compatibility_special_cases(version, expected): + assert _check_version_compatibility(version) == expected + + +def test_import_snippet_rejects_invalid_mode(): + service = SnippetDslService(session=SimpleNamespace()) + + with pytest.raises(ValueError, match="Invalid import_mode"): + service.import_snippet(account=SimpleNamespace(current_tenant_id="tenant-1"), import_mode="bad-mode") + + +def test_import_snippet_requires_yaml_content(): + service = SnippetDslService(session=SimpleNamespace()) + + result = service.import_snippet( + account=SimpleNamespace(current_tenant_id="tenant-1"), + import_mode=ImportMode.YAML_CONTENT.value, + ) + + assert result.status == ImportStatus.FAILED + assert result.error == "yaml_content is required when import_mode is yaml-content" + + +def test_import_snippet_rejects_forbidden_nodes(): + service = SnippetDslService(session=SimpleNamespace()) + yaml_content = """ +version: 0.3.0 +kind: snippet +snippet: + name: Bad Snippet +workflow: + graph: + nodes: + - id: start-1 + data: + type: start + edges: [] +""" + + result = service.import_snippet( + account=SimpleNamespace(current_tenant_id="tenant-1"), + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=yaml_content, + ) + + assert result.status == ImportStatus.FAILED + assert result.error == "Snippet cannot contain the following node types: start" + + +def test_import_snippet_stores_pending_data_for_newer_dsl(monkeypatch): + service = SnippetDslService(session=SimpleNamespace(scalar=Mock(return_value=None))) + setex = Mock() + monkeypatch.setattr("services.snippet_dsl_service.redis_client.setex", setex) + yaml_content = """ +version: 999.0.0 +kind: snippet +snippet: + name: Future Snippet +workflow: + graph: + nodes: [] + edges: [] +""" + + result = service.import_snippet( + account=SimpleNamespace(current_tenant_id="tenant-1"), + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=yaml_content, + name="Override", + description="Override description", + ) + + assert result.status == ImportStatus.PENDING + setex.assert_called_once() + pending = SnippetPendingData.model_validate_json(setex.call_args.args[2]) + assert pending.name == "Override" + assert pending.description == "Override description" + + +def test_confirm_import_returns_failed_when_pending_data_missing(monkeypatch): + service = SnippetDslService(session=SimpleNamespace()) + monkeypatch.setattr("services.snippet_dsl_service.redis_client.get", Mock(return_value=None)) + + result = service.confirm_import(import_id="missing", account=SimpleNamespace(current_tenant_id="tenant-1")) + + assert result.status == ImportStatus.FAILED + assert result.error == "Import information expired or does not exist" + + +def test_confirm_import_returns_failed_for_invalid_pending_payload(monkeypatch): + service = SnippetDslService(session=SimpleNamespace()) + monkeypatch.setattr("services.snippet_dsl_service.redis_client.get", Mock(return_value=object())) + + result = service.confirm_import(import_id="bad", account=SimpleNamespace(current_tenant_id="tenant-1")) + + assert result.status == ImportStatus.FAILED + assert result.error == "Invalid import information" def test_confirm_import_creates_snippet_from_pending_data(monkeypatch): @@ -44,3 +160,115 @@ workflow: assert kwargs["name"] == "Override name" assert kwargs["description"] == "Override description" redis_delete.assert_called_once_with("snippet_import_info:import-1") + + +def test_create_or_update_snippet_updates_existing_snippet_and_syncs_workflow(monkeypatch): + snippet = SimpleNamespace( + id="snippet-1", + name="Old", + description="Old", + type="node", + icon_info=None, + input_fields=None, + updated_by=None, + updated_at=None, + ) + session = SimpleNamespace(add=Mock(), flush=Mock(), commit=Mock()) + service = SnippetDslService(session=session) + draft_workflow = SimpleNamespace(unique_hash="hash-1") + snippet_service = SimpleNamespace( + get_draft_workflow=Mock(return_value=draft_workflow), + sync_draft_workflow=Mock(), + ) + monkeypatch.setattr("services.snippet_dsl_service.SnippetService", lambda: snippet_service) + + result = service._create_or_update_snippet( + snippet=snippet, + data={ + "snippet": { + "name": "New", + "description": "New description", + "type": "unknown-type", + "icon_info": {"icon": "x"}, + "input_fields": [{"variable": "query"}], + }, + "workflow": {"graph": {"nodes": [], "edges": []}}, + }, + account=SimpleNamespace(id="account-1", current_tenant_id="tenant-1"), + ) + + assert result is snippet + assert snippet.name == "New" + assert snippet.type == "node" + assert snippet.icon_info == {"icon": "x"} + snippet_service.sync_draft_workflow.assert_called_once() + session.commit.assert_called_once() + + +def test_export_snippet_dsl_raises_without_draft_workflow(monkeypatch): + service = SnippetDslService(session=SimpleNamespace()) + monkeypatch.setattr( + "services.snippet_dsl_service.SnippetService", + lambda: SimpleNamespace(get_draft_workflow=Mock(return_value=None)), + ) + + with pytest.raises(ValueError, match="Missing draft workflow"): + service.export_snippet_dsl(SimpleNamespace()) + + +def test_append_workflow_export_data_filters_credentials_and_extracts_dependencies(monkeypatch): + service = SnippetDslService(session=SimpleNamespace()) + workflow_dict = { + "graph": { + "nodes": [ + {"data": {}}, + { + "data": { + "type": BuiltinNodeTypes.TOOL, + "credential_id": "secret", + "tool_configurations": {"provider_type": "builtin", "provider": "langgenius/google"}, + } + }, + { + "data": { + "type": BuiltinNodeTypes.AGENT, + "agent_parameters": { + "tools": { + "value": [ + { + "provider_type": "builtin", + "provider": "langgenius/openai", + "credential_id": "agent-secret", + } + ] + } + }, + } + }, + ] + }, + "environment_variables": [{"name": "SECRET"}], + "conversation_variables": [{"name": "memory"}], + } + workflow = SimpleNamespace( + to_dict=Mock(return_value=workflow_dict), + graph_dict=workflow_dict["graph"], + ) + monkeypatch.setattr( + "services.snippet_dsl_service.DependenciesAnalysisService.generate_dependencies", + Mock(return_value=[]), + ) + export_data = {} + + service._append_workflow_export_data( + export_data=export_data, + snippet=SimpleNamespace(tenant_id="tenant-1"), + workflow=workflow, + include_secret=False, + ) + + nodes = export_data["workflow"]["graph"]["nodes"] + assert export_data["workflow"]["environment_variables"] == [] + assert export_data["workflow"]["conversation_variables"] == [] + assert "credential_id" not in nodes[1]["data"] + assert "credential_id" not in nodes[2]["data"]["agent_parameters"]["tools"]["value"][0] diff --git a/api/tests/unit_tests/services/test_snippet_generate_service.py b/api/tests/unit_tests/services/test_snippet_generate_service.py new file mode 100644 index 00000000000..500cf375101 --- /dev/null +++ b/api/tests/unit_tests/services/test_snippet_generate_service.py @@ -0,0 +1,140 @@ +import json +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.workflow.snippet_start import SNIPPET_VIRTUAL_START_NODE_ID +from models.workflow import Workflow, WorkflowKind, WorkflowType +from services.snippet_generate_service import SnippetGenerateService + + +def _workflow(graph: dict) -> Workflow: + return Workflow( + id="workflow-1", + tenant_id="tenant-1", + app_id="snippet-1", + type=WorkflowType.WORKFLOW, + kind=WorkflowKind.SNIPPET, + version=Workflow.VERSION_DRAFT, + graph=json.dumps(graph), + features="{}", + created_by="account-1", + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + +def test_filter_virtual_start_events_keeps_blocking_response_unchanged(): + response = {"data": {"outputs": {"text": "ok"}}} + + assert SnippetGenerateService._filter_virtual_start_events(response) is response + + +def test_filter_virtual_start_events_removes_virtual_start_node_events(): + stream = iter( + [ + {"event": "node_started", "data": {"node_id": SNIPPET_VIRTUAL_START_NODE_ID}}, + {"event": "node_finished", "data": {"node_id": "llm-1"}}, + "raw-event", + ] + ) + + filtered = SnippetGenerateService._filter_virtual_start_events(stream) + + assert list(filtered) == [{"event": "node_finished", "data": {"node_id": "llm-1"}}, "raw-event"] + + +@pytest.mark.parametrize( + ("message", "expected"), + [ + ("raw-event", False), + ({"event": "message", "data": {"node_id": SNIPPET_VIRTUAL_START_NODE_ID}}, False), + ({"event": "node_started", "data": "not-a-dict"}, False), + ({"event": "node_started", "data": {"node_id": SNIPPET_VIRTUAL_START_NODE_ID}}, True), + ], +) +def test_is_virtual_start_event(message, expected): + assert SnippetGenerateService._is_virtual_start_event(message) is expected + + +def test_ensure_start_node_returns_workflow_when_start_already_exists(): + workflow = _workflow({"nodes": [{"id": "start", "data": {"type": "start"}}], "edges": []}) + snippet = SimpleNamespace(input_fields_list=[]) + + result = SnippetGenerateService._ensure_start_node(workflow, snippet) + + assert result is workflow + + +def test_ensure_start_node_injects_virtual_start_for_root_candidates(monkeypatch): + graph = { + "nodes": [ + {"id": "llm-1", "data": {"type": "llm"}}, + {"id": "answer-1", "data": {"type": "answer"}}, + ], + "edges": [{"source": "llm-1", "target": "answer-1"}], + } + workflow = _workflow(graph) + snippet = SimpleNamespace( + input_fields_list=[ + { + "variable": "query", + "label": "Query", + "type": "text-input", + "required": True, + "max_length": 128, + } + ] + ) + make_transient = Mock() + monkeypatch.setattr("services.snippet_generate_service.make_transient", make_transient) + + result = SnippetGenerateService._ensure_start_node(workflow, snippet) + + assert result is workflow + updated_graph = workflow.graph_dict + assert updated_graph["nodes"][0]["id"] == SNIPPET_VIRTUAL_START_NODE_ID + assert updated_graph["nodes"][0]["data"]["variables"][0]["max_length"] == 128 + assert updated_graph["edges"][-1]["source"] == SNIPPET_VIRTUAL_START_NODE_ID + assert updated_graph["edges"][-1]["target"] == "llm-1" + make_transient.assert_called_once_with(workflow) + + +def test_parse_files_returns_empty_when_upload_config_disabled(monkeypatch): + workflow = _workflow({"nodes": [], "edges": []}) + monkeypatch.setattr("services.snippet_generate_service.FileUploadConfigManager.convert", Mock(return_value=None)) + + assert SnippetGenerateService.parse_files(workflow, files=[{"id": "file-1"}]) == [] + + +def test_parse_files_delegates_to_file_factory(monkeypatch): + workflow = _workflow({"nodes": [], "edges": []}) + upload_config = SimpleNamespace(enabled=True) + files = [SimpleNamespace(id="file-1")] + monkeypatch.setattr( + "services.snippet_generate_service.FileUploadConfigManager.convert", Mock(return_value=upload_config) + ) + build_from_mappings = Mock(return_value=files) + monkeypatch.setattr("services.snippet_generate_service.file_factory.build_from_mappings", build_from_mappings) + + result = SnippetGenerateService.parse_files(workflow, files=[{"id": "file-1"}]) + + assert result == files + build_from_mappings.assert_called_once() + + +def test_generate_raises_when_draft_workflow_missing(monkeypatch): + monkeypatch.setattr( + "services.snippet_generate_service.SnippetService", + lambda: SimpleNamespace(get_draft_workflow=Mock(return_value=None)), + ) + + with pytest.raises(ValueError, match="Workflow not initialized"): + SnippetGenerateService.generate( + snippet=SimpleNamespace(id="snippet-1", tenant_id="tenant-1"), + user=SimpleNamespace(id="user-1"), + args={"inputs": {}}, + invoke_from="debugger", + ) diff --git a/api/tests/unit_tests/services/test_snippet_service.py b/api/tests/unit_tests/services/test_snippet_service.py index 8730062cf59..ac48ac5a4bf 100644 --- a/api/tests/unit_tests/services/test_snippet_service.py +++ b/api/tests/unit_tests/services/test_snippet_service.py @@ -2,13 +2,13 @@ from __future__ import annotations import json from types import SimpleNamespace -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import pytest from models.snippet import SnippetType from models.workflow import Workflow, WorkflowKind, WorkflowType -from services.errors.app import WorkflowNotFoundError +from services.errors.app import WorkflowHashNotEqualError, WorkflowNotFoundError from services.snippet_service import SnippetService @@ -59,6 +59,34 @@ def test_create_snippet_allows_duplicate_names(monkeypatch: pytest.MonkeyPatch) session.commit.assert_called_once() +def test_validate_snippet_graph_forbidden_nodes_ignores_malformed_nodes() -> None: + SnippetService.validate_snippet_graph_forbidden_nodes( + { + "nodes": [ + "not-a-node", + {"id": "empty-data", "data": {}}, + {"id": "bad-type", "data": {"type": 123}}, + {"id": "llm-1", "data": {"type": "llm"}}, + ] + } + ) + + +def test_validate_snippet_graph_forbidden_nodes_raises_with_node_details() -> None: + with pytest.raises(ValueError, match="start-1:start"): + SnippetService.validate_snippet_graph_forbidden_nodes( + {"nodes": [{"id": "start-1", "data": {"type": "start"}}]} + ) + + +def test_get_snippets_returns_empty_when_tag_filter_has_no_targets(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", Mock(return_value=[])) + + result = SnippetService.get_snippets(tenant_id="tenant-1", tag_ids=["tag-1"]) + + assert result == ([], 0, False) + + def test_update_snippet_allows_duplicate_names() -> None: session = _SessionWithoutNameLookup() snippet = SimpleNamespace( @@ -81,6 +109,91 @@ def test_update_snippet_allows_duplicate_names() -> None: session.add.assert_called_once_with(snippet) +def test_sync_draft_workflow_creates_draft_and_updates_input_fields(monkeypatch: pytest.MonkeyPatch) -> None: + service = SnippetService.__new__(SnippetService) + monkeypatch.setattr(service, "get_draft_workflow", Mock(return_value=None)) + session = SimpleNamespace(add=Mock(), flush=Mock(), commit=Mock()) + monkeypatch.setattr("services.snippet_service.db.session", session) + snippet = SimpleNamespace( + id="snippet-1", + tenant_id="tenant-1", + input_fields=None, + updated_by=None, + updated_at=None, + ) + account = SimpleNamespace(id="account-1") + + workflow = service.sync_draft_workflow( + snippet=snippet, + graph={"nodes": [{"id": "llm-1", "data": {"type": "llm"}}], "edges": []}, + unique_hash=None, + account=account, + input_fields=[{"variable": "query"}], + ) + + assert workflow.app_id == snippet.id + assert workflow.kind == WorkflowKind.SNIPPET + assert json.loads(snippet.input_fields) == [{"variable": "query"}] + session.add.assert_called_once_with(workflow) + session.flush.assert_called_once() + session.commit.assert_called_once() + + +def test_sync_draft_workflow_raises_when_hash_mismatches() -> None: + service = SnippetService.__new__(SnippetService) + service.get_draft_workflow = Mock(return_value=SimpleNamespace(unique_hash="server-hash")) + + with pytest.raises(WorkflowHashNotEqualError): + service.sync_draft_workflow( + snippet=SimpleNamespace(id="snippet-1", tenant_id="tenant-1"), + graph={"nodes": [], "edges": []}, + unique_hash="client-hash", + account=SimpleNamespace(id="account-1"), + ) + + +def test_get_default_block_configs_skips_empty_defaults(monkeypatch: pytest.MonkeyPatch) -> None: + node_with_default = SimpleNamespace(get_default_config=Mock(return_value={"type": "llm"})) + node_without_default = SimpleNamespace(get_default_config=Mock(return_value=None)) + monkeypatch.setattr( + "services.snippet_service.NODE_TYPE_CLASSES_MAPPING", + { + "llm": {"1": node_with_default}, + "empty": {"1": node_without_default}, + }, + ) + monkeypatch.setattr("services.snippet_service.LATEST_VERSION", "1") + service = SnippetService.__new__(SnippetService) + + assert service.get_default_block_configs() == [{"type": "llm"}] + + +def test_get_default_block_config_returns_none_for_unknown_node(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("services.snippet_service.NODE_TYPE_CLASSES_MAPPING", {}) + service = SnippetService.__new__(SnippetService) + + assert service.get_default_block_config("missing") is None + + +def test_get_default_block_config_returns_node_default(monkeypatch: pytest.MonkeyPatch) -> None: + node_class = SimpleNamespace(get_default_config=Mock(return_value={"type": "llm"})) + monkeypatch.setattr("services.snippet_service.NODE_TYPE_CLASSES_MAPPING", {"llm": {"1": node_class}}) + monkeypatch.setattr("services.snippet_service.LATEST_VERSION", "1") + service = SnippetService.__new__(SnippetService) + + assert service.get_default_block_config("llm", filters={"k": "v"}) == {"type": "llm"} + node_class.get_default_config.assert_called_once_with(filters={"k": "v"}) + + +def test_get_default_block_config_returns_none_for_empty_default(monkeypatch: pytest.MonkeyPatch) -> None: + node_class = SimpleNamespace(get_default_config=Mock(return_value=None)) + monkeypatch.setattr("services.snippet_service.NODE_TYPE_CLASSES_MAPPING", {"llm": {"1": node_class}}) + monkeypatch.setattr("services.snippet_service.LATEST_VERSION", "1") + service = SnippetService.__new__(SnippetService) + + assert service.get_default_block_config("llm") is None + + def test_restore_published_snippet_workflow_to_draft_copies_source_snapshot( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py index 771924ea363..282b32a7e55 100644 --- a/api/tests/unit_tests/services/test_tag_service.py +++ b/api/tests/unit_tests/services/test_tag_service.py @@ -58,6 +58,29 @@ def test_delete_tag_binding_limits_deletion_to_valid_snippet_tags(mocker, curren db_session.commit.assert_called_once() +def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker, current_user, db_session): + mocker.patch("services.tag_service.TagService.check_target_exists") + db_session.execute.return_value = SimpleNamespace(rowcount=0) + + TagService.delete_tag_binding( + TagBindingDeletePayload( + tag_ids=["tag-1"], + target_id="snippet-1", + type=TagType.SNIPPET, + ) + ) + + db_session.execute.assert_called_once() + db_session.commit.assert_not_called() + + +def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(db_session): + result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", []) + + assert result == [] + db_session.scalars.assert_not_called() + + def test_check_target_exists_accepts_existing_snippet(current_user, db_session): db_session.scalar.return_value = SimpleNamespace(id="snippet-1") @@ -71,3 +94,10 @@ def test_check_target_exists_raises_when_snippet_missing(current_user, db_sessio with pytest.raises(NotFound, match="Snippet not found"): TagService.check_target_exists("snippet", "missing-snippet") + + +def test_check_target_exists_raises_for_invalid_binding_type(current_user, db_session): + with pytest.raises(NotFound, match="Invalid binding type"): + TagService.check_target_exists("unknown", "target-1") + + db_session.scalar.assert_not_called()