From 1e4b2995d420bb3cab0ebc3e36108896e3d955e1 Mon Sep 17 00:00:00 2001 From: FFXN <31929997+FFXN@users.noreply.github.com> Date: Thu, 28 May 2026 11:01:49 +0800 Subject: [PATCH] feat: dev snippet fronted (#36748) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: dependabot[bot] Signed-off-by: EvanYao826 <155432245+EvanYao826@users.noreply.github.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: 盐粒 Yanli Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Tianle <40735546+Tianlel@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Yunlu Wen Co-authored-by: zyssyz123 <916125788@qq.com> Co-authored-by: Claude Opus 4.7 (1M context) Co-authored-by: chariri Co-authored-by: Asuka Minato Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Nian <11332799+Lillian68@users.noreply.github.com> Co-authored-by: 非法操作 Co-authored-by: Carmen Fernández Ruiz <279459669+zeus1959@users.noreply.github.com> Co-authored-by: wangxiaolei Co-authored-by: QuantumGhost Co-authored-by: L1nSn0w Co-authored-by: Evan <2869018789@qq.com> Co-authored-by: Escape0707 Co-authored-by: Jingyi Co-authored-by: Amr Sherif <140330826+amr-sheriff@users.noreply.github.com> Co-authored-by: ZHOU ZHICHEN <118870511+zhuiguangzhe2003@users.noreply.github.com> Co-authored-by: unknown Co-authored-by: JzoNg Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> --- .github/dependabot.yml | 111 + .github/workflows/cli-tests.yml | 10 +- .gitignore | 2 +- api/app_factory.py | 3 +- api/clients/agent_backend/__init__.py | 6 +- api/clients/agent_backend/request_builder.py | 46 +- api/commands/__init__.py | 3 + api/commands/data_migrate.py | 179 ++ api/configs/middleware/vdb/milvus_config.py | 18 + api/controllers/console/__init__.py | 2 + api/controllers/console/agent/composer.py | 22 +- api/controllers/console/apikey.py | 92 +- api/controllers/console/app/agent.py | 4 +- api/controllers/console/app/app.py | 24 +- api/controllers/console/app/audio.py | 4 +- api/controllers/console/app/completion.py | 10 +- api/controllers/console/app/conversation.py | 14 +- .../console/app/conversation_variables.py | 4 +- api/controllers/console/app/mcp_server.py | 8 +- api/controllers/console/app/message.py | 14 +- api/controllers/console/app/model_config.py | 4 +- api/controllers/console/app/site.py | 5 +- api/controllers/console/app/statistic.py | 65 +- .../console/app/workflow_draft_variable.py | 13 +- .../app/workflow_node_output_inspector.py | 415 +++ .../console/app/workflow_statistic.py | 10 +- .../console/auth/data_source_bearer_auth.py | 17 +- api/controllers/console/auth/oauth_server.py | 12 +- .../console/datasets/datasets_segments.py | 3 + .../console/datasets/hit_testing.py | 98 +- .../console/datasets/hit_testing_base.py | 41 +- api/controllers/console/explore/audio.py | 9 +- api/controllers/console/explore/completion.py | 18 +- .../console/explore/conversation.py | 23 +- .../console/explore/installed_app.py | 4 +- api/controllers/console/explore/message.py | 30 +- .../console/explore/saved_message.py | 14 +- api/controllers/console/extension.py | 28 +- api/controllers/console/feature.py | 14 +- api/controllers/console/files.py | 15 +- api/controllers/console/remote_files.py | 10 +- api/controllers/console/tag/tags.py | 46 +- api/controllers/console/workspace/members.py | 2 +- api/controllers/openapi/__init__.py | 14 + api/controllers/openapi/_models.py | 60 +- api/controllers/openapi/account.py | 65 +- api/controllers/openapi/app_run.py | 13 +- api/controllers/openapi/apps.py | 53 +- .../openapi/apps_permitted_external.py | 28 +- api/controllers/openapi/auth/__init__.py | 4 +- api/controllers/openapi/auth/composition.py | 100 +- api/controllers/openapi/auth/conditions.py | 53 + api/controllers/openapi/auth/context.py | 68 - api/controllers/openapi/auth/data.py | 69 + api/controllers/openapi/auth/flow.py | 19 + api/controllers/openapi/auth/pipeline.py | 224 +- api/controllers/openapi/auth/prepare.py | 67 + api/controllers/openapi/auth/role_gate.py | 77 + api/controllers/openapi/auth/steps.py | 170 -- api/controllers/openapi/auth/strategies.py | 168 -- api/controllers/openapi/auth/verify.py | 82 + api/controllers/openapi/files.py | 9 +- api/controllers/openapi/human_input_form.py | 13 +- api/controllers/openapi/workflow_events.py | 10 +- api/controllers/openapi/workspaces.py | 307 +- .../service_api/dataset/hit_testing.py | 23 +- api/controllers/service_api/wraps.py | 1 + api/controllers/web/app.py | 6 +- api/controllers/web/completion.py | 10 +- api/controllers/web/conversation.py | 12 +- api/controllers/web/files.py | 3 +- api/controllers/web/message.py | 10 +- api/controllers/web/remote_files.py | 5 +- api/controllers/web/saved_message.py | 7 +- api/controllers/web/site.py | 6 +- api/core/agent/base_agent_runner.py | 75 +- api/core/app/apps/base_app_runner.py | 76 +- api/core/app/workflow/layers/persistence.py | 31 + api/core/tools/__base/tool.py | 93 +- .../nodes/agent_v2/plugin_tools_builder.py | 268 ++ .../agent_v2/runtime_feature_manifest.py | 17 +- .../nodes/agent_v2/runtime_request_builder.py | 37 +- .../workflow/nodes/agent_v2/validators.py | 21 + api/extensions/ext_commands.py | 2 + api/fields/hit_testing_fields.py | 146 +- api/libs/oauth_bearer.py | 25 +- api/models/agent_config_entities.py | 86 +- api/openapi/markdown/console-swagger.md | 303 +- api/openapi/markdown/openapi-swagger.md | 139 + api/openapi/markdown/service-swagger.md | 102 +- .../src/dify_vdb_milvus/milvus_vector.py | 25 +- .../tests/unit_tests/test_milvus.py | 29 + api/pyproject.toml | 7 +- api/services/account_service.py | 28 + api/services/legacy_model_type_migration.py | 2464 +++++++++++++++++ api/services/workflow/inspector_events.py | 194 ++ .../workflow/node_output_inspector_service.py | 712 +++++ .../app_generate/workflow_execute_task.py | 74 +- api/tests/helpers/__init__.py | 1 + .../helpers/legacy_model_type_migration.py | 366 +++ .../test_node_output_inspector_service.py | 475 ++++ .../seed_legacy_model_type_dirty_data.py | 82 + .../test_legacy_model_type_migration.py | 408 +++ .../auth/test_data_source_bearer_auth.py | 52 +- .../console/datasets/test_external.py | 116 + .../console/test_api_based_extension.py | 65 +- .../controllers/console/test_apikey.py | 48 +- .../controllers/console/test_feature.py | 65 + .../controllers/console/test_files.py | 101 + .../test_human_input_delivery_test_service.py | 10 +- .../tasks/test_delete_account_task.py | 84 + .../clients/agent_backend/test_client.py | 5 +- .../clients/agent_backend/test_fake_client.py | 5 +- .../agent_backend/test_request_builder.py | 56 +- .../test_legacy_model_type_migration.py | 2025 ++++++++++++++ .../test_workflow_node_output_inspector.py | 454 +++ .../auth/test_data_source_bearer_auth.py | 85 + .../console/auth/test_oauth_server.py | 52 + .../console/datasets/test_external.py | 10 +- .../console/datasets/test_hit_testing.py | 85 +- .../console/datasets/test_hit_testing_base.py | 100 +- .../console/explore/test_message.py | 81 +- .../controllers/console/tag/test_tags.py | 82 +- .../controllers/console/test_apikey.py | 141 + .../controllers/console/test_extension.py | 2 - .../controllers/console/test_feature.py | 14 +- .../controllers/console/test_files.py | 55 +- .../controllers/console/test_remote_files.py | 43 +- .../controllers/console/test_wraps.py | 93 +- .../openapi/auth/test_composition.py | 115 +- .../openapi/auth/test_conditions.py | 143 + .../controllers/openapi/auth/test_context.py | 21 - .../controllers/openapi/auth/test_data.py | 117 + .../controllers/openapi/auth/test_flow.py | 42 + .../controllers/openapi/auth/test_pipeline.py | 302 +- .../controllers/openapi/auth/test_prepare.py | 183 ++ .../openapi/auth/test_role_gate.py | 330 +++ .../openapi/auth/test_step_app_resolver.py | 64 - .../openapi/auth/test_step_authz.py | 76 - .../openapi/auth/test_step_bearer.py | 83 - .../openapi/auth/test_step_layer0.py | 157 -- .../openapi/auth/test_step_mount.py | 77 - .../openapi/auth/test_step_scope.py | 25 - .../openapi/auth/test_surface_gate.py | 239 -- .../controllers/openapi/auth/test_verify.py | 142 + .../controllers/openapi/conftest.py | 30 +- .../controllers/openapi/test_account.py | 8 +- .../openapi/test_app_run_streaming.py | 18 +- .../openapi/test_human_input_form.py | 43 +- .../openapi/test_workflow_events_openapi.py | 45 +- .../openapi/test_workspaces_members.py | 928 +++++++ .../dataset/test_dataset_segment.py | 13 + .../service_api/dataset/test_document.py | 12 +- .../service_api/dataset/test_hit_testing.py | 116 +- .../controllers/web/test_human_input_form.py | 4 +- .../core/agent/test_base_agent_runner.py | 217 +- .../core/app/apps/test_base_app_runner.py | 40 + .../test_persistence_inspector_publish.py | 192 ++ .../unit_tests/core/tools/test_base_tool.py | 113 +- .../agent_v2/test_plugin_tools_builder.py | 439 +++ .../agent_v2/test_runtime_request_builder.py | 100 +- .../test_oauth_bearer_rate_limit_ordering.py | 5 +- .../libs/test_oauth_bearer_require_scope.py | 3 +- .../libs/test_workspace_member_helper.py | 4 +- .../services/test_account_service.py | 48 +- .../test_dataset_service_lock_not_owned.py | 2 +- .../services/test_oauth_device_flow.py | 6 +- .../workflow/test_inspector_events.py | 224 ++ .../test_node_output_inspector_service.py | 499 ++++ .../tasks/test_delete_account_task.py | 115 - .../test_mail_human_input_delivery_task.py | 8 +- api/uv.lock | 14 +- cli/AGENTS.md | 2 +- cli/ARD.md | 2 +- cli/package.json | 2 + cli/src/api/account-sessions.ts | 11 +- cli/src/api/app-meta.test.ts | 10 +- cli/src/api/members.test.ts | 280 ++ cli/src/api/members.ts | 61 + cli/src/api/workspaces.ts | 17 +- cli/src/auth/file-backend.test.ts | 2 +- cli/src/auth/file-backend.ts | 2 +- cli/src/auth/hosts.test.ts | 2 +- cli/src/auth/hosts.ts | 2 +- cli/src/cache/app-info.test.ts | 43 +- cli/src/cache/app-info.ts | 68 +- cli/src/cache/nudge-store.test.ts | 41 +- cli/src/cache/nudge-store.ts | 67 +- cli/src/commands/_shared/authed-command.ts | 20 +- .../auth/devices/_shared/devices.test.ts | 52 +- .../commands/auth/devices/_shared/devices.ts | 54 +- cli/src/commands/auth/devices/list/index.ts | 12 +- cli/src/commands/auth/login/index.ts | 4 +- cli/src/commands/auth/login/login.test.ts | 2 +- cli/src/commands/auth/login/login.ts | 4 +- cli/src/commands/auth/logout/index.ts | 6 +- cli/src/commands/auth/logout/logout.test.ts | 2 +- cli/src/commands/auth/logout/logout.ts | 4 +- cli/src/commands/auth/status/index.ts | 4 +- cli/src/commands/auth/status/status.test.ts | 2 +- cli/src/commands/auth/status/status.ts | 2 +- cli/src/commands/auth/use/index.ts | 25 - cli/src/commands/auth/use/use.test.ts | 71 - cli/src/commands/auth/use/use.ts | 49 - cli/src/commands/auth/whoami/index.ts | 4 +- cli/src/commands/auth/whoami/whoami.test.ts | 2 +- cli/src/commands/auth/whoami/whoami.ts | 2 +- cli/src/commands/config/get/index.ts | 4 +- cli/src/commands/config/get/run.test.ts | 17 +- cli/src/commands/config/get/run.ts | 9 +- cli/src/commands/config/path/index.ts | 2 +- cli/src/commands/config/set/index.ts | 4 +- cli/src/commands/config/set/run.test.ts | 29 +- cli/src/commands/config/set/run.ts | 13 +- cli/src/commands/config/unset/index.ts | 4 +- cli/src/commands/config/unset/run.test.ts | 13 +- cli/src/commands/config/unset/run.ts | 13 +- cli/src/commands/config/view/index.ts | 4 +- cli/src/commands/config/view/run.test.ts | 23 +- cli/src/commands/config/view/run.ts | 9 +- cli/src/commands/create/member/handlers.ts | 23 + cli/src/commands/create/member/index.ts | 40 + cli/src/commands/create/member/run.test.ts | 102 + cli/src/commands/create/member/run.ts | 75 + cli/src/commands/delete/member/handlers.ts | 26 + cli/src/commands/delete/member/index.ts | 40 + cli/src/commands/delete/member/run.test.ts | 72 + cli/src/commands/delete/member/run.ts | 90 + cli/src/commands/describe/app/run.test.ts | 6 +- cli/src/commands/describe/app/run.ts | 9 +- cli/src/commands/env/list/run-list.ts | 7 +- cli/src/commands/get/app/run.ts | 9 +- cli/src/commands/get/member/handlers.ts | 89 + cli/src/commands/get/member/index.ts | 44 + cli/src/commands/get/member/run.test.ts | 153 + cli/src/commands/get/member/run.ts | 65 + cli/src/commands/get/workspace/run.ts | 6 +- cli/src/commands/resume/app/run.ts | 9 +- .../app/_strategies/streaming-structured.ts | 6 +- .../run/app/_strategies/streaming-text.ts | 6 +- cli/src/commands/run/app/handlers.ts | 2 +- cli/src/commands/run/app/hitl-render.ts | 2 +- cli/src/commands/run/app/run.test.ts | 44 +- cli/src/commands/run/app/run.ts | 7 +- cli/src/commands/run/app/stream-handlers.ts | 4 +- cli/src/commands/set/member/handlers.ts | 26 + cli/src/commands/set/member/index.ts | 43 + cli/src/commands/set/member/run.test.ts | 87 + cli/src/commands/set/member/run.ts | 78 + cli/src/commands/tree.generated.ts | 28 +- cli/src/commands/use/workspace/index.ts | 31 + cli/src/commands/use/workspace/use.test.ts | 199 ++ cli/src/commands/use/workspace/use.ts | 76 + cli/src/commands/version/index.ts | 6 +- .../{loader.test.ts => config-loader.test.ts} | 27 +- cli/src/config/config-loader.ts | 42 + cli/src/config/dir.test.ts | 71 - cli/src/config/dir.ts | 45 - cli/src/config/loader.ts | 58 - cli/src/config/writer.ts | 39 - cli/src/env/registry.ts | 5 +- cli/src/errors/format.ts | 2 +- .../config-writer.test.ts} | 38 +- cli/src/store/config-writer.ts | 8 + cli/src/store/dir.ts | 20 + cli/src/store/manager.ts | 28 + cli/src/store/store.test.ts | 193 ++ cli/src/store/store.ts | 165 ++ cli/src/sys/index.test.ts | 37 + cli/src/sys/index.ts | 122 + cli/src/{ => sys}/io/color.ts | 0 cli/src/{ => sys}/io/spinner.ts | 0 cli/src/{ => sys}/io/streams.ts | 7 +- cli/src/{ => sys}/io/think-filter.test.ts | 0 cli/src/{ => sys}/io/think-filter.ts | 0 cli/src/util/browser.ts | 3 +- cli/src/version/info.ts | 3 +- cli/src/version/nudge.test.ts | 8 +- cli/src/version/nudge.ts | 2 +- cli/src/version/probe.test.ts | 9 +- cli/src/version/probe.ts | 7 +- cli/src/version/render.ts | 2 +- cli/src/workspace/resolver.ts | 2 +- cli/tsconfig.json | 10 +- dify-agent/docs/agenton/index.md | 2 - .../docs/dify-agent/get-started/index.md | 14 +- dify-agent/docs/dify-agent/guide/index.md | 8 +- dify-agent/docs/dify-agent/index.md | 3 +- .../execution-context-layer/index.md | 67 + .../user-manual/plugin-layer/index.md | 59 - .../user-manual/plugin-llm-layer/index.md | 42 +- .../user-manual/plugin-tool-layer/index.md | 130 + .../run_server_consumer.py | 58 +- .../run_server_sync_client.py | 58 +- dify-agent/mkdocs.yml | 5 +- .../src/dify_agent/adapters/llm/provider.py | 70 +- .../dify_agent/layers/dify_plugin/__init__.py | 24 +- .../dify_agent/layers/dify_plugin/configs.py | 161 +- .../layers/dify_plugin/llm_layer.py | 27 +- .../layers/dify_plugin/plugin_layer.py | 69 - .../layers/dify_plugin/tool_client.py | 333 +++ .../layers/dify_plugin/tools_layer.py | 341 +++ .../layers/execution_context/__init__.py | 18 + .../layers/execution_context/configs.py | 50 + .../layers/execution_context/layer.py | 95 + .../src/dify_agent/plugin_daemon_transport.py | 72 + .../src/dify_agent/protocol/__init__.py | 4 - dify-agent/src/dify_agent/protocol/schemas.py | 32 +- .../dify_agent/runtime/compositor_factory.py | 30 +- .../src/dify_agent/runtime/run_scheduler.py | 84 +- dify-agent/src/dify_agent/runtime/runner.py | 31 +- .../runtime/user_prompt_validation.py | 6 +- .../src/dify_agent/server/routes/runs.py | 17 +- .../layers/dify_plugin/test_configs.py | 191 +- .../layers/dify_plugin/test_layers.py | 661 ++++- .../layers/execution_context/test_configs.py | 47 + .../layers/execution_context/test_layer.py | 107 + .../protocol/test_protocol_schemas.py | 174 +- .../dify_agent/runtime/test_run_scheduler.py | 400 +-- .../local/dify_agent/runtime/test_runner.py | 422 ++- .../tests/local/dify_agent/server/test_app.py | 18 +- .../dify_agent/server/test_runs_routes.py | 45 +- .../dify_agent/test_import_boundaries.py | 13 +- eslint-suppressions.json | 373 ++- package.json | 7 + .../generated/api/console/agents/types.gen.ts | 26 +- .../generated/api/console/agents/zod.gen.ts | 55 +- .../generated/api/console/apps/orpc.gen.ts | 308 ++- .../generated/api/console/apps/types.gen.ts | 260 +- .../generated/api/console/apps/zod.gen.ts | 201 +- .../api/console/datasets/types.gen.ts | 96 +- .../generated/api/console/datasets/zod.gen.ts | 99 +- .../generated/api/openapi/orpc.gen.ts | 111 +- .../generated/api/openapi/types.gen.ts | 129 + .../generated/api/openapi/zod.gen.ts | 123 + .../generated/api/service/types.gen.ts | 76 +- .../generated/api/service/zod.gen.ts | 93 +- packages/dify-ui/README.md | 27 +- packages/dify-ui/package.json | 9 + .../src/autocomplete/__tests__/index.spec.tsx | 31 + .../src/autocomplete/index.stories.tsx | 180 +- packages/dify-ui/src/autocomplete/index.tsx | 2 +- .../src/combobox/__tests__/index.spec.tsx | 88 +- .../dify-ui/src/combobox/index.stories.tsx | 351 ++- packages/dify-ui/src/combobox/index.tsx | 2 +- .../dify-ui/src/kbd/__tests__/index.spec.tsx | 59 + packages/dify-ui/src/kbd/index.stories.tsx | 230 ++ packages/dify-ui/src/kbd/index.tsx | 61 + .../dify-ui/src/popover/index.stories.tsx | 10 +- .../src/select/__tests__/index.spec.tsx | 37 +- packages/dify-ui/src/select/index.stories.tsx | 7 +- packages/dify-ui/src/select/index.tsx | 2 +- .../src/textarea/__tests__/index.spec.tsx | 187 ++ .../dify-ui/src/textarea/index.stories.tsx | 193 ++ packages/dify-ui/src/textarea/index.tsx | 103 + pnpm-lock.yaml | 199 +- pnpm-workspace.yaml | 3 +- .../app-sidebar/sidebar-shell-flow.test.tsx | 24 +- web/__tests__/app/app-publisher-flow.test.tsx | 9 - .../billing/billing-integration.test.tsx | 4 +- .../delete-account/components/feed-back.tsx | 7 +- .../app-sidebar/__tests__/index.spec.tsx | 12 +- .../__tests__/toggle-button.spec.tsx | 6 - web/app/components/app-sidebar/index.tsx | 22 +- .../components/app-sidebar/toggle-button.tsx | 11 +- .../add-annotation-modal/edit-item/index.tsx | 5 +- .../edit-annotation-modal/edit-item/index.tsx | 5 +- .../app-publisher/__tests__/index.spec.tsx | 22 +- .../app-publisher/__tests__/sections.spec.tsx | 8 +- .../components/app/app-publisher/index.tsx | 9 +- .../components/app/app-publisher/sections.tsx | 9 +- .../app/app-publisher/version-info-modal.tsx | 17 +- .../config-var/config-modal/form-fields.tsx | 5 +- .../config/automatic/idea-output.tsx | 5 +- .../config/automatic/instruction-editor.tsx | 3 +- .../dataset-config/settings-modal/index.tsx | 5 +- .../debug/__tests__/chat-user-input.spec.tsx | 29 +- .../configuration/debug/chat-user-input.tsx | 5 +- .../prompt-value-panel/index.tsx | 5 +- .../create-app-modal/__tests__/index.spec.tsx | 32 +- .../components/app/create-app-modal/index.tsx | 20 +- .../__tests__/index.spec.tsx | 59 +- .../app/create-from-dsl-modal/index.tsx | 26 +- .../duplicate-modal/__tests__/index.spec.tsx | 18 +- .../app/overview/settings/index.tsx | 8 +- .../overview/workflow-hidden-input-fields.tsx | 4 +- .../chat-with-history/inputs-form/content.tsx | 5 +- .../human-input-content/content-item.tsx | 5 +- .../base/chat/chat/answer/operation.tsx | 4 +- .../embedded-chatbot/inputs-form/content.tsx | 5 +- web/app/components/base/chip/index.tsx | 7 +- .../follow-up-setting-modal.tsx | 5 +- .../moderation/form-generation.tsx | 5 +- .../moderation/moderation-content.tsx | 11 +- .../moderation/moderation-setting-modal.tsx | 17 +- .../base/file-uploader/pdf-preview.tsx | 6 +- .../field/__tests__/text-area.spec.tsx | 17 + .../mixed-variable-text-input/placeholder.tsx | 3 +- .../base/form/components/field/text-area.tsx | 10 +- .../base/image-uploader/image-preview.tsx | 10 +- .../components/base/markdown-blocks/form.tsx | 5 +- .../plugins/hitl-input-block/input-field.tsx | 10 +- .../plugins/hitl-input-block/pre-populate.tsx | 6 +- .../base/search-input/index.stories.tsx | 5 +- .../base/textarea/__tests__/index.spec.tsx | 77 - .../base/textarea/index.stories.tsx | 562 ---- web/app/components/base/textarea/index.tsx | 60 - .../billing/apps-full-in-dialog/index.tsx | 5 +- .../usage-info/__tests__/index.spec.tsx | 4 +- .../components/billing/usage-info/index.tsx | 6 +- .../datasets/common/image-previewer/index.tsx | 11 +- .../__tests__/index.spec.tsx | 2 - .../create-from-dsl-modal/index.tsx | 8 +- .../list/template-card/edit-pipeline-info.tsx | 8 +- .../common/__tests__/action-buttons.spec.tsx | 49 +- .../completed/common/action-buttons.tsx | 19 +- .../datasets/list/__tests__/datasets.spec.tsx | 123 + .../datasets/list/dataset-card-skeleton.tsx | 40 + .../dataset-card/__tests__/index.spec.tsx | 9 + .../datasets/list/dataset-card/index.tsx | 2 +- web/app/components/datasets/list/datasets.tsx | 18 +- .../datasets/rename-modal/index.tsx | 4 +- .../form/components/basic-info-section.tsx | 5 +- .../settings/summary-index-setting.tsx | 16 +- .../create-app-modal/__tests__/index.spec.tsx | 44 +- .../explore/create-app-modal/index.tsx | 24 +- .../goto-anything/__tests__/index.spec.tsx | 52 +- .../__tests__/search-input.spec.tsx | 19 +- .../goto-anything/components/search-input.tsx | 9 +- .../__tests__/use-goto-anything-modal.spec.ts | 66 +- .../hooks/use-goto-anything-modal.ts | 24 +- .../plugins/card/base/placeholder.tsx | 2 +- .../plugins/install-plugin/base/installed.tsx | 2 +- .../steps/install.tsx | 2 +- .../steps/uploading.tsx | 2 +- .../app-selector/app-inputs-form.tsx | 5 +- .../tool-selector/__tests__/index.spec.tsx | 8 +- .../__tests__/tool-base-form.spec.tsx | 28 +- .../components/tool-base-form.tsx | 7 +- .../__tests__/use-tool-selector-state.spec.ts | 4 +- .../hooks/use-tool-selector-state.ts | 4 +- .../components/__tests__/index.spec.tsx | 2 - ...blish-as-knowledge-pipeline-modal.spec.tsx | 17 +- .../publish-as-knowledge-pipeline-modal.tsx | 5 +- .../__tests__/index.spec.tsx | 23 +- .../__tests__/run-mode.spec.tsx | 8 +- .../publisher/__tests__/index.spec.tsx | 40 +- .../publisher/__tests__/popup.spec.tsx | 23 +- .../rag-pipeline-header/publisher/popup.tsx | 18 +- .../rag-pipeline-header/run-mode.tsx | 9 +- .../share/text-generation/run-once/index.tsx | 5 +- .../__tests__/snippet-create-button.spec.tsx | 1 - .../components/snippet-create-button.tsx | 3 - .../__tests__/snippet-main.spec.tsx | 49 + .../__tests__/use-snippet-publish.spec.ts | 55 +- .../components/hooks/use-snippet-publish.ts | 12 +- .../components/snippet-header/run-mode.tsx | 6 +- .../snippets/components/snippet-main.tsx | 9 +- .../snippets/create-snippet-dialog.tsx | 8 +- .../__tests__/use-create-snippet.spec.tsx | 4 +- .../__tests__/use-nodes-sync-draft.spec.ts | 39 + .../hooks/__tests__/use-snippet-run.spec.ts | 175 ++ .../snippets/hooks/use-create-snippet.ts | 4 +- .../snippets/hooks/use-nodes-sync-draft.ts | 37 +- .../snippets/hooks/use-snippet-run.ts | 16 +- .../tools/__tests__/provider-list.spec.tsx | 40 +- .../config-credentials.tsx | 249 +- .../edit-custom-collection-modal/index.tsx | 22 +- .../tools/mcp/__tests__/index.spec.tsx | 38 +- web/app/components/tools/mcp/index.tsx | 47 +- .../components/tools/mcp/mcp-server-modal.tsx | 8 +- .../tools/mcp/mcp-server-param-item.tsx | 8 +- web/app/components/tools/provider-list.tsx | 65 +- .../tools/provider/tool-card-skeleton.tsx | 50 + .../components/tools/workflow-tool/index.tsx | 5 +- .../__tests__/reactflow-mock-state.ts | 1 - .../__tests__/shortcuts-name.spec.tsx | 51 - .../workflow/block-selector/blocks.tsx | 1 - .../workflow/block-selector/main.tsx | 115 +- .../snippets/__tests__/index.spec.tsx | 12 + .../block-selector/snippets/index.tsx | 42 +- .../snippets/snippet-tags-filter.tsx | 14 +- .../header/__tests__/header-layouts.spec.tsx | 33 + .../header/__tests__/run-mode.spec.tsx | 4 - .../__tests__/test-run-menu-helpers.spec.tsx | 4 - .../header/__tests__/test-run-menu.spec.tsx | 4 - .../workflow/header/header-in-restoring.tsx | 2 + .../use-nodes-available-var-list.spec.ts | 57 + .../hooks/use-nodes-available-var-list.ts | 9 +- .../components/before-run-form/form-item.tsx | 5 +- .../mixed-variable-text-input/placeholder.tsx | 3 +- .../__tests__/use-available-var-list.spec.ts | 108 + .../_base/hooks/snippet-input-field-vars.ts | 12 + .../_base/hooks/use-available-var-list.ts | 6 +- .../assigner/components/var-list/index.tsx | 5 +- .../nodes/http/components/curl-panel.tsx | 5 +- .../__tests__/form-content.spec.tsx | 16 +- .../human-input/components/form-content.tsx | 15 +- .../json-schema-generator/prompt-editor.tsx | 9 +- .../edit-card/advanced-options.tsx | 11 +- .../components/loop-variables/form-item.tsx | 9 +- .../components/extract-parameter/update.tsx | 5 +- .../mixed-variable-text-input/placeholder.tsx | 3 +- .../components/workflow/shortcuts-name.tsx | 42 - .../shortcuts/__tests__/shortcut-kbd.spec.tsx | 4 +- .../workflow/shortcuts/shortcut-kbd.tsx | 21 +- .../shortcuts/use-workflow-hotkeys.ts | 13 +- .../workflow/utils/__tests__/common.spec.ts | 156 -- web/app/components/workflow/utils/common.ts | 36 - .../variable-inspect/display-content.tsx | 5 +- .../value-content-sections.tsx | 8 +- web/eslint.config.mjs | 16 + web/eslint.constants.mjs | 54 + .../__tests__/dataset-card-tags.spec.tsx | 5 +- .../__tests__/app-card-tags.spec.tsx | 10 + .../components/app-card-tags.tsx | 2 +- .../components/dataset-card-tags.tsx | 10 +- web/i18n/ar-TN/workflow.json | 2 + web/i18n/de-DE/workflow.json | 2 + web/i18n/en-US/snippet.json | 1 + web/i18n/en-US/workflow.json | 2 + web/i18n/es-ES/workflow.json | 2 + web/i18n/fa-IR/workflow.json | 2 + web/i18n/fr-FR/workflow.json | 2 + web/i18n/hi-IN/workflow.json | 2 + web/i18n/id-ID/workflow.json | 2 + web/i18n/it-IT/workflow.json | 2 + web/i18n/ja-JP/workflow.json | 2 + web/i18n/ko-KR/workflow.json | 2 + web/i18n/nl-NL/workflow.json | 2 + web/i18n/pl-PL/workflow.json | 2 + web/i18n/pt-BR/workflow.json | 2 + web/i18n/ro-RO/workflow.json | 2 + web/i18n/ru-RU/workflow.json | 2 + web/i18n/sl-SI/workflow.json | 2 + web/i18n/th-TH/workflow.json | 2 + web/i18n/tr-TR/workflow.json | 2 + web/i18n/uk-UA/workflow.json | 2 + web/i18n/vi-VN/workflow.json | 2 + web/i18n/zh-Hans/snippet.json | 1 + web/i18n/zh-Hans/workflow.json | 2 + web/i18n/zh-Hant/workflow.json | 2 + web/package.json | 1 - web/service/knowledge/use-dataset.ts | 1 + web/service/use-snippets.ts | 36 +- 545 files changed, 27174 insertions(+), 6886 deletions(-) create mode 100644 api/commands/data_migrate.py create mode 100644 api/controllers/console/app/workflow_node_output_inspector.py create mode 100644 api/controllers/openapi/auth/conditions.py delete mode 100644 api/controllers/openapi/auth/context.py create mode 100644 api/controllers/openapi/auth/data.py create mode 100644 api/controllers/openapi/auth/flow.py create mode 100644 api/controllers/openapi/auth/prepare.py create mode 100644 api/controllers/openapi/auth/role_gate.py delete mode 100644 api/controllers/openapi/auth/steps.py delete mode 100644 api/controllers/openapi/auth/strategies.py create mode 100644 api/controllers/openapi/auth/verify.py create mode 100644 api/core/workflow/nodes/agent_v2/plugin_tools_builder.py create mode 100644 api/services/legacy_model_type_migration.py create mode 100644 api/services/workflow/inspector_events.py create mode 100644 api/services/workflow/node_output_inspector_service.py create mode 100644 api/tests/helpers/__init__.py create mode 100644 api/tests/helpers/legacy_model_type_migration.py create mode 100644 api/tests/integration_tests/services/test_node_output_inspector_service.py create mode 100644 api/tests/seed_legacy_model_type_dirty_data.py create mode 100644 api/tests/test_containers_integration_tests/commands/test_legacy_model_type_migration.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/datasets/test_external.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/test_feature.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/test_files.py create mode 100644 api/tests/test_containers_integration_tests/tasks/test_delete_account_task.py create mode 100644 api/tests/unit_tests/commands/test_legacy_model_type_migration.py create mode 100644 api/tests/unit_tests/controllers/console/app/test_workflow_node_output_inspector.py create mode 100644 api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py create mode 100644 api/tests/unit_tests/controllers/console/auth/test_oauth_server.py create mode 100644 api/tests/unit_tests/controllers/console/test_apikey.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_conditions.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_context.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_data.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_flow.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_prepare.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_verify.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_workspaces_members.py create mode 100644 api/tests/unit_tests/core/app/workflow/layers/test_persistence_inspector_publish.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent_v2/test_plugin_tools_builder.py create mode 100644 api/tests/unit_tests/services/workflow/test_inspector_events.py create mode 100644 api/tests/unit_tests/services/workflow/test_node_output_inspector_service.py delete mode 100644 api/tests/unit_tests/tasks/test_delete_account_task.py create mode 100644 cli/src/api/members.test.ts create mode 100644 cli/src/api/members.ts delete mode 100644 cli/src/commands/auth/use/index.ts delete mode 100644 cli/src/commands/auth/use/use.test.ts delete mode 100644 cli/src/commands/auth/use/use.ts create mode 100644 cli/src/commands/create/member/handlers.ts create mode 100644 cli/src/commands/create/member/index.ts create mode 100644 cli/src/commands/create/member/run.test.ts create mode 100644 cli/src/commands/create/member/run.ts create mode 100644 cli/src/commands/delete/member/handlers.ts create mode 100644 cli/src/commands/delete/member/index.ts create mode 100644 cli/src/commands/delete/member/run.test.ts create mode 100644 cli/src/commands/delete/member/run.ts create mode 100644 cli/src/commands/get/member/handlers.ts create mode 100644 cli/src/commands/get/member/index.ts create mode 100644 cli/src/commands/get/member/run.test.ts create mode 100644 cli/src/commands/get/member/run.ts create mode 100644 cli/src/commands/set/member/handlers.ts create mode 100644 cli/src/commands/set/member/index.ts create mode 100644 cli/src/commands/set/member/run.test.ts create mode 100644 cli/src/commands/set/member/run.ts create mode 100644 cli/src/commands/use/workspace/index.ts create mode 100644 cli/src/commands/use/workspace/use.test.ts create mode 100644 cli/src/commands/use/workspace/use.ts rename cli/src/config/{loader.test.ts => config-loader.test.ts} (80%) create mode 100644 cli/src/config/config-loader.ts delete mode 100644 cli/src/config/dir.test.ts delete mode 100644 cli/src/config/dir.ts delete mode 100644 cli/src/config/loader.ts delete mode 100644 cli/src/config/writer.ts rename cli/src/{config/writer.test.ts => store/config-writer.test.ts} (67%) create mode 100644 cli/src/store/config-writer.ts create mode 100644 cli/src/store/dir.ts create mode 100644 cli/src/store/manager.ts create mode 100644 cli/src/store/store.test.ts create mode 100644 cli/src/store/store.ts create mode 100644 cli/src/sys/index.test.ts create mode 100644 cli/src/sys/index.ts rename cli/src/{ => sys}/io/color.ts (100%) rename cli/src/{ => sys}/io/spinner.ts (100%) rename cli/src/{ => sys}/io/streams.ts (89%) rename cli/src/{ => sys}/io/think-filter.test.ts (100%) rename cli/src/{ => sys}/io/think-filter.ts (100%) create mode 100644 dify-agent/docs/dify-agent/user-manual/execution-context-layer/index.md delete mode 100644 dify-agent/docs/dify-agent/user-manual/plugin-layer/index.md create mode 100644 dify-agent/docs/dify-agent/user-manual/plugin-tool-layer/index.md delete mode 100644 dify-agent/src/dify_agent/layers/dify_plugin/plugin_layer.py create mode 100644 dify-agent/src/dify_agent/layers/dify_plugin/tool_client.py create mode 100644 dify-agent/src/dify_agent/layers/dify_plugin/tools_layer.py create mode 100644 dify-agent/src/dify_agent/layers/execution_context/__init__.py create mode 100644 dify-agent/src/dify_agent/layers/execution_context/configs.py create mode 100644 dify-agent/src/dify_agent/layers/execution_context/layer.py create mode 100644 dify-agent/src/dify_agent/plugin_daemon_transport.py create mode 100644 dify-agent/tests/local/dify_agent/layers/execution_context/test_configs.py create mode 100644 dify-agent/tests/local/dify_agent/layers/execution_context/test_layer.py create mode 100644 packages/dify-ui/src/kbd/__tests__/index.spec.tsx create mode 100644 packages/dify-ui/src/kbd/index.stories.tsx create mode 100644 packages/dify-ui/src/kbd/index.tsx create mode 100644 packages/dify-ui/src/textarea/__tests__/index.spec.tsx create mode 100644 packages/dify-ui/src/textarea/index.stories.tsx create mode 100644 packages/dify-ui/src/textarea/index.tsx delete mode 100644 web/app/components/base/textarea/__tests__/index.spec.tsx delete mode 100644 web/app/components/base/textarea/index.stories.tsx delete mode 100644 web/app/components/base/textarea/index.tsx create mode 100644 web/app/components/datasets/list/dataset-card-skeleton.tsx create mode 100644 web/app/components/snippets/hooks/__tests__/use-snippet-run.spec.ts create mode 100644 web/app/components/tools/provider/tool-card-skeleton.tsx delete mode 100644 web/app/components/workflow/__tests__/shortcuts-name.spec.tsx create mode 100644 web/app/components/workflow/nodes/_base/hooks/__tests__/use-available-var-list.spec.ts delete mode 100644 web/app/components/workflow/shortcuts-name.tsx diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 266fa17c29..3c22088ffe 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -110,3 +110,114 @@ updates: github-actions-dependencies: patterns: - "*" + - package-ecosystem: "uv" + directory: "/api" + target-branch: "lts/1.13.x" + open-pull-requests-limit: 10 + schedule: + interval: "weekly" + groups: + flask: + patterns: + - "flask" + - "flask-*" + - "werkzeug" + - "gunicorn" + google: + patterns: + - "google-*" + - "googleapis-*" + opentelemetry: + patterns: + - "opentelemetry-*" + pydantic: + patterns: + - "pydantic" + - "pydantic-*" + llm: + patterns: + - "langfuse" + - "langsmith" + - "litellm" + - "mlflow*" + - "opik" + - "weave*" + - "arize*" + - "tiktoken" + - "transformers" + database: + patterns: + - "sqlalchemy" + - "psycopg2*" + - "psycogreen" + - "redis*" + - "alembic*" + storage: + patterns: + - "boto3*" + - "botocore*" + - "azure-*" + - "bce-*" + - "cos-python-*" + - "esdk-obs-*" + - "google-cloud-storage" + - "opendal" + - "oss2" + - "supabase*" + - "tos*" + vdb: + patterns: + - "alibabacloud*" + - "chromadb" + - "clickhouse-*" + - "clickzetta-*" + - "couchbase" + - "elasticsearch" + - "opensearch-py" + - "oracledb" + - "pgvect*" + - "pymilvus" + - "pymochow" + - "pyobvector" + - "qdrant-client" + - "intersystems-*" + - "tablestore" + - "tcvectordb" + - "tidb-vector" + - "upstash-*" + - "volcengine-*" + - "weaviate-*" + - "xinference-*" + - "mo-vector" + - "mysql-connector-*" + dev: + patterns: + - "coverage" + - "dotenv-linter" + - "faker" + - "lxml-stubs" + - "basedpyright" + - "ruff" + - "pytest*" + - "types-*" + - "boto3-stubs" + - "hypothesis" + - "pandas-stubs" + - "scipy-stubs" + - "import-linter" + - "celery-types" + - "mypy*" + - "pyrefly" + python-packages: + patterns: + - "*" + - package-ecosystem: "github-actions" + directory: "/" + target-branch: "lts/1.13.x" + open-pull-requests-limit: 5 + schedule: + interval: "weekly" + groups: + github-actions-dependencies: + patterns: + - "*" diff --git a/.github/workflows/cli-tests.yml b/.github/workflows/cli-tests.yml index 8cd053651a..4498afa416 100644 --- a/.github/workflows/cli-tests.yml +++ b/.github/workflows/cli-tests.yml @@ -15,8 +15,12 @@ concurrency: jobs: test: - name: CLI Tests - runs-on: depot-ubuntu-24.04 + name: CLI Tests (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [depot-ubuntu-24.04, windows-latest, macos-latest] env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} defaults: @@ -37,7 +41,7 @@ jobs: run: pnpm ci - name: Report coverage - if: ${{ env.CODECOV_TOKEN != '' }} + if: ${{ env.CODECOV_TOKEN != '' && matrix.os == 'depot-ubuntu-24.04' }} uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 with: directory: cli/coverage diff --git a/.gitignore b/.gitignore index 207e2600e7..5b434ee4ec 100644 --- a/.gitignore +++ b/.gitignore @@ -257,5 +257,5 @@ scripts/stress-test/reports/ # Code Agent Folder .qoder/* -.context/* +.context/ .eslintcache diff --git a/api/app_factory.py b/api/app_factory.py index e9094fd8ad..49be025731 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -223,10 +223,11 @@ def initialize_extensions(app: DifyApp): def create_migrations_app() -> DifyApp: app = create_flask_app_with_configs() - from extensions import ext_database, ext_migrate + from extensions import ext_commands, ext_database, ext_migrate # Initialize only required extensions ext_database.init_app(app) ext_migrate.init_app(app) + ext_commands.init_app(app) return app diff --git a/api/clients/agent_backend/__init__.py b/api/clients/agent_backend/__init__.py index 2e3777f61b..4d459d34a0 100644 --- a/api/clients/agent_backend/__init__.py +++ b/api/clients/agent_backend/__init__.py @@ -30,7 +30,8 @@ from clients.agent_backend.factory import create_agent_backend_run_client from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario from clients.agent_backend.request_builder import ( AGENT_SOUL_PROMPT_LAYER_ID, - DIFY_PLUGIN_CONTEXT_LAYER_ID, + DIFY_EXECUTION_CONTEXT_LAYER_ID, + DIFY_PLUGIN_TOOLS_LAYER_ID, WORKFLOW_NODE_JOB_PROMPT_LAYER_ID, WORKFLOW_USER_PROMPT_LAYER_ID, AgentBackendModelConfig, @@ -42,7 +43,8 @@ from clients.agent_backend.request_builder import ( __all__ = [ "AGENT_SOUL_PROMPT_LAYER_ID", - "DIFY_PLUGIN_CONTEXT_LAYER_ID", + "DIFY_EXECUTION_CONTEXT_LAYER_ID", + "DIFY_PLUGIN_TOOLS_LAYER_ID", "WORKFLOW_NODE_JOB_PROMPT_LAYER_ID", "WORKFLOW_USER_PROMPT_LAYER_ID", "AgentBackendError", diff --git a/api/clients/agent_backend/request_builder.py b/api/clients/agent_backend/request_builder.py index 392eee641b..74114469dd 100644 --- a/api/clients/agent_backend/request_builder.py +++ b/api/clients/agent_backend/request_builder.py @@ -4,7 +4,9 @@ This module is intentionally an adapter, not a wire DTO package. The emitted object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend protocol has a single owner. API-only context such as Agent Soul vs workflow job prompt is preserved in layer names and metadata until the dedicated product -schemas land in later phases. +schemas land in later phases. Dify-owned execution identifiers are emitted as an +explicit ``dify.execution_context`` layer so the run request stays fully +composition-driven. """ from __future__ import annotations @@ -15,18 +17,21 @@ from agenton.compositor import CompositorSessionSnapshot from agenton.layers import ExitIntent from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig from dify_agent.layers.dify_plugin import ( - DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID, + DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, DifyPluginCredentialValue, - DifyPluginLayerConfig, DifyPluginLLMLayerConfig, + DifyPluginToolsLayerConfig, +) +from dify_agent.layers.execution_context import ( + DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + DifyExecutionContextLayerConfig, ) from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig from dify_agent.protocol import ( DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID, CreateRunRequest, - ExecutionContext, LayerExitSignals, RunComposition, RunLayerSpec, @@ -37,17 +42,16 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt" WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt" WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt" -DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin" +DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context" +DIFY_PLUGIN_TOOLS_LAYER_ID = "tools" class AgentBackendModelConfig(BaseModel): """API-side model/plugin selection before it is converted to Dify Agent layers.""" - tenant_id: str plugin_id: str model_provider: str model: str - user_id: str | None = None credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict) model_settings: dict[str, JsonValue] = Field(default_factory=dict) @@ -73,13 +77,14 @@ class AgentBackendWorkflowNodeRunInput(BaseModel): """Inputs needed to build the first workflow-node-oriented Agent backend run request.""" model: AgentBackendModelConfig - execution_context: ExecutionContext + execution_context: DifyExecutionContextLayerConfig workflow_node_job_prompt: str user_prompt: str agent_soul_prompt: str | None = None purpose: RunPurpose = "workflow_node" idempotency_key: str | None = None output: AgentBackendOutputConfig | None = None + tools: DifyPluginToolsLayerConfig | None = None session_snapshot: CompositorSessionSnapshot | None = None suspend_on_exit: bool = False metadata: dict[str, JsonValue] = Field(default_factory=dict) @@ -125,21 +130,18 @@ class AgentBackendRunRequestBuilder: config=PromptLayerConfig(user=run_input.user_prompt), ), RunLayerSpec( - name=DIFY_PLUGIN_CONTEXT_LAYER_ID, - type=DIFY_PLUGIN_LAYER_TYPE_ID, + name=DIFY_EXECUTION_CONTEXT_LAYER_ID, + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, metadata=run_input.metadata, - config=DifyPluginLayerConfig( - tenant_id=run_input.model.tenant_id, - plugin_id=run_input.model.plugin_id, - user_id=run_input.model.user_id, - ), + config=run_input.execution_context, ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID, - deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID}, + deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}, metadata=run_input.metadata, config=DifyPluginLLMLayerConfig( + plugin_id=run_input.model.plugin_id, model_provider=run_input.model.model_provider, model=run_input.model.model, credentials=run_input.model.credentials, @@ -149,6 +151,17 @@ class AgentBackendRunRequestBuilder: ] ) + if run_input.tools is not None and run_input.tools.tools: + layers.append( + RunLayerSpec( + name=DIFY_PLUGIN_TOOLS_LAYER_ID, + type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}, + metadata=run_input.metadata, + config=run_input.tools, + ) + ) + if run_input.output is not None: layers.append( RunLayerSpec( @@ -165,7 +178,6 @@ class AgentBackendRunRequestBuilder: return CreateRunRequest( composition=RunComposition(layers=layers), - execution_context=run_input.execution_context, purpose=run_input.purpose, idempotency_key=run_input.idempotency_key, metadata=run_input.metadata, diff --git a/api/commands/__init__.py b/api/commands/__init__.py index d62d0dbd7c..9d1bf7d0fe 100644 --- a/api/commands/__init__.py +++ b/api/commands/__init__.py @@ -3,6 +3,7 @@ CLI command modules extracted from `commands.py`. """ from .account import create_tenant, reset_email, reset_password +from .data_migrate import data_migrate, legacy_model_types from .plugin import ( extract_plugins, extract_unique_plugins, @@ -44,6 +45,7 @@ __all__ = [ "clear_orphaned_file_records", "convert_to_agent_apps", "create_tenant", + "data_migrate", "delete_archived_workflow_runs", "export_app_messages", "extract_plugins", @@ -52,6 +54,7 @@ __all__ = [ "fix_app_site_missing", "install_plugins", "install_rag_pipeline_plugins", + "legacy_model_types", "migrate_annotation_vector_database", "migrate_data_for_plugin", "migrate_knowledge_vector_database", diff --git a/api/commands/data_migrate.py b/api/commands/data_migrate.py new file mode 100644 index 0000000000..2b33f46cd8 --- /dev/null +++ b/api/commands/data_migrate.py @@ -0,0 +1,179 @@ +import io +import os +import sys +from contextlib import AbstractContextManager, nullcontext +from pathlib import Path +from typing import cast + +import click + +from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType +from services.legacy_model_type_migration import ( + VALID_TABLE_NAMES, + LegacyModelTypeMigrationService, + load_tenant_ids_from_file, +) + +_SUPPORTED_MODEL_TYPE_CHOICES = ( + ModelType.LLM.value, + ModelType.TEXT_EMBEDDING.value, + ModelType.RERANK.value, +) +_DEFAULT_CONCURRENCY = os.cpu_count() or 1 + + +def _normalize_multi_value_option( + values: tuple[str, ...], + *, + valid_values: tuple[str, ...], + option_name: str, +) -> tuple[str, ...]: + normalized_values: list[str] = [] + seen_values: set[str] = set() + + for value in values: + for item in value.split(","): + normalized_item = item.strip() + if not normalized_item: + continue + if normalized_item not in valid_values: + raise click.BadParameter( + f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}", + param_hint=option_name, + ) + if normalized_item in seen_values: + continue + seen_values.add(normalized_item) + normalized_values.append(normalized_item) + + return tuple(normalized_values) + + +@click.group( + "data-migrate", + help="Online data migration commands.", +) +def data_migrate() -> None: + """Namespace for production data migration commands.""" + + +@click.command( + "legacy-model-types", + help=( + "Migrate legacy provider model_type values to canonical values. " + "Default is dry-run and emits JSON lines only. " + "If --tables includes provider_model_credentials, the command may also update " + "provider_models and load_balancing_model_configs references so merged credentials stay reachable." + ), +) +@click.option( + "--apply", + is_flag=True, + default=False, + help="Apply the migration. Default is dry-run.", +) +@click.option( + "--tables", + "tables", + multiple=True, + type=str, + help=( + "Limit migration to specific tables. Accepts comma-separated values or repeated flags.\n" + "\n" + "Options: load_balancing_model_configs, provider_model_credentials, " + "provider_model_settings, provider_models, tenant_default_models.\n\n" + "When provider_model_credentials is selected, provider_models and " + "load_balancing_model_configs may also be updated for credential reference rewrites.\n" + "\n" + "If unspecified, all relevant tables are migrated." + ), +) +@click.option( + "--model-types", + "model_types", + multiple=True, + type=str, + help=( + "Canonical model types to migrate. Accepts comma-separated values or repeated flags.\n" + "\n" + "Options: llm,text-embedding,rerank\n" + "\n" + "If unspecified, all relevant legacy model types are migrated." + ), +) +@click.option( + "--tenant-id-file", + type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True), + help="Optional file containing tenant ids, one per line.", +) +@click.option( + "--output", + type=click.Path(dir_okay=False, resolve_path=True, path_type=Path), + help=( + "Optional file path for JSON lines event logs. Defaults to stdout.\n" + "It's highly recommended to save the event logs to a file and preserve it for a period of time." + ), +) +@click.option( + "--concurrency", + type=click.IntRange(min=1), + default=_DEFAULT_CONCURRENCY, + show_default=True, + help="Number of tenant-level worker threads to run in parallel.", +) +def legacy_model_types( + apply: bool, + tables: tuple[str, ...], + model_types: tuple[str, ...], + tenant_id_file: str | None, + output: Path | None, + concurrency: int = _DEFAULT_CONCURRENCY, +) -> None: + """ + Migrate legacy provider-related model_type values and emit JSON lines events. + """ + + normalized_tables = _normalize_multi_value_option( + tables, + valid_values=VALID_TABLE_NAMES, + option_name="--tables", + ) + normalized_model_types = _normalize_multi_value_option( + model_types, + valid_values=_SUPPORTED_MODEL_TYPE_CHOICES, + option_name="--model-types", + ) + selected_model_types = ( + tuple(ModelType.value_of(model_type) for model_type in normalized_model_types) + if normalized_model_types + else ( + ModelType.LLM, + ModelType.TEXT_EMBEDDING, + ModelType.RERANK, + ) + ) + tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None + + output_context: AbstractContextManager[io.TextIOBase] + if output is None: + output_context = nullcontext(cast(io.TextIOBase, sys.stdout)) + else: + try: + output_context = output.open("w", encoding="utf-8") + except OSError as exc: + raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc + + with output_context as output_stream: + LegacyModelTypeMigrationService( + engine=db.engine, + apply=apply, + concurrency=concurrency, + output=cast(io.TextIOBase, output_stream), + tables=normalized_tables or None, + model_types=selected_model_types, + tenant_ids=tenant_ids, + ).migrate() + + +data_migrate.add_command(legacy_model_types) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index eb9b0ac2ab..2f3a3ed2bd 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -41,3 +41,21 @@ class MilvusConfig(BaseSettings): description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.', default=None, ) + + MILVUS_SECURE: bool = Field( + description="Enable TLS for the Milvus connection (one-way TLS). When True, the client uses gRPC over TLS " + "and verifies the server certificate. Equivalent to passing secure=True to pymilvus.", + default=False, + ) + + MILVUS_SERVER_PEM_PATH: str | None = Field( + description="Filesystem path inside the container to the Milvus server certificate (PEM). Mount this via " + "a Kubernetes secret. Used as pymilvus's server_pem_path when MILVUS_SECURE is True.", + default=None, + ) + + MILVUS_SERVER_NAME: str | None = Field( + description="Server name (TLS SNI / certificate CN or SAN) to verify against the Milvus server certificate. " + "Required when MILVUS_SERVER_PEM_PATH is set.", + default=None, + ) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index dc416d47c5..bffc29e23a 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -68,6 +68,7 @@ from .app import ( workflow_app_log, workflow_comment, workflow_draft_variable, + workflow_node_output_inspector, workflow_run, workflow_statistic, workflow_trigger, @@ -223,6 +224,7 @@ __all__ = [ "workflow_app_log", "workflow_comment", "workflow_draft_variable", + "workflow_node_output_inspector", "workflow_run", "workflow_statistic", "workflow_trigger", diff --git a/api/controllers/console/agent/composer.py b/api/controllers/console/agent/composer.py index d716ab9fd2..85b54ce7cc 100644 --- a/api/controllers/console/agent/composer.py +++ b/api/controllers/console/agent/composer.py @@ -5,7 +5,7 @@ from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from libs.login import current_account_with_tenant, login_required -from models.model import AppMode +from models.model import App, AppMode from services.agent.composer_service import AgentComposerService from services.agent.composer_validator import ComposerConfigValidator from services.entities.agent_entities import ComposerSavePayload @@ -19,7 +19,7 @@ class WorkflowAgentComposerApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def get(self, app_model, node_id: str): + def get(self, app_model: App, node_id: str): _, tenant_id = current_account_with_tenant() return AgentComposerService.load_workflow_composer( tenant_id=tenant_id, @@ -33,7 +33,7 @@ class WorkflowAgentComposerApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def put(self, app_model, node_id: str): + def put(self, app_model: App, node_id: str): account, tenant_id = current_account_with_tenant() payload = ComposerSavePayload.model_validate(console_ns.payload or {}) return AgentComposerService.save_workflow_composer( @@ -52,7 +52,7 @@ class WorkflowAgentComposerValidateApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def post(self, app_model, node_id: str): + def post(self, app_model: App, node_id: str): payload = ComposerSavePayload.model_validate(console_ns.payload or {}) ComposerConfigValidator.validate_save_payload(payload) return {"result": "success", "errors": []} @@ -64,7 +64,7 @@ class WorkflowAgentComposerCandidatesApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def get(self, app_model, node_id: str): + def get(self, app_model: App, node_id: str): return AgentComposerService.get_workflow_candidates(app_id=app_model.id) @@ -74,7 +74,7 @@ class WorkflowAgentComposerImpactApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def post(self, app_model, node_id: str): + def post(self, app_model: App, node_id: str): _, tenant_id = current_account_with_tenant() payload = ComposerSavePayload.model_validate(console_ns.payload or {}) current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None @@ -91,7 +91,7 @@ class WorkflowAgentComposerSaveToRosterApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def post(self, app_model, node_id: str): + def post(self, app_model: App, node_id: str): account, tenant_id = current_account_with_tenant() payload = ComposerSavePayload.model_validate(console_ns.payload or {}) return AgentComposerService.save_workflow_composer( @@ -109,7 +109,7 @@ class AgentAppComposerApi(Resource): @login_required @account_initialization_required @get_app_model() - def get(self, app_model): + def get(self, app_model: App): _, tenant_id = current_account_with_tenant() return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id) @@ -119,7 +119,7 @@ class AgentAppComposerApi(Resource): @account_initialization_required @edit_permission_required @get_app_model() - def put(self, app_model): + def put(self, app_model: App): account, tenant_id = current_account_with_tenant() payload = ComposerSavePayload.model_validate(console_ns.payload or {}) return AgentComposerService.save_agent_app_composer( @@ -137,7 +137,7 @@ class AgentAppComposerValidateApi(Resource): @login_required @account_initialization_required @get_app_model() - def post(self, app_model): + def post(self, app_model: App): payload = ComposerSavePayload.model_validate(console_ns.payload or {}) ComposerConfigValidator.validate_save_payload(payload) return {"result": "success", "errors": []} @@ -149,5 +149,5 @@ class AgentAppComposerCandidatesApi(Resource): @login_required @account_initialization_required @get_app_model() - def get(self, app_model): + def get(self, app_model: App): return AgentComposerService.get_agent_app_candidates(app_id=app_model.id) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 133c57d34d..57470dc977 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -9,18 +9,25 @@ from sqlalchemy import delete, func, select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden -from controllers.common.schema import register_schema_models +from controllers.common.schema import register_response_schema_models from extensions.ext_database import db from fields.base import ResponseModel -from libs.helper import to_timestamp -from libs.login import current_account_with_tenant, login_required +from libs.helper import dump_response, to_timestamp +from libs.login import login_required +from models import Account from models.dataset import Dataset from models.enums import ApiTokenType from models.model import ApiToken, App from services.api_token_service import ApiTokenCache from . import console_ns -from .wraps import account_initialization_required, edit_permission_required, setup_required +from .wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user, +) class ApiKeyItem(ResponseModel): @@ -40,7 +47,7 @@ class ApiKeyList(ResponseModel): data: list[ApiKeyItem] -register_schema_models(console_ns, ApiKeyItem, ApiKeyList) +register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList) def _get_resource(resource_id, tenant_id, resource_model): @@ -64,10 +71,11 @@ class BaseApiKeyListResource(Resource): token_prefix: str | None = None max_keys = 10 - def get(self, resource_id): + def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]: + return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id)) + + def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList: assert self.resource_id_field is not None, "resource_id_field must be set" - resource_id = str(resource_id) - _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) keys = db.session.scalars( @@ -75,13 +83,14 @@ class BaseApiKeyListResource(Resource): ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ) ).all() - return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json") + return ApiKeyList.model_validate({"data": keys}, from_attributes=True) @edit_permission_required - def post(self, resource_id): + def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]: + return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201 + + def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken: assert self.resource_id_field is not None, "resource_id_field must be set" - resource_id = str(resource_id) - _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) current_key_count: int = ( db.session.scalar( @@ -108,7 +117,7 @@ class BaseApiKeyListResource(Resource): api_token.type = self.resource_type db.session.add(api_token) db.session.commit() - return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201 + return api_token class BaseApiKeyResource(Resource): @@ -118,9 +127,20 @@ class BaseApiKeyResource(Resource): resource_model: type | None = None resource_id_field: str | None = None - def delete(self, resource_id: str, api_key_id: str): + def delete( + self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account + ) -> tuple[str, int]: + self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user) + return "", 204 + + def _delete_api_key( + self, + resource_id: str, + api_key_id: str, + current_tenant_id: str, + current_user: Account, + ) -> None: assert self.resource_id_field is not None, "resource_id_field must be set" - current_user, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) if not current_user.is_admin_or_owner: @@ -147,8 +167,6 @@ class BaseApiKeyResource(Resource): db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id)) db.session.commit() - return "", 204 - @console_ns.route("/apps//api-keys") class AppApiKeyListResource(BaseApiKeyListResource): @@ -156,18 +174,21 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc(description="Get all API keys for an app") @console_ns.doc(params={"resource_id": "App ID"}) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) - def get(self, resource_id: UUID): + @with_current_tenant_id + def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]: """Get all API keys for an app""" - return super().get(resource_id) + return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id)) @console_ns.doc("create_app_api_key") @console_ns.doc(description="Create a new API key for an app") @console_ns.doc(params={"resource_id": "App ID"}) @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") - def post(self, resource_id: UUID): + @with_current_tenant_id + @edit_permission_required + def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]: """Create a new API key for an app""" - return super().post(resource_id) + return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201 resource_type = ApiTokenType.APP resource_model = App @@ -181,9 +202,14 @@ class AppApiKeyResource(BaseApiKeyResource): @console_ns.doc(description="Delete an API key for an app") @console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) @console_ns.response(204, "API key deleted successfully") - def delete(self, resource_id: UUID, api_key_id: UUID): + @with_current_user + @with_current_tenant_id + def delete( + self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID + ) -> tuple[str, int]: """Delete an API key for an app""" - return super().delete(str(resource_id), str(api_key_id)) + self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user) + return "", 204 resource_type = ApiTokenType.APP resource_model = App @@ -196,18 +222,21 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc(description="Get all API keys for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) - def get(self, resource_id: UUID): + @with_current_tenant_id + def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]: """Get all API keys for a dataset""" - return super().get(resource_id) + return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id)) @console_ns.doc("create_dataset_api_key") @console_ns.doc(description="Create a new API key for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") - def post(self, resource_id: UUID): + @with_current_tenant_id + @edit_permission_required + def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]: """Create a new API key for a dataset""" - return super().post(resource_id) + return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201 resource_type = ApiTokenType.DATASET resource_model = Dataset @@ -221,9 +250,14 @@ class DatasetApiKeyResource(BaseApiKeyResource): @console_ns.doc(description="Delete an API key for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) @console_ns.response(204, "API key deleted successfully") - def delete(self, resource_id: UUID, api_key_id: UUID): + @with_current_user + @with_current_tenant_id + def delete( + self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID + ) -> tuple[str, int]: """Delete an API key for a dataset""" - return super().delete(str(resource_id), str(api_key_id)) + self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user) + return "", 204 resource_type = ApiTokenType.DATASET resource_model = Dataset diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index c05600ced5..277c86ced3 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -8,7 +8,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from libs.helper import uuid_value from libs.login import login_required -from models.model import AppMode +from models.model import App, AppMode from services.agent_service import AgentService @@ -39,7 +39,7 @@ class AgentLogApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT]) - def get(self, app_model): + def get(self, app_model: App): """Get agent logs""" args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5a965208c6..06555e5842 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -573,7 +573,7 @@ class AppApi(Resource): @account_initialization_required @enterprise_license_required @get_app_model(mode=None) - def get(self, app_model): + def get(self, app_model: App): """Get app detail""" app_service = AppService() @@ -581,7 +581,7 @@ class AppApi(Resource): if FeatureService.get_system_features().webapp_auth.enabled: app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) - app_model.access_mode = app_setting.access_mode + app_model.access_mode = app_setting.access_mode # type: ignore[attr-defined] response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) return response_model.model_dump(mode="json") @@ -598,7 +598,7 @@ class AppApi(Resource): @account_initialization_required @get_app_model(mode=None) @edit_permission_required - def put(self, app_model): + def put(self, app_model: App): """Update app""" args = UpdateAppPayload.model_validate(console_ns.payload) @@ -627,7 +627,7 @@ class AppApi(Resource): @login_required @account_initialization_required @edit_permission_required - def delete(self, app_model): + def delete(self, app_model: App): """Delete app""" app_service = AppService() app_service.delete_app(app_model) @@ -648,7 +648,7 @@ class AppCopyApi(Resource): @account_initialization_required @get_app_model(mode=None) @edit_permission_required - def post(self, app_model): + def post(self, app_model: App): """Copy app""" # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() @@ -709,7 +709,7 @@ class AppExportApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_model): + def get(self, app_model: App): """Export app""" args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) @@ -731,7 +731,7 @@ class AppPublishToCreatorsPlatformApi(Resource): @account_initialization_required @get_app_model(mode=None) @edit_permission_required - def post(self, app_model): + def post(self, app_model: App): """Publish app to Creators Platform""" from configs import dify_config from core.helper.creators import get_redirect_url, upload_dsl @@ -762,7 +762,7 @@ class AppNameApi(Resource): @account_initialization_required @get_app_model(mode=None) @edit_permission_required - def post(self, app_model): + def post(self, app_model: App): args = AppNamePayload.model_validate(console_ns.payload) app_service = AppService() @@ -784,7 +784,7 @@ class AppIconApi(Resource): @account_initialization_required @get_app_model(mode=None) @edit_permission_required - def post(self, app_model): + def post(self, app_model: App): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() @@ -811,7 +811,7 @@ class AppSiteStatus(Resource): @account_initialization_required @get_app_model(mode=None) @edit_permission_required - def post(self, app_model): + def post(self, app_model: App): args = AppSiteStatusPayload.model_validate(console_ns.payload) app_service = AppService() @@ -833,7 +833,7 @@ class AppApiStatus(Resource): @is_admin_or_owner_required @account_initialization_required @get_app_model(mode=None) - def post(self, app_model): + def post(self, app_model: App): args = AppApiStatusPayload.model_validate(console_ns.payload) app_service = AppService() @@ -874,7 +874,7 @@ class AppTraceApi(Resource): @account_initialization_required @edit_permission_required @get_app_model - def post(self, app_model): + def post(self, app_model: App): # add app trace args = AppTracePayload.model_validate(console_ns.payload) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 5b673f3394..acf2215e45 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -70,7 +70,7 @@ class ChatMessageAudioApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - def post(self, app_model): + def post(self, app_model: App): file = request.files["file"] try: @@ -171,7 +171,7 @@ class TextModesApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): try: args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fddfe2f4bc..8983a33d16 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -33,7 +33,7 @@ from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required from models import Account -from models.model import AppMode +from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService from services.errors.llm import InvokeRateLimitError @@ -84,7 +84,7 @@ class CompletionMessageApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - def post(self, app_model): + def post(self, app_model: App): args_model = CompletionMessagePayload.model_validate(console_ns.payload) args = args_model.model_dump(exclude_none=True, by_alias=True) @@ -131,7 +131,7 @@ class CompletionMessageStopApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - def post(self, app_model, task_id: str): + def post(self, app_model: App, task_id: str): if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") @@ -159,7 +159,7 @@ class ChatMessageApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @edit_permission_required - def post(self, app_model): + def post(self, app_model: App): args_model = ChatMessagePayload.model_validate(console_ns.payload) args = args_model.model_dump(exclude_none=True, by_alias=True) @@ -212,7 +212,7 @@ class ChatMessageStopApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - def post(self, app_model, task_id: str): + def post(self, app_model: App, task_id: str): if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 0ca7a08286..a5b1a8c77d 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -33,7 +33,7 @@ from fields.conversation_fields import ( from libs.datetime_utils import naive_utc_now, parse_time_range from libs.login import current_account_with_tenant, login_required from models import Conversation, EndUser, Message, MessageAnnotation -from models.model import AppMode +from models.model import App, AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError @@ -93,7 +93,7 @@ class CompletionConversationApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @edit_permission_required - def get(self, app_model): + def get(self, app_model: App): current_user, _ = current_account_with_tenant() args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) @@ -165,7 +165,7 @@ class CompletionConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @edit_permission_required - def get(self, app_model, conversation_id: UUID): + def get(self, app_model: App, conversation_id: UUID): conversation_id_str = str(conversation_id) return ConversationMessageDetailResponse.model_validate( _get_conversation(app_model, conversation_id_str), from_attributes=True @@ -182,7 +182,7 @@ class CompletionConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @edit_permission_required - def delete(self, app_model, conversation_id: UUID): + def delete(self, app_model: App, conversation_id: UUID): current_user, _ = current_account_with_tenant() conversation_id_str = str(conversation_id) @@ -207,7 +207,7 @@ class ChatConversationApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @edit_permission_required - def get(self, app_model): + def get(self, app_model: App): current_user, _ = current_account_with_tenant() args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) @@ -318,7 +318,7 @@ class ChatConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @edit_permission_required - def get(self, app_model, conversation_id: UUID): + def get(self, app_model: App, conversation_id: UUID): conversation_id_str = str(conversation_id) return ConversationDetailResponse.model_validate( _get_conversation(app_model, conversation_id_str), from_attributes=True @@ -335,7 +335,7 @@ class ChatConversationDetailApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required @edit_permission_required - def delete(self, app_model, conversation_id: UUID): + def delete(self, app_model: App, conversation_id: UUID): current_user, _ = current_account_with_tenant() conversation_id_str = str(conversation_id) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 5951f7405a..beaef48275 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -19,7 +19,7 @@ from fields.base import ResponseModel from libs.helper import to_timestamp from libs.login import login_required from models import ConversationVariable -from models.model import AppMode +from models.model import App, AppMode class ConversationVariablesQuery(BaseModel): @@ -94,7 +94,7 @@ class ConversationVariablesApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) - def get(self, app_model): + def get(self, app_model: App): args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) stmt = ( diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index a5259527ea..e5ef15c3e1 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -17,7 +17,7 @@ from fields.base import ResponseModel from libs.helper import to_timestamp from libs.login import current_account_with_tenant, login_required from models.enums import AppMCPServerStatus -from models.model import AppMCPServer +from models.model import App, AppMCPServer class MCPServerCreatePayload(BaseModel): @@ -73,7 +73,7 @@ class AppMCPServerController(Resource): @account_initialization_required @setup_required @get_app_model - def get(self, app_model): + def get(self, app_model: App): server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) if server is None: return {} @@ -92,7 +92,7 @@ class AppMCPServerController(Resource): @login_required @setup_required @edit_permission_required - def post(self, app_model): + def post(self, app_model: App): _, current_tenant_id = current_account_with_tenant() payload = MCPServerCreatePayload.model_validate(console_ns.payload or {}) @@ -127,7 +127,7 @@ class AppMCPServerController(Resource): @setup_required @account_initialization_required @edit_permission_required - def put(self, app_model): + def put(self, app_model: App): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) server = db.session.get(AppMCPServer, payload.id) if not server: diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 7445fed86c..15b3437bf9 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -45,7 +45,7 @@ from libs.helper import to_timestamp, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required from models.enums import FeedbackFromSource, FeedbackRating -from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback +from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService, attach_message_extra_contents @@ -180,7 +180,7 @@ class ChatMessageListApi(Resource): @setup_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @edit_permission_required - def get(self, app_model): + def get(self, app_model: App): args = ChatMessagesQuery.model_validate(request.args.to_dict()) conversation = db.session.scalar( @@ -257,7 +257,7 @@ class MessageFeedbackApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_model): + def post(self, app_model: App): current_user, _ = current_account_with_tenant() args = MessageFeedbackPayload.model_validate(console_ns.payload) @@ -314,7 +314,7 @@ class MessageAnnotationCountApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): count = db.session.scalar( select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id) ) @@ -337,7 +337,7 @@ class MessageSuggestedQuestionApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - def get(self, app_model, message_id: UUID): + def get(self, app_model: App, message_id: UUID): current_user, _ = current_account_with_tenant() message_id_str = str(message_id) @@ -379,7 +379,7 @@ class MessageFeedbackExportApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): args = FeedbackExportQuery.model_validate(request.args.to_dict()) # Import the service function @@ -417,7 +417,7 @@ class MessageApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model, message_id: UUID): + def get(self, app_model: App, message_id: UUID): message_id_str = str(message_id) message = db.session.scalar( diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 1869cbf5f6..a893b66911 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -16,7 +16,7 @@ from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required -from models.model import AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService @@ -52,7 +52,7 @@ class ModelConfigResource(Resource): @edit_permission_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) - def post(self, app_model): + def post(self, app_model: App): """Modify app model config""" current_user, current_tenant_id = current_account_with_tenant() # validate config diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 9991d78d94..ca7f194a35 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -20,6 +20,7 @@ from fields.base import ResponseModel from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import Site +from models.model import App class AppSiteUpdatePayload(BaseModel): @@ -84,7 +85,7 @@ class AppSite(Resource): @edit_permission_required @account_initialization_required @get_app_model - def post(self, app_model): + def post(self, app_model: App): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) @@ -133,7 +134,7 @@ class AppSiteAccessTokenReset(Resource): @is_admin_or_owner_required @account_initialization_required @get_app_model - def post(self, app_model): + def post(self, app_model: App): current_user, _ = current_account_with_tenant() site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index d23b2837c9..dd3cf273a2 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -15,6 +15,7 @@ from libs.datetime_utils import parse_time_range from libs.helper import convert_datetime_to_date from libs.login import current_account_with_tenant, login_required from models import AppMode +from models.model import App class StatisticTimeRangeQuery(BaseModel): @@ -47,7 +48,7 @@ class DailyMessageStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) @@ -61,8 +62,12 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} assert account.timezone is not None + arg_dict: dict[str, object] = { + "tz": account.timezone, + "app_id": app_model.id, + "invoke_from": InvokeFrom.DEBUGGER, + } try: start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) @@ -104,7 +109,7 @@ class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) @@ -118,8 +123,12 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} assert account.timezone is not None + arg_dict: dict[str, object] = { + "tz": account.timezone, + "app_id": app_model.id, + "invoke_from": InvokeFrom.DEBUGGER, + } try: start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) @@ -160,7 +169,7 @@ class DailyTerminalsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) @@ -174,8 +183,12 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} assert account.timezone is not None + arg_dict: dict[str, object] = { + "tz": account.timezone, + "app_id": app_model.id, + "invoke_from": InvokeFrom.DEBUGGER, + } try: start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) @@ -217,7 +230,7 @@ class DailyTokenCostStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) @@ -232,8 +245,12 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} assert account.timezone is not None + arg_dict: dict[str, object] = { + "tz": account.timezone, + "app_id": app_model.id, + "invoke_from": InvokeFrom.DEBUGGER, + } try: start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) @@ -277,7 +294,7 @@ class AverageSessionInteractionStatistic(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) @@ -299,8 +316,12 @@ FROM WHERE c.app_id = :app_id AND m.invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} assert account.timezone is not None + arg_dict: dict[str, object] = { + "tz": account.timezone, + "app_id": app_model.id, + "invoke_from": InvokeFrom.DEBUGGER, + } try: start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) @@ -353,7 +374,7 @@ class UserSatisfactionRateStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) @@ -371,8 +392,12 @@ LEFT JOIN WHERE m.app_id = :app_id AND m.invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} assert account.timezone is not None + arg_dict: dict[str, object] = { + "tz": account.timezone, + "app_id": app_model.id, + "invoke_from": InvokeFrom.DEBUGGER, + } try: start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) @@ -419,7 +444,7 @@ class AverageResponseTimeStatistic(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) @@ -433,8 +458,12 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} assert account.timezone is not None + arg_dict: dict[str, object] = { + "tz": account.timezone, + "app_id": app_model.id, + "invoke_from": InvokeFrom.DEBUGGER, + } try: start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) @@ -476,7 +505,7 @@ class TokensPerSecondStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) @@ -492,8 +521,12 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} assert account.timezone is not None + arg_dict: dict[str, object] = { + "tz": account.timezone, + "app_id": app_model.id, + "invoke_from": InvokeFrom.DEBUGGER, + } try: start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 83f0d1dde6..5aa243597a 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -83,13 +83,14 @@ def _serialize_var_value(variable: WorkflowDraftVariable): # create a copy of the value to avoid affecting the model cache. value = value.model_copy(deep=True) # Refresh the url signature before returning it to client. - if isinstance(value, FileSegment): - file = value.value - file.remote_url = file.generate_url() - elif isinstance(value, ArrayFileSegment): - files = value.value - for file in files: + match value: + case FileSegment(): + file = value.value file.remote_url = file.generate_url() + case ArrayFileSegment(): + files = value.value + for file in files: + file.remote_url = file.generate_url() return _convert_values_to_json_serializable_object(value) diff --git a/api/controllers/console/app/workflow_node_output_inspector.py b/api/controllers/console/app/workflow_node_output_inspector.py new file mode 100644 index 0000000000..7da3ebe32b --- /dev/null +++ b/api/controllers/console/app/workflow_node_output_inspector.py @@ -0,0 +1,415 @@ +"""Console REST endpoints for the Node Output Inspector (Stage 4 §8 / §10.3). + +PRD §Node Output Inspector replaces the consumer-organized Variable Inspector +with a producer-organized view of each node's declared outputs and their +per-run status. This module exposes two parallel sets of three read-only +endpoints — one for ``/workflows/draft/runs/...`` (Composer test runs) and one +for ``/workflows/published/runs/...`` (real App API / webapp / webhook / +schedule / plugin triggers). Both sets share the same service code, the same +response shapes, and the same error codes; the URL is the *only* difference, +so the frontend can pick the right prefix based on which run-detail page the +user is on. + +Decision D-1 (published Inspector deferred) was lifted 2026-05-26 — the +``published_run_inspector_not_implemented`` 404 code is therefore no longer +produced. + +URLs follow the design doc and reuse the existing +``/apps//workflows/draft/...`` prefix from +:mod:`controllers.console.app.workflow_draft_variable`. The +``published`` prefix mirrors it shape-for-shape. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Iterator +from uuid import UUID + +from flask import Response +from flask_restx import Resource + +from controllers.console import console_ns +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from libs.exception import BaseHTTPException +from libs.login import login_required +from models import App, AppMode +from services.workflow import inspector_events +from services.workflow.node_output_inspector_service import ( + NodeOutputInspectorError, + NodeOutputInspectorService, +) + +logger = logging.getLogger(__name__) + + +# Heartbeat cadence — every N empty subscribe ticks emit a SSE comment so +# intervening proxies (nginx, ingress) don't reap the idle connection. +# ``inspector_events.subscribe`` ticks at 1s, so 15 → 15s heartbeat. +_HEARTBEAT_EVERY_TICKS = 15 +# Hard ceiling on a single stream — if we never see a terminal workflow +# event (engine crashed, redis dropped the message), force-close after this +# many ticks (= seconds). +_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min + + +def _service() -> NodeOutputInspectorService: + """One-line factory so tests can monkeypatch a stub if needed.""" + return NodeOutputInspectorService() + + +def _serve_snapshot(app_model: App, run_id: UUID) -> dict: + """Resource-body shared by draft + published snapshot endpoints. + + Pulled out so the 6 REST routes don't duplicate the same 6-line try/except + + ``model_dump`` ritual — the routes shrink to one-liners and the actual + behaviour lives here, where unit tests can hit it without spinning up + Flask request context. + """ + try: + snapshot = _service().snapshot_workflow_run(app_model=app_model, workflow_run_id=str(run_id)) + except NodeOutputInspectorError as error: + raise _InspectorNotFound(error) from error + return snapshot.model_dump(mode="json") + + +def _serve_node_detail(app_model: App, run_id: UUID, node_id: str) -> dict: + """Resource-body shared by draft + published node-detail endpoints.""" + try: + view = _service().node_detail( + app_model=app_model, + workflow_run_id=str(run_id), + node_id=node_id, + ) + except NodeOutputInspectorError as error: + raise _InspectorNotFound(error) from error + return view.model_dump(mode="json") + + +def _serve_output_preview(app_model: App, run_id: UUID, node_id: str, output_name: str) -> dict: + """Resource-body shared by draft + published output-preview endpoints.""" + try: + preview = _service().output_preview( + app_model=app_model, + workflow_run_id=str(run_id), + node_id=node_id, + output_name=output_name, + ) + except NodeOutputInspectorError as error: + raise _InspectorNotFound(error) from error + return preview.model_dump(mode="json") + + +class _InspectorNotFound(BaseHTTPException): + """404 that preserves the inspector's specific error code. + + Without this the response body collapses to a generic ``not_found`` code + and clients lose the ability to distinguish, e.g., + ``workflow_run_not_found`` from ``published_run_inspector_not_implemented``. + """ + + code = 404 + + def __init__(self, error: NodeOutputInspectorError) -> None: + self.error_code = error.code + super().__init__(description=str(error)) + + +@console_ns.route("/apps//workflows/draft/runs//node-outputs") +class WorkflowDraftRunNodeOutputsApi(Resource): + """Whole-run snapshot organized by producer node.""" + + @console_ns.doc("get_workflow_draft_run_node_outputs") + @console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.") + @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @console_ns.response(404, "Workflow run not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, run_id: UUID): + return _serve_snapshot(app_model, run_id) + + +@console_ns.route("/apps//workflows/draft/runs//node-outputs/") +class WorkflowDraftRunNodeOutputDetailApi(Resource): + """One node's declared outputs + per-output status.""" + + @console_ns.doc("get_workflow_draft_run_node_output_detail") + @console_ns.doc(description="One node's declared outputs for a draft workflow run.") + @console_ns.doc( + params={ + "app_id": "Application ID", + "run_id": "Workflow run ID", + "node_id": "Node ID inside the workflow graph", + } + ) + @console_ns.response(404, "Workflow run / node not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, run_id: UUID, node_id: str): + return _serve_node_detail(app_model, run_id, node_id) + + +@console_ns.route( + "/apps//workflows/draft/runs//node-outputs///preview" +) +class WorkflowDraftRunNodeOutputPreviewApi(Resource): + """Full value for one declared output (with signed URL for file refs).""" + + @console_ns.doc("get_workflow_draft_run_node_output_preview") + @console_ns.doc(description="Full value for one declared output, including signed download URL for files.") + @console_ns.doc( + params={ + "app_id": "Application ID", + "run_id": "Workflow run ID", + "node_id": "Node ID inside the workflow graph", + "output_name": "Declared output name as exposed by Composer", + } + ) + @console_ns.response(404, "Workflow run / node / output not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str): + return _serve_output_preview(app_model, run_id, node_id, output_name) + + +# ────────────────────────────────────────────────────────────────────────────── +# SSE event stream — shared generator used by draft + published variants +# ────────────────────────────────────────────────────────────────────────────── + + +def _sse_envelope(event: str, data: dict | str, event_id: int) -> str: + """Format one SSE record per D-5 ``{event, data, id}`` envelope. + + ``data`` is JSON-serialized when given as a dict; raw strings are + forwarded unchanged so we can also emit ``:keepalive`` comment lines. + """ + payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False) + return f"event: {event}\nid: {event_id}\ndata: {payload}\n\n" + + +def _stream_inspector_events(app_model: App, run_id: UUID) -> Iterator[str]: + """Yield SSE-framed strings for one workflow run. + + The stream begins with a full ``snapshot`` event so the client has a + starting state without needing a separate REST GET. Then for every + ``node_changed`` message from the pub/sub channel we re-read that node + from DB and push a fresh ``node_changed`` event. When the workflow run + reaches a terminal state we push one final ``workflow_run_completed`` + event and close the stream. + + Failures inside the loop are caught and surfaced as ``error`` events so + the frontend can show a banner rather than seeing the connection drop + silently. The Inspector never raises across the SSE boundary. + """ + service = _service() + run_id_str = str(run_id) + + # Initial snapshot — also flushes a 404 back at the client right away + # if the run is gone (raised before yielding any bytes, so Flask turns it + # into the normal HTTP 404 path). + try: + snapshot = service.snapshot_workflow_run(app_model=app_model, workflow_run_id=run_id_str) + except NodeOutputInspectorError as error: + raise _InspectorNotFound(error) from error + + event_id = 0 + yield _sse_envelope("snapshot", snapshot.model_dump(mode="json"), event_id) + + # If the run already finished by the time the client connected, emit + # the terminal envelope synchronously and close — no point subscribing. + # The enum value for partial success is the hyphenated ``partial-succeeded`` + # (graphon.enums.WorkflowExecutionStatus), not ``partial_succeeded``. + if snapshot.workflow_run_status.value in {"succeeded", "failed", "stopped", "partial-succeeded"}: + event_id += 1 + yield _sse_envelope( + "workflow_run_completed", + {"workflow_run_id": run_id_str, "workflow_run_status": snapshot.workflow_run_status.value}, + event_id, + ) + return + + # Live subscription + ticks_since_heartbeat = 0 + total_ticks = 0 + for message in inspector_events.subscribe(run_id_str, timeout_seconds=1.0): + total_ticks += 1 + if total_ticks > _STREAM_HARD_TIMEOUT_TICKS: + logger.warning( + "Inspector SSE: forcing close after %ds without terminal event for run %s", + _STREAM_HARD_TIMEOUT_TICKS, + run_id_str, + ) + return + + # Heartbeat sentinel — ``inspector_events.subscribe`` synthesizes a + # ``node_changed`` message with both fields ``None`` on every redis + # timeout. Real ``workflow_completed`` messages keep their kind even + # when status couldn't be resolved (publisher race), so checking kind + # first makes the heartbeat branch safe. + if message.kind == "node_changed" and message.node_id is None and message.status is None: + ticks_since_heartbeat += 1 + if ticks_since_heartbeat >= _HEARTBEAT_EVERY_TICKS: + yield ":keepalive\n\n" + ticks_since_heartbeat = 0 + continue + ticks_since_heartbeat = 0 + + if message.kind == "workflow_completed": + event_id += 1 + yield _sse_envelope( + "workflow_run_completed", + {"workflow_run_id": run_id_str, "workflow_run_status": message.status or "unknown"}, + event_id, + ) + return + + # node_changed: recompute the node slice from DB + if not message.node_id: + continue + try: + node_view = service.node_detail( + app_model=app_model, + workflow_run_id=run_id_str, + node_id=message.node_id, + ) + except NodeOutputInspectorError: + # Node may not appear in the graph yet (race with persistence); skip. + continue + except Exception: + logger.warning( + "Inspector SSE: node_detail failed for run %s node %s", + run_id_str, + message.node_id, + exc_info=True, + ) + event_id += 1 + yield _sse_envelope( + "error", + {"node_id": message.node_id, "message": "failed to refresh node detail"}, + event_id, + ) + continue + + event_id += 1 + yield _sse_envelope("node_changed", node_view.model_dump(mode="json"), event_id) + + +@console_ns.route("/apps//workflows/draft/runs//node-outputs/events") +class WorkflowDraftRunNodeOutputEventsApi(Resource): + """SSE stream of inspector deltas for a draft run.""" + + @console_ns.doc("stream_workflow_draft_run_node_output_events") + @console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.") + @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @console_ns.response(404, "Workflow run not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, run_id: UUID): + return Response( + _stream_inspector_events(app_model, run_id), + mimetype="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Published-run endpoints — symmetric to the draft trio above +# ────────────────────────────────────────────────────────────────────────────── + + +@console_ns.route("/apps//workflows/published/runs//node-outputs") +class WorkflowPublishedRunNodeOutputsApi(Resource): + """Whole-run snapshot for a *published* workflow run. + + Same response shape as the ``/draft/`` variant — frontend can multiplex + based on which page (Composer test-run vs. Run History) is mounted. + """ + + @console_ns.doc("get_workflow_published_run_node_outputs") + @console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.") + @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @console_ns.response(404, "Workflow run not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, run_id: UUID): + return _serve_snapshot(app_model, run_id) + + +@console_ns.route("/apps//workflows/published/runs//node-outputs/") +class WorkflowPublishedRunNodeOutputDetailApi(Resource): + """One node's declared outputs + per-output status (published run).""" + + @console_ns.doc("get_workflow_published_run_node_output_detail") + @console_ns.doc(description="One node's declared outputs for a published workflow run.") + @console_ns.doc( + params={ + "app_id": "Application ID", + "run_id": "Workflow run ID", + "node_id": "Node ID inside the workflow graph", + } + ) + @console_ns.response(404, "Workflow run / node not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, run_id: UUID, node_id: str): + return _serve_node_detail(app_model, run_id, node_id) + + +@console_ns.route( + "/apps//workflows/published/runs/" + "/node-outputs///preview" +) +class WorkflowPublishedRunNodeOutputPreviewApi(Resource): + """Full value for one declared output of a published run.""" + + @console_ns.doc("get_workflow_published_run_node_output_preview") + @console_ns.doc(description="Full value for one declared output of a published run.") + @console_ns.doc( + params={ + "app_id": "Application ID", + "run_id": "Workflow run ID", + "node_id": "Node ID inside the workflow graph", + "output_name": "Declared output name as exposed by Composer", + } + ) + @console_ns.response(404, "Workflow run / node / output not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str): + return _serve_output_preview(app_model, run_id, node_id, output_name) + + +@console_ns.route("/apps//workflows/published/runs//node-outputs/events") +class WorkflowPublishedRunNodeOutputEventsApi(Resource): + """SSE stream of inspector deltas for a published run.""" + + @console_ns.doc("stream_workflow_published_run_node_output_events") + @console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.") + @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @console_ns.response(404, "Workflow run not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, run_id: UUID): + return Response( + _stream_inspector_events(app_model, run_id), + mimetype="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index ca899d8784..7b5a628561 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -11,7 +11,7 @@ from extensions.ext_database import db from libs.datetime_utils import parse_time_range from libs.login import current_account_with_tenant, login_required from models.enums import WorkflowRunTriggeredFrom -from models.model import AppMode +from models.model import App, AppMode from repositories.factory import DifyAPIRepositoryFactory @@ -46,7 +46,7 @@ class WorkflowDailyRunsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) @@ -86,7 +86,7 @@ class WorkflowDailyTerminalsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) @@ -126,7 +126,7 @@ class WorkflowDailyTokenCostStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) @@ -166,7 +166,7 @@ class WorkflowAverageAppInteractionStatistic(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - def get(self, app_model): + def get(self, app_model: App): account, _ = current_account_with_tenant() args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 7544c4dbdc..33f6fb14ae 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -5,12 +5,12 @@ from pydantic import BaseModel, Field from controllers.common.schema import register_response_schema_models, register_schema_models from fields.base import ResponseModel -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService from .. import console_ns from ..auth.error import ApiKeyAuthFailedError -from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id class ApiKeyAuthBindingPayload(BaseModel): @@ -42,8 +42,8 @@ class ApiKeyAuthDataSource(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str): data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id) if data_source_api_key_bindings: return { @@ -69,9 +69,9 @@ class ApiKeyAuthDataSourceBinding(Resource): @account_initialization_required @is_admin_or_owner_required @console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__]) - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): # The role of the current user in the table must be admin or owner - _, current_tenant_id = current_account_with_tenant() payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload) data = payload.model_dump() ApiKeyAuthService.validate_api_key_auth_args(data) @@ -89,10 +89,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @account_initialization_required @is_admin_or_owner_required @console_ns.response(204, "Binding deleted successfully") - def delete(self, binding_id: UUID): + @with_current_tenant_id + def delete(self, current_tenant_id: str, binding_id: UUID): # The role of the current user in the table must be admin or owner - _, current_tenant_id = current_account_with_tenant() - ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id)) return "", 204 diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 727428c8e7..7e48558977 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -8,9 +8,9 @@ from flask_restx import Resource from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_user from graphon.model_runtime.utils.encoders import jsonable_encoder -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models import Account from models.model import OAuthProviderApp from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService @@ -133,12 +133,10 @@ class OAuthServerUserAuthorizeApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @oauth_server_client_id_required - def post(self, oauth_provider_app: OAuthProviderApp): - current_user, _ = current_account_with_tenant() - account = current_user - user_account_id = account.id - + def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account): + user_account_id = current_user.id code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) return jsonable_encoder( { diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 38ad7dfdd1..3b01ef7558 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -169,9 +169,12 @@ class DatasetDocumentSegmentListApi(Resource): # Use database-specific methods for JSON array search if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": # PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text + # Guard with jsonb_typeof to avoid "cannot extract elements from a scalar" error + # when keywords is null or a non-array JSON value. keywords_condition = func.array_to_string( func.array( select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB))) + .where(func.jsonb_typeof(cast(DocumentSegment.keywords, JSONB)) == "array") .correlate(DocumentSegment) .scalar_subquery() ), diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 110a2e16f5..37640138eb 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,15 +1,12 @@ from __future__ import annotations -from datetime import datetime -from typing import Any from uuid import UUID from flask_restx import Resource -from pydantic import Field, field_validator -from controllers.common.schema import register_schema_models -from fields.base import ResponseModel -from libs.helper import to_timestamp +from controllers.common.schema import register_response_schema_models, register_schema_models +from fields.hit_testing_fields import HitTestingResponse +from libs.helper import dump_response from libs.login import login_required from .. import console_ns @@ -20,86 +17,8 @@ from ..wraps import ( setup_required, ) - -class HitTestingDocument(ResponseModel): - id: str | None = None - data_source_type: str | None = None - name: str | None = None - doc_type: str | None = None - doc_metadata: Any | None = None - - -class HitTestingSegment(ResponseModel): - id: str | None = None - position: int | None = None - document_id: str | None = None - content: str | None = None - sign_content: str | None = None - answer: str | None = None - word_count: int | None = None - tokens: int | None = None - keywords: list[str] = Field(default_factory=list) - index_node_id: str | None = None - index_node_hash: str | None = None - hit_count: int | None = None - enabled: bool | None = None - disabled_at: int | None = None - disabled_by: str | None = None - status: str | None = None - created_by: str | None = None - created_at: int | None = None - indexing_at: int | None = None - completed_at: int | None = None - error: str | None = None - stopped_at: int | None = None - document: HitTestingDocument | None = None - - @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") - @classmethod - def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return to_timestamp(value) - - -class HitTestingChildChunk(ResponseModel): - id: str | None = None - content: str | None = None - position: int | None = None - score: float | None = None - - -class HitTestingFile(ResponseModel): - id: str | None = None - name: str | None = None - size: int | None = None - extension: str | None = None - mime_type: str | None = None - source_url: str | None = None - - -class HitTestingRecord(ResponseModel): - segment: HitTestingSegment | None = None - child_chunks: list[HitTestingChildChunk] = Field(default_factory=list) - score: float | None = None - tsne_position: Any | None = None - files: list[HitTestingFile] = Field(default_factory=list) - summary: str | None = None - - -class HitTestingResponse(ResponseModel): - query: str - records: list[HitTestingRecord] = Field(default_factory=list) - - -register_schema_models( - console_ns, - HitTestingPayload, - HitTestingDocument, - HitTestingSegment, - HitTestingChildChunk, - HitTestingFile, - HitTestingRecord, - HitTestingResponse, -) +register_schema_models(console_ns, HitTestingPayload) +register_response_schema_models(console_ns, HitTestingResponse) @console_ns.route("/datasets//hit-testing") @@ -119,12 +38,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self, dataset_id: UUID): + def post(self, dataset_id: UUID) -> dict[str, object]: dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - payload = HitTestingPayload.model_validate(console_ns.payload or {}) - args = payload.model_dump(exclude_none=True) + args = self.parse_args(console_ns.payload) self.hit_testing_args_check(args) - return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json") + return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args)) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index bb725a5f6c..4be91e0e54 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,6 @@ import logging -from typing import Any +from typing import Any, cast -from flask_restx import marshal from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -19,10 +18,10 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from fields.hit_testing_fields import hit_testing_record_fields from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account +from models.dataset import Dataset from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.hit_testing_service import HitTestingService @@ -38,16 +37,6 @@ class HitTestingPayload(BaseModel): class DatasetsHitTestingBase: - @staticmethod - def _extract_hit_testing_query(query: Any) -> str: - """Return the query string from the service response shape.""" - if isinstance(query, dict): - content = query.get("content") - if isinstance(content, str): - return content - - raise ValueError("Invalid hit testing query response") - @staticmethod def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]: """Ensure collection fields match the API schema before response validation.""" @@ -63,6 +52,7 @@ class DatasetsHitTestingBase: segment = normalized_record.get("segment") if isinstance(segment, dict): normalized_segment = dict(segment) + normalized_segment.setdefault("sign_content", None) if normalized_segment.get("keywords") is None: normalized_segment["keywords"] = [] normalized_record["segment"] = normalized_segment @@ -73,12 +63,15 @@ class DatasetsHitTestingBase: if normalized_record.get("files") is None: normalized_record["files"] = [] + normalized_record.setdefault("tsne_position", None) + normalized_record.setdefault("summary", None) + normalized_records.append(normalized_record) return normalized_records @staticmethod - def get_and_validate_dataset(dataset_id: str): + def get_and_validate_dataset(dataset_id: str) -> Dataset: assert isinstance(current_user, Account) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -92,33 +85,35 @@ class DatasetsHitTestingBase: return dataset @staticmethod - def hit_testing_args_check(args: dict[str, Any]): + def hit_testing_args_check(args: dict[str, Any]) -> None: HitTestingService.hit_testing_args_check(args) @staticmethod - def parse_args(payload: dict[str, Any]) -> dict[str, Any]: + def parse_args(payload: dict[str, Any] | None) -> dict[str, Any]: """Validate and return hit-testing arguments from an incoming payload.""" hit_testing_payload = HitTestingPayload.model_validate(payload or {}) return hit_testing_payload.model_dump(exclude_none=True) @staticmethod - def perform_hit_testing(dataset, args): + def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]: assert isinstance(current_user, Account) try: response = HitTestingService.retrieve( dataset=dataset, - query=args.get("query"), + query=cast(str, args.get("query")), account=current_user, retrieval_model=args.get("retrieval_model"), - external_retrieval_model=args.get("external_retrieval_model"), + external_retrieval_model=cast(dict[str, Any], args.get("external_retrieval_model")), attachment_ids=args.get("attachment_ids"), limit=10, ) + query = response.get("query") + if not isinstance(query, dict) or not isinstance(query.get("content"), str): + raise ValueError("Invalid hit testing query response") + return { - "query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")), - "records": DatasetsHitTestingBase._prepare_hit_testing_records( - marshal(response.get("records", []), hit_testing_record_fields) - ), + "query": {"content": query["content"]}, + "records": DatasetsHitTestingBase._prepare_hit_testing_records(response.get("records", [])), } except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index ab660d9dc3..42b611cafd 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -20,6 +20,7 @@ from controllers.console.app.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from graphon.model_runtime.errors.invoke import InvokeError +from models.model import InstalledApp from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -40,8 +41,10 @@ register_schema_model(console_ns, TextToAudioPayload) endpoint="installed_app_audio", ) class ChatAudioApi(InstalledAppResource): - def post(self, installed_app): + def post(self, installed_app: InstalledApp): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() file = request.files["file"] @@ -81,8 +84,10 @@ class ChatAudioApi(InstalledAppResource): ) class ChatTextApi(InstalledAppResource): @console_ns.expect(console_ns.models[TextToAudioPayload.__name__]) - def post(self, installed_app): + def post(self, installed_app: InstalledApp): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() try: payload = TextToAudioPayload.model_validate(console_ns.payload or {}) diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 72e2d923da..c08e8690a8 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -31,7 +31,7 @@ from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account -from models.model import AppMode +from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService from services.errors.llm import InvokeRateLimitError @@ -83,8 +83,10 @@ register_response_schema_models(console_ns, SimpleResultResponse) ) class CompletionApi(InstalledAppResource): @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__]) - def post(self, installed_app): + def post(self, installed_app: InstalledApp): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() @@ -133,8 +135,10 @@ class CompletionApi(InstalledAppResource): ) class CompletionStopApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def post(self, installed_app, task_id: str): + def post(self, installed_app: InstalledApp, task_id: str): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() @@ -157,8 +161,10 @@ class CompletionStopApi(InstalledAppResource): ) class ChatApi(InstalledAppResource): @console_ns.expect(console_ns.models[ChatMessagePayload.__name__]) - def post(self, installed_app): + def post(self, installed_app: InstalledApp): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -209,8 +215,10 @@ class ChatApi(InstalledAppResource): ) class ChatStopApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def post(self, installed_app, task_id: str): + def post(self, installed_app: InstalledApp, task_id: str): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index a3ae59aaf7..68e18a0207 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -8,6 +8,7 @@ from werkzeug.exceptions import NotFound from controllers.common.controller_schemas import ConversationRenamePayload from controllers.common.schema import register_response_schema_models, register_schema_models +from controllers.console.app.error import AppUnavailableError from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom @@ -20,7 +21,7 @@ from fields.conversation_fields import ( from libs.helper import UUIDStrOrEmpty from libs.login import current_user from models import Account -from models.model import AppMode +from models.model import AppMode, InstalledApp from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService @@ -44,8 +45,10 @@ register_response_schema_models(console_ns, ResultResponse) ) class ConversationListApi(InstalledAppResource): @console_ns.expect(console_ns.models[ConversationListQuery.__name__]) - def get(self, installed_app): + def get(self, installed_app: InstalledApp): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -92,8 +95,10 @@ class ConversationListApi(InstalledAppResource): ) class ConversationApi(InstalledAppResource): @console_ns.response(204, "Conversation deleted successfully") - def delete(self, installed_app, c_id: UUID): + def delete(self, installed_app: InstalledApp, c_id: UUID): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -115,8 +120,10 @@ class ConversationApi(InstalledAppResource): ) class ConversationRenameApi(InstalledAppResource): @console_ns.expect(console_ns.models[ConversationRenamePayload.__name__]) - def post(self, installed_app, c_id: UUID): + def post(self, installed_app: InstalledApp, c_id: UUID): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -146,8 +153,10 @@ class ConversationRenameApi(InstalledAppResource): ) class ConversationPinApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__]) - def patch(self, installed_app, c_id: UUID): + def patch(self, installed_app: InstalledApp, c_id: UUID): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -170,8 +179,10 @@ class ConversationPinApi(InstalledAppResource): ) class ConversationUnPinApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__]) - def patch(self, installed_app, c_id: UUID): + def patch(self, installed_app: InstalledApp, c_id: UUID): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 4ad3dbc85f..bd4d1ef49f 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -262,7 +262,7 @@ class InstalledAppApi(InstalledAppResource): """ @console_ns.response(204, "App uninstalled successfully") - def delete(self, installed_app): + def delete(self, installed_app: InstalledApp): _, current_tenant_id = current_account_with_tenant() if installed_app.app_owner_tenant_id == current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") @@ -273,7 +273,7 @@ class InstalledAppApi(InstalledAppResource): return "", 204 @console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__]) - def patch(self, installed_app): + def patch(self, installed_app: InstalledApp): payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {}) commit_args = False diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index c6930a76cb..a19355f90b 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -10,6 +10,7 @@ from controllers.common.controller_schemas import MessageFeedbackPayload, Messag from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console.app.error import ( AppMoreLikeThisDisabledError, + AppUnavailableError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -21,15 +22,16 @@ from controllers.console.explore.error import ( NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource +from controllers.console.wraps import with_current_user from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.login import current_account_with_tenant +from models import Account from models.enums import FeedbackRating -from models.model import AppMode +from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError @@ -59,9 +61,11 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe ) class MessageListApi(InstalledAppResource): @console_ns.expect(console_ns.models[MessageListQuery.__name__]) - def get(self, installed_app): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, installed_app: InstalledApp): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -96,9 +100,11 @@ class MessageListApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource): @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__]) @console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__]) - def post(self, installed_app, message_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, installed_app: InstalledApp, message_id: UUID): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() message_id_str = str(message_id) @@ -124,9 +130,11 @@ class MessageFeedbackApi(InstalledAppResource): ) class MessageMoreLikeThisApi(InstalledAppResource): @console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__]) - def get(self, installed_app, message_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() if app_model.mode != "completion": raise NotCompletionAppError() @@ -170,9 +178,11 @@ class MessageMoreLikeThisApi(InstalledAppResource): ) class MessageSuggestedQuestionApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__]) - def get(self, installed_app, message_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index fc863b78d7..cf48eeea72 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -7,12 +7,14 @@ from werkzeug.exceptions import NotFound from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns +from controllers.console.app.error import AppUnavailableError from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import with_current_user from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from models import Account +from models.model import InstalledApp from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -24,8 +26,10 @@ register_response_schema_models(console_ns, ResultResponse) class SavedMessageListApi(InstalledAppResource): @console_ns.expect(console_ns.models[SavedMessageListQuery.__name__]) @with_current_user - def get(self, current_user: Account, installed_app): + def get(self, current_user: Account, installed_app: InstalledApp): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() if app_model.mode != "completion": raise NotCompletionAppError() @@ -48,8 +52,10 @@ class SavedMessageListApi(InstalledAppResource): @console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__]) @console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__]) @with_current_user - def post(self, current_user: Account, installed_app): + def post(self, current_user: Account, installed_app: InstalledApp): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() if app_model.mode != "completion": raise NotCompletionAppError() @@ -69,8 +75,10 @@ class SavedMessageListApi(InstalledAppResource): class SavedMessageApi(InstalledAppResource): @console_ns.response(204, "Saved message deleted successfully") @with_current_user - def delete(self, current_user: Account, installed_app, message_id: UUID): + def delete(self, current_user: Account, installed_app: InstalledApp, message_id: UUID): app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() message_id_str = str(message_id) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 029e2e7f0d..26a348da40 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator from constants import HIDDEN_VALUE from fields.base import ResponseModel from libs.helper import to_timestamp -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models from . import console_ns -from .wraps import account_initialization_required, setup_required +from .wraps import account_initialization_required, setup_required, with_current_tenant_id class CodeBasedExtensionQuery(BaseModel): @@ -116,11 +116,11 @@ class APIBasedExtensionAPI(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str): return [ _serialize_api_based_extension(extension) - for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id) ] @console_ns.doc("create_api_based_extension") @@ -130,9 +130,9 @@ class APIBasedExtensionAPI(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) - _, current_tenant_id = current_account_with_tenant() extension_data = APIBasedExtension( tenant_id=current_tenant_id, @@ -153,12 +153,12 @@ class APIBasedExtensionDetailAPI(Resource): @setup_required @login_required @account_initialization_required - def get(self, id: UUID): + @with_current_tenant_id + def get(self, current_tenant_id: str, id: UUID): api_based_extension_id = str(id) - _, tenant_id = current_account_with_tenant() return _serialize_api_based_extension( - APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) ) @console_ns.doc("update_api_based_extension") @@ -169,9 +169,9 @@ class APIBasedExtensionDetailAPI(Resource): @setup_required @login_required @account_initialization_required - def post(self, id: UUID): + @with_current_tenant_id + def post(self, current_tenant_id: str, id: UUID): api_based_extension_id = str(id) - _, current_tenant_id = current_account_with_tenant() extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) @@ -197,9 +197,9 @@ class APIBasedExtensionDetailAPI(Resource): @setup_required @login_required @account_initialization_required - def delete(self, id: UUID): + @with_current_tenant_id + def delete(self, current_tenant_id: str, id: UUID): api_based_extension_id = str(id) - _, current_tenant_id = current_account_with_tenant() extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 654991900d..86ef961948 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -2,11 +2,11 @@ from flask_restx import Resource from werkzeug.exceptions import Unauthorized from controllers.common.schema import register_response_schema_models -from libs.login import current_account_with_tenant, current_user, login_required +from libs.login import current_user, login_required from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel from . import console_ns -from .wraps import account_initialization_required, cloud_utm_record, setup_required +from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel) @@ -24,10 +24,9 @@ class FeatureApi(Resource): @login_required @account_initialization_required @cloud_utm_record - def get(self): + @with_current_tenant_id + def get(self, current_tenant_id: str): """Get feature configuration for current tenant""" - _, current_tenant_id = current_account_with_tenant() - payload = FeatureService.get_features( current_tenant_id, exclude_vector_space=True, @@ -49,10 +48,9 @@ class FeatureVectorSpaceApi(Resource): @login_required @account_initialization_required @cloud_utm_record - def get(self): + @with_current_tenant_id + def get(self, current_tenant_id: str): """Get vector-space usage and limit for current tenant""" - _, current_tenant_id = current_account_with_tenant() - return FeatureService.get_vector_space(current_tenant_id).model_dump() diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 499a623872..5197120c13 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -22,10 +22,13 @@ from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, setup_required, + with_current_tenant_id, + with_current_user, ) from extensions.ext_database import db from fields.file_fields import FileResponse, UploadConfig -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from services.file_service import FileService from . import console_ns @@ -62,8 +65,8 @@ class FileApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("documents") @console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): source_str = request.form.get("source") source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None @@ -107,10 +110,10 @@ class FilePreviewApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__]) - def get(self, file_id: UUID): + @with_current_tenant_id + def get(self, current_tenant_id: str, file_id: UUID): file_id_str = str(file_id) - _, tenant_id = current_account_with_tenant() - text = FileService(db.engine).get_file_preview(file_id_str, tenant_id) + text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id) return {"content": text} diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 19f1fd8aab..9f7fe6379c 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -12,11 +12,13 @@ from controllers.common.errors import ( ) from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns +from controllers.console.wraps import with_current_user from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from graphon.file import helpers as file_helpers -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from services.file_service import FileService @@ -49,7 +51,8 @@ class RemoteFileUpload(Resource): @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) @console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__]) @login_required - def post(self): + @with_current_user + def post(self, current_user: Account): payload = RemoteFileUploadPayload.model_validate(console_ns.payload) url = payload.url @@ -74,12 +77,11 @@ class RemoteFileUpload(Resource): content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content try: - user, _ = current_account_with_tenant() upload_file = FileService(db.engine).upload_file( filename=file_info.filename, content=content, mimetype=file_info.mimetype, - user=user, + user=current_user, source_url=url, ) except services.errors.file.FileTooLargeError as file_too_large_error: diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index ece21f81f8..d2ec06c062 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -9,9 +9,16 @@ from werkzeug.exceptions import Forbidden from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from fields.base import ResponseModel -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from models.enums import TagType from services.tag_service import ( SaveTagPayload, @@ -95,8 +102,8 @@ class TagListApi(Resource): } ) @console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])}) - def get(self): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str): raw_args = request.args.to_dict() param = TagListQueryParam.model_validate(raw_args) tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) @@ -112,9 +119,9 @@ class TagListApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, or editor + @with_current_user + def post(self, current_user: Account): + # Allow users with edit permission, or dataset editors (including dataset operators). if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() @@ -135,8 +142,8 @@ class TagUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, tag_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def patch(self, current_user: Account, tag_id: UUID): tag_id_str = str(tag_id) # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -166,20 +173,19 @@ class TagUpdateDeleteApi(Resource): return "", 204 -def _require_tag_binding_edit_permission() -> None: +def _require_tag_binding_edit_permission(current_user: Account) -> None: """ Ensure the current account can edit tag bindings. Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant. """ - current_user, _ = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() -def _create_tag_bindings() -> tuple[dict[str, str], int]: - _require_tag_binding_edit_permission() +def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission(current_user) payload = TagBindingPayload.model_validate(console_ns.payload or {}) TagService.save_tag_binding( @@ -192,8 +198,8 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]: return {"result": "success"}, 200 -def _remove_tag_bindings() -> tuple[dict[str, str], int]: - _require_tag_binding_edit_permission() +def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission(current_user) payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) TagService.delete_tag_binding( @@ -216,8 +222,9 @@ class TagBindingCollectionApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - return _create_tag_bindings() + @with_current_user + def post(self, current_user: Account): + return _create_tag_bindings(current_user) @console_ns.route("/tag-bindings/remove") @@ -231,5 +238,6 @@ class TagBindingRemoveApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - return _remove_tag_bindings() + @with_current_user + def post(self, current_user: Account): + return _remove_tag_bindings(current_user) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 279633b008..1ff788016e 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -161,7 +161,7 @@ def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None: if new_member_count <= 0: return - features = FeatureService.get_features(tenant_id=tenant_id) + features = FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True) if dify_config.ENTERPRISE_ENABLED: workspace_members = features.workspace_members diff --git a/api/controllers/openapi/__init__.py b/api/controllers/openapi/__init__.py index 998829098a..f89ef0111c 100644 --- a/api/controllers/openapi/__init__.py +++ b/api/controllers/openapi/__init__.py @@ -37,6 +37,13 @@ from controllers.openapi._models import ( DeviceMutateRequest, DeviceMutateResponse, DevicePollRequest, + MemberActionResponse, + MemberInvitePayload, + MemberInviteResponse, + MemberListQuery, + MemberListResponse, + MemberResponse, + MemberRoleUpdatePayload, MessageMetadata, PermittedExternalAppsListQuery, PermittedExternalAppsListResponse, @@ -63,6 +70,9 @@ register_schema_models( DevicePollRequest, DeviceLookupQuery, DeviceMutateRequest, + MemberInvitePayload, + MemberListQuery, + MemberRoleUpdatePayload, PermittedExternalAppsListQuery, ) register_response_schema_models( @@ -86,6 +96,10 @@ register_response_schema_models( WorkspaceSummaryResponse, WorkspaceListResponse, WorkspaceDetailResponse, + MemberResponse, + MemberListResponse, + MemberInviteResponse, + MemberActionResponse, DeviceCodeResponse, DeviceLookupResponse, DeviceMutateResponse, diff --git a/api/controllers/openapi/_models.py b/api/controllers/openapi/_models.py index 128a937549..59b2e5176e 100644 --- a/api/controllers/openapi/_models.py +++ b/api/controllers/openapi/_models.py @@ -6,7 +6,7 @@ from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator -from libs.helper import UUIDStrOrEmpty, uuid_value +from libs.helper import EmailStr, UUIDStrOrEmpty, uuid_value from models.model import AppMode # Server-side cap on `limit` query param for /openapi/v1/* list endpoints. @@ -342,3 +342,61 @@ class ApprovalGrantClaimsPayload(BaseModel): user_code: str = Field(min_length=1, max_length=32) nonce: str = Field(min_length=1, max_length=128) csrf_token: str = Field(min_length=1, max_length=128) + + +# Closed enum for invite/update-role payloads. Owner is intentionally not +# assignable through these endpoints — ownership transfer goes through the +# console's three-step email-verification flow. +MemberAssignableRole = Literal["normal", "admin"] + + +class MemberResponse(BaseModel): + id: str + name: str + email: str + role: str + status: str + avatar: str | None = None + + +class MemberListResponse(BaseModel): + page: int + limit: int + total: int + has_more: bool + data: list[MemberResponse] + + +class MemberListQuery(BaseModel): + """Strict (extra='forbid').""" + + model_config = ConfigDict(extra="forbid") + + page: int = Field(1, ge=1) + limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT) + + +class MemberInvitePayload(BaseModel): + model_config = ConfigDict(extra="forbid") + + email: EmailStr + role: MemberAssignableRole + + +class MemberRoleUpdatePayload(BaseModel): + model_config = ConfigDict(extra="forbid") + + role: MemberAssignableRole + + +class MemberInviteResponse(BaseModel): + result: Literal["success"] = "success" + email: str + role: str + member_id: str + invite_url: str + tenant_id: str + + +class MemberActionResponse(BaseModel): + result: Literal["success"] = "success" diff --git a/api/controllers/openapi/account.py b/api/controllers/openapi/account.py index 602d7e7ab4..256a822dcb 100644 --- a/api/controllers/openapi/account.py +++ b/api/controllers/openapi/account.py @@ -4,7 +4,7 @@ from datetime import UTC, datetime from flask import request from flask_restx import Resource -from werkzeug.exceptions import BadRequest, NotFound +from werkzeug.exceptions import NotFound from controllers.openapi import openapi_ns from controllers.openapi._models import ( @@ -17,18 +17,17 @@ from controllers.openapi._models import ( SessionRow, WorkspacePayload, ) +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.oauth_bearer import ( - ACCEPT_USER_ANY, - AuthContext, - SubjectType, + Scope, + TokenType, get_auth_ctx, - validate_bearer, ) from libs.rate_limit import ( LIMIT_ME_PER_ACCOUNT, - LIMIT_ME_PER_EMAIL, enforce, ) from services.account_service import AccountService, TenantService @@ -42,32 +41,18 @@ from services.oauth_device_flow import ( @openapi_ns.route("/account") class AccountApi(Resource): @openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - def get(self): - ctx = get_auth_ctx() + @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, *, auth_data: AuthData): + enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}") - if ctx.subject_type == SubjectType.EXTERNAL_SSO: - enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}") - else: - enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}") - - if ctx.subject_type == SubjectType.EXTERNAL_SSO: - return AccountResponse( - subject_type=ctx.subject_type, - subject_email=ctx.subject_email, - subject_issuer=ctx.subject_issuer, - account=None, - workspaces=[], - default_workspace_id=None, - ).model_dump(mode="json") - - account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None - memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else [] + account_id_str = str(auth_data.account_id) if auth_data.account_id else None + account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None + memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else [] default_ws_id = _pick_default_workspace(memberships) return AccountResponse( - subject_type=ctx.subject_type, - subject_email=ctx.subject_email or (account.email if account else None), + subject_type="account", + subject_email=account.email if account else None, account=_account_payload(account) if account else None, workspaces=[_workspace_payload(m) for m in memberships], default_workspace_id=default_ws_id, @@ -77,19 +62,17 @@ class AccountApi(Resource): @openapi_ns.route("/account/sessions/self") class AccountSessionsSelfApi(Resource): @openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - def delete(self): - ctx = get_auth_ctx() - _require_oauth_subject(ctx) - revoke_oauth_token(db.session, redis_client, str(ctx.token_id)) + @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def delete(self, *, auth_data: AuthData): + revoke_oauth_token(db.session, redis_client, str(auth_data.token_id)) return RevokeResponse(status="revoked").model_dump(mode="json"), 200 @openapi_ns.route("/account/sessions") class AccountSessionsApi(Resource): @openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - def get(self): + @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, *, auth_data: AuthData): ctx = get_auth_ctx() now = datetime.now(UTC) page = int(request.args.get("page", "1")) @@ -122,10 +105,9 @@ class AccountSessionsApi(Resource): @openapi_ns.route("/account/sessions/") class AccountSessionByIdApi(Resource): @openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - def delete(self, session_id: str): + @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def delete(self, session_id: str, *, auth_data: AuthData): ctx = get_auth_ctx() - _require_oauth_subject(ctx) # 404 (not 403) on cross-subject so the endpoint doesn't leak # token IDs that belong to other subjects. @@ -136,13 +118,6 @@ class AccountSessionByIdApi(Resource): return RevokeResponse(status="revoked").model_dump(mode="json"), 200 -def _require_oauth_subject(ctx: AuthContext) -> None: - if not ctx.source.startswith("oauth"): - raise BadRequest( - "this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs" - ) - - def _iso(dt: datetime | None) -> str | None: if dt is None: return None diff --git a/api/controllers/openapi/app_run.py b/api/controllers/openapi/app_run.py index 95a26d50fa..8ef94740c9 100644 --- a/api/controllers/openapi/app_run.py +++ b/api/controllers/openapi/app_run.py @@ -16,7 +16,8 @@ import services from controllers.openapi import openapi_ns from controllers.openapi._audit import emit_app_run from controllers.openapi._models import AppRunRequest -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -124,8 +125,9 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = { class AppRunApi(Resource): @openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__]) @openapi_ns.response(200, "Run result (SSE stream)") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def post(self, app_id: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() body = request.get_json(silent=True) or {} try: payload = AppRunRequest.model_validate(body) @@ -158,8 +160,9 @@ class AppRunApi(Resource): @openapi_ns.route("/apps//tasks//stop") class AppRunTaskStopApi(Resource): @openapi_ns.response(200, "Task stopped") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def post(self, app_id: str, task_id: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() AppQueueManager.set_stop_flag_no_user_check(task_id) GraphEngineManager(redis_client).send_stop_command(task_id) return {"result": "success"} diff --git a/api/controllers/openapi/apps.py b/api/controllers/openapi/apps.py index 8a3fc81809..d3bc4e4680 100644 --- a/api/controllers/openapi/apps.py +++ b/api/controllers/openapi/apps.py @@ -1,9 +1,4 @@ -"""GET /openapi/v1/apps and per-app reads. - -Decorator order: `method_decorators` is innermost-first. `validate_bearer` -is last → outermost → publishes the auth ContextVar before `require_scope` -reads it. -""" +"""GET /openapi/v1/apps and per-app reads.""" from __future__ import annotations @@ -28,31 +23,17 @@ from controllers.openapi._models import ( AppListRow, TagItem, ) -from controllers.openapi.auth.surface_gate import accept_subjects +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from controllers.service_api.app.error import AppUnavailableError from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from extensions.ext_database import db -from libs.oauth_bearer import ( - ACCEPT_USER_ANY, - AuthContext, - Scope, - SubjectType, - get_auth_ctx, - require_scope, - require_workspace_member, - validate_bearer, -) +from libs.oauth_bearer import Scope, TokenType from models import App from services.account_service import TenantService from services.app_service import AppListParams, AppService from services.tag_service import TagService -_APPS_READ_DECORATORS = [ - require_scope(Scope.APPS_READ), - accept_subjects(SubjectType.ACCOUNT), - validate_bearer(accept=ACCEPT_USER_ANY), -] - _ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"}) @@ -66,13 +47,9 @@ _EMPTY_PARAMETERS: dict[str, Any] = { class AppReadResource(Resource): - """Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks.""" - - method_decorators = _APPS_READ_DECORATORS - - def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]: - ctx: AuthContext = get_auth_ctx() + """Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks.""" + def _load(self, app_id: str, workspace_id: str | None = None) -> App: try: parsed_uuid = _uuid.UUID(app_id) is_uuid = True @@ -99,8 +76,7 @@ class AppReadResource(Resource): raise Conflict("".join(lines)) app = matches[0] - require_workspace_member(ctx, str(app.tenant_id)) - return app, ctx + return app def parameters_payload(app: App) -> dict: @@ -114,13 +90,14 @@ def parameters_payload(app: App) -> dict: class AppDescribeApi(AppReadResource): @openapi_ns.doc(params=query_params_from_model(AppDescribeQuery)) @openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__]) - def get(self, app_id: str): + @auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, app_id: str, *, auth_data: AuthData): try: query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True)) except ValidationError as exc: raise UnprocessableEntity(exc.json()) - app, _ = self._load(app_id, workspace_id=query.workspace_id) + app = self._load(app_id, workspace_id=query.workspace_id) requested = query.fields want_info = requested is None or "info" in requested @@ -168,20 +145,16 @@ class AppDescribeApi(AppReadResource): @openapi_ns.route("/apps") class AppListApi(Resource): - method_decorators = _APPS_READ_DECORATORS - @openapi_ns.doc(params=query_params_from_model(AppListQuery)) @openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__]) - def get(self): - ctx: AuthContext = get_auth_ctx() - + @auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, *, auth_data: AuthData): try: query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True)) except ValidationError as exc: raise UnprocessableEntity(exc.json()) workspace_id = query.workspace_id - require_workspace_member(ctx, workspace_id) empty = ( AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump( @@ -237,7 +210,7 @@ class AppListApi(Resource): openapi_visible=True, ) - pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params) + pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params) if pagination is None: return empty diff --git a/api/controllers/openapi/apps_permitted_external.py b/api/controllers/openapi/apps_permitted_external.py index 9359dca228..f86fd34a19 100644 --- a/api/controllers/openapi/apps_permitted_external.py +++ b/api/controllers/openapi/apps_permitted_external.py @@ -18,37 +18,27 @@ from controllers.openapi._models import ( PermittedExternalAppsListQuery, PermittedExternalAppsListResponse, ) -from controllers.openapi.auth.surface_gate import accept_subjects +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData, Edition from extensions.ext_database import db -from libs.device_flow_security import enterprise_only -from libs.oauth_bearer import ( - ACCEPT_USER_ANY, - Scope, - SubjectType, - require_scope, - validate_bearer, -) +from libs.oauth_bearer import Scope, TokenType from models import App from services.account_service import TenantService from services.app_service import AppService from services.enterprise.app_permitted_service import list_permitted_apps -from services.openapi.license_gate import license_required @openapi_ns.route("/permitted-external-apps") class PermittedExternalAppsListApi(Resource): - method_decorators = [ - require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL), - license_required, - accept_subjects(SubjectType.EXTERNAL_SSO), - validate_bearer(accept=ACCEPT_USER_ANY), - enterprise_only, - ] - @openapi_ns.response( 200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__] ) - def get(self): + @auth_router.guard( + scope=Scope.APPS_READ_PERMITTED_EXTERNAL, + allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}), + edition=frozenset({Edition.EE}), + ) + def get(self, *, auth_data: AuthData): try: query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True)) except ValidationError as exc: diff --git a/api/controllers/openapi/auth/__init__.py b/api/controllers/openapi/auth/__init__.py index 17ac5493d0..0460788c18 100644 --- a/api/controllers/openapi/auth/__init__.py +++ b/api/controllers/openapi/auth/__init__.py @@ -1,3 +1,3 @@ -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router -__all__ = ["OAUTH_BEARER_PIPELINE"] +__all__ = ["auth_router"] diff --git a/api/controllers/openapi/auth/composition.py b/api/controllers/openapi/auth/composition.py index 973ddd75a2..c2c3e12873 100644 --- a/api/controllers/openapi/auth/composition.py +++ b/api/controllers/openapi/auth/composition.py @@ -1,46 +1,64 @@ -"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints. - -Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative -paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip -the pipeline and use `validate_bearer + require_scope + require_workspace_member` -inline — they don't need `AppAuthzCheck`/`CallerMount`. -""" - from __future__ import annotations -from controllers.openapi.auth.pipeline import Pipeline -from controllers.openapi.auth.steps import ( - AppAuthzCheck, - AppResolver, - BearerCheck, - CallerMount, - ScopeCheck, - SurfaceCheck, - WorkspaceMembershipCheck, +from controllers.openapi.auth.conditions import ( + EDITION_CE, + EDITION_EE, + LOADED_APP_IS_PRIVATE, + PATH_HAS_APP_ID, + WEBAPP_AUTH_ENABLED, ) -from controllers.openapi.auth.strategies import ( - AccountMounter, - AclStrategy, - AppAuthzStrategy, - EndUserMounter, - MembershipStrategy, +from controllers.openapi.auth.data import Edition +from controllers.openapi.auth.flow import When +from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter +from controllers.openapi.auth.prepare import ( + load_account, + load_app, + load_app_access_mode, + load_tenant, + resolve_external_user, ) -from libs.oauth_bearer import SubjectType -from services.feature_service import FeatureService - - -def _resolve_app_authz_strategy() -> AppAuthzStrategy: - if FeatureService.get_system_features().webapp_auth.enabled: - return AclStrategy() - return MembershipStrategy() - - -OAUTH_BEARER_PIPELINE = Pipeline( - BearerCheck(), - SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})), - ScopeCheck(), - AppResolver(), - WorkspaceMembershipCheck(), - AppAuthzCheck(_resolve_app_authz_strategy), - CallerMount(AccountMounter(), EndUserMounter()), +from controllers.openapi.auth.verify import ( + check_acl, + check_app_access, + check_membership, + check_private_app_permission, + check_scope, +) +from libs.oauth_bearer import TokenType + +account_pipeline = AuthPipeline( + prepare=[ + When(PATH_HAS_APP_ID, then=load_app), + When(PATH_HAS_APP_ID, then=load_tenant), + load_account, # all tokens here are account tokens + When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode), + ], + auth=[ + check_scope, + When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership), + When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access), + When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl), + When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission), + ], +) + +external_sso_pipeline = AuthPipeline( + prepare=[ + When(PATH_HAS_APP_ID, then=load_app), + When(PATH_HAS_APP_ID, then=load_tenant), + When(PATH_HAS_APP_ID, then=resolve_external_user), + When(PATH_HAS_APP_ID, then=load_app_access_mode), + ], + auth=[ + check_scope, + When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl), + When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission), + ], +) + +auth_router = PipelineRouter( + { + TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline), + TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})), + } ) diff --git a/api/controllers/openapi/auth/conditions.py b/api/controllers/openapi/auth/conditions.py new file mode 100644 index 0000000000..2399fc04f1 --- /dev/null +++ b/api/controllers/openapi/auth/conditions.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from collections.abc import Callable + +from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition +from libs.oauth_bearer import TokenType +from services.enterprise.enterprise_service import WebAppAccessMode +from services.feature_service import FeatureService + +CondFn = Callable[[RequestContext, AuthData | None], bool] + + +class Cond: + def __init__(self, fn: CondFn) -> None: + self._fn = fn + + def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool: + return self._fn(ctx, data) + + def __and__(self, other: Cond) -> Cond: + return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data)) + + def __or__(self, other: Cond) -> Cond: + return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data)) + + def __invert__(self) -> Cond: + return Cond(lambda ctx, data: not self(ctx, data)) + + +def request_cond(fn: Callable[[RequestContext], bool]) -> Cond: + return Cond(lambda ctx, _: fn(ctx)) + + +def data_cond(fn: Callable[[AuthData], bool]) -> Cond: + return Cond(lambda _, data: data is not None and fn(data)) + + +def config_cond(fn: Callable[[], bool]) -> Cond: + return Cond(lambda _, __: fn()) + + +TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT) +TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO) + +PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params) + +EDITION_CE = config_cond(lambda: current_edition() == Edition.CE) +EDITION_EE = config_cond(lambda: current_edition() == Edition.EE) +EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS) + +WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled) + +LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE) diff --git a/api/controllers/openapi/auth/context.py b/api/controllers/openapi/auth/context.py deleted file mode 100644 index 95013627f0..0000000000 --- a/api/controllers/openapi/auth/context.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Mutable per-request context for the openapi auth pipeline. - -Every field starts None / empty and is filled in by a step. The pipeline -is the only thing that should construct or mutate Context — handlers -read populated values via the decorator's kwargs unpacking. - -Context is intentionally decoupled from Flask's ``Request``: the pipeline -guard extracts whatever transport-level inputs the steps need (bearer -token, path params) at the boundary and writes them into Context fields, -so steps stay testable without a request object and won't leak coupling -to a specific framework. -""" - -from __future__ import annotations - -import uuid -from collections.abc import Mapping -from contextvars import Token -from dataclasses import dataclass, field -from datetime import datetime -from typing import TYPE_CHECKING, Literal, Protocol - -from werkzeug.exceptions import Unauthorized - -from libs.oauth_bearer import AuthContext, Scope, SubjectType - -if TYPE_CHECKING: - from models import App, Tenant - - -@dataclass -class Context: - required_scope: Scope - bearer_token: str | None = None - path_params: Mapping[str, str] = field(default_factory=dict) - subject_type: SubjectType | None = None - subject_email: str | None = None - subject_issuer: str | None = None - account_id: uuid.UUID | None = None - scopes: frozenset[Scope] = field(default_factory=frozenset) - token_id: uuid.UUID | None = None - token_hash: str | None = None - cached_verified_tenants: dict[str, bool] | None = None - source: str | None = None - expires_at: datetime | None = None - app: App | None = None - tenant: Tenant | None = None - caller: object | None = None - caller_kind: Literal["account", "end_user"] | None = None - auth_ctx_reset_token: Token[AuthContext] | None = None - - @property - def must_tenant(self) -> Tenant: - if not self.tenant: - raise Unauthorized("tenant is not associated") - return self.tenant - - @property - def must_subject_type(self) -> SubjectType: - if not self.subject_type: - raise Unauthorized("subject_type unset — BearerCheck did not run") - return self.subject_type - - -class Step(Protocol): - """One responsibility. Mutate ctx or raise to short-circuit.""" - - def __call__(self, ctx: Context) -> None: ... diff --git a/api/controllers/openapi/auth/data.py b/api/controllers/openapi/auth/data.py new file mode 100644 index 0000000000..30973b5e9b --- /dev/null +++ b/api/controllers/openapi/auth/data.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import uuid +from enum import StrEnum +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field +from werkzeug.exceptions import InternalServerError + +from configs import dify_config +from libs.oauth_bearer import Scope, TokenType +from models.account import Account, Tenant +from models.model import App, EndUser +from services.enterprise.enterprise_service import WebAppAccessMode + + +class Edition(StrEnum): + CE = "ce" + EE = "ee" + SAAS = "saas" + + +def current_edition() -> Edition: + if dify_config.EDITION == "CLOUD": + return Edition.SAAS + if dify_config.ENTERPRISE_ENABLED: + return Edition.EE + return Edition.CE + + +class ExternalIdentity(BaseModel): + model_config = ConfigDict(frozen=True) + + email: str + issuer: str | None = None + + +class RequestContext(BaseModel): + model_config = ConfigDict(frozen=True) + + token_type: TokenType + scope: Scope | None = None + path_params: dict[str, str] + + +class AuthData(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + required_scope: Scope | None = None + token_type: TokenType + account_id: uuid.UUID | None = None + token_hash: str + token_id: uuid.UUID | None = None + scopes: frozenset[Scope] + tenants: dict[str, bool] = Field(default_factory=dict) + external_identity: ExternalIdentity | None = None + path_params: dict[str, str] = Field(default_factory=dict) + + app: App | None = None + tenant: Tenant | None = None + app_access_mode: WebAppAccessMode | None = None + + caller: Account | EndUser | None = None + caller_kind: Literal["account", "end_user"] | None = None + + def require_app_context(self) -> tuple[App, Account | EndUser, Literal["account", "end_user"]]: + if self.app is None or self.caller is None or self.caller_kind is None: + raise InternalServerError("pipeline_invariant_violated: app context missing") + return self.app, self.caller, self.caller_kind diff --git a/api/controllers/openapi/auth/flow.py b/api/controllers/openapi/auth/flow.py new file mode 100644 index 0000000000..eee1378cf4 --- /dev/null +++ b/api/controllers/openapi/auth/flow.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from controllers.openapi.auth.conditions import Cond +from controllers.openapi.auth.data import AuthData, RequestContext + + +class When: + def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None: + self.condition = condition + self._step = then + + def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool: + return self.condition(ctx, data) + + def __call__(self, arg: Any) -> None: + self._step(arg) diff --git a/api/controllers/openapi/auth/pipeline.py b/api/controllers/openapi/auth/pipeline.py index 096b1b7ea3..e992e5e5ab 100644 --- a/api/controllers/openapi/auth/pipeline.py +++ b/api/controllers/openapi/auth/pipeline.py @@ -1,51 +1,209 @@ -"""Pipeline IS the auth scheme. +"""Auth pipeline — entry point for all openapi auth. -`Pipeline.guard(scope=…)` is the only attachment point for endpoints — -that is the design lock-in: forgetting an auth layer is structurally -impossible because there is no "sometimes wrap, sometimes don't" choice. +`PipelineRouter.guard()` is the only attachment point for endpoints. +`AuthPipeline` is a pure step-runner with no routing concerns. +`PipelineRoute` binds a pipeline to optional edition requirements. """ from __future__ import annotations +from collections.abc import Callable +from dataclasses import dataclass from functools import wraps +from typing import Any -from flask import request +from flask import current_app, request +from flask_login import user_logged_in +from werkzeug.exceptions import Forbidden, NotFound, Unauthorized -from controllers.openapi.auth.context import Context, Step -from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx +from controllers.openapi._audit import emit_wrong_surface +from controllers.openapi.auth.data import ( + AuthData, + Edition, + ExternalIdentity, + RequestContext, + current_edition, +) +from controllers.openapi.auth.flow import When +from libs.oauth_bearer import ( + AuthContext, + Scope, + TokenType, + extract_bearer, + get_authenticator, + reset_auth_ctx, + set_auth_ctx, +) +from services.feature_service import FeatureService, LicenseStatus -class Pipeline: - def __init__(self, *steps: Step) -> None: - self._steps = steps +class AuthPipeline: + """Pure step-runner — no routing, no guard. - def run(self, ctx: Context) -> None: - for step in self._steps: - step(ctx) + Both `prepare` and `auth` steps receive the same `AuthData` instance. + `prepare` steps populate it; `auth` steps validate it. + """ - def guard(self, *, scope: Scope): - def decorator(view): + def __init__(self, prepare: list, auth: list) -> None: + self._prepare = prepare + self._auth = auth + + def _run( + self, + identity: AuthContext, + args: tuple, + kwargs: dict, + view: Callable, + *, + scope: Scope | None, + ) -> Any: + req_ctx = RequestContext( + token_type=identity.token_type, + scope=scope, + path_params=dict(request.view_args or {}), + ) + + data = AuthData( + token_type=identity.token_type, + account_id=identity.account_id, + token_hash=identity.token_hash, + token_id=identity.token_id, + scopes=frozenset(identity.scopes), + tenants=dict(identity.verified_tenants), + required_scope=scope, + path_params=dict(req_ctx.path_params), + external_identity=( + ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer) + if identity.subject_email + else None + ), + ) + + for step in self._prepare: + if _should_run(step, req_ctx, data=None): + step(data) + + for step in self._auth: + if _should_run(step, req_ctx, data=data): + step(data) + + reset_token = set_auth_ctx(identity) + if data.caller: + _mount_flask_login(data.caller) + + try: + kwargs["auth_data"] = data + return view(*args, **kwargs) + finally: + reset_auth_ctx(reset_token) + + +@dataclass(frozen=True) +class PipelineRoute: + pipeline: AuthPipeline + required_edition: frozenset[Edition] | None = None + + +class PipelineRouter: + """Entry point for openapi auth. + + `guard()` is the decorator that endpoints attach to. It applies + global gates (edition, token type) then dispatches to the matching + `PipelineRoute` for the token type. + """ + + def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None: + self._routes = routes + + def guard( + self, + *, + scope: Scope | None = None, + allowed_token_types: frozenset[TokenType] | None = None, + edition: frozenset[Edition] | None = None, + ) -> Callable: + def decorator(view: Callable) -> Callable: @wraps(view) - def decorated(*args, **kwargs): - # Extract transport-level inputs at the boundary so steps - # stay decoupled from Flask's request object. - ctx = Context( - required_scope=scope, - bearer_token=extract_bearer(request), - path_params=dict(request.view_args or {}), + def decorated(*args: Any, **kwargs: Any) -> Any: + return self._execute( + args, + kwargs, + view, + scope=scope, + allowed_token_types=allowed_token_types, + edition=edition, ) - try: - self.run(ctx) - kwargs.update( - app_model=ctx.app, - caller=ctx.caller, - caller_kind=ctx.caller_kind, - ) - return view(*args, **kwargs) - finally: - if ctx.auth_ctx_reset_token is not None: - reset_auth_ctx(ctx.auth_ctx_reset_token) return decorated return decorator + + def _execute( + self, + args: tuple, + kwargs: dict, + view: Callable, + *, + scope: Scope | None, + allowed_token_types: frozenset[TokenType] | None, + edition: frozenset[Edition] | None, + ) -> Any: + # 404 not 403 — this edition doesn't expose the feature at all + if edition is not None and current_edition() not in edition: + raise NotFound() + + license_checked = False + if edition is not None and Edition.EE in edition: + _check_license() + license_checked = True + + token = extract_bearer(request) + if not token: + raise Unauthorized("bearer required") + + identity = get_authenticator().authenticate(token) + + if allowed_token_types is not None and identity.token_type not in allowed_token_types: + emit_wrong_surface( + subject_type=_subject_type_str(identity), + attempted_path=request.path, + client_id=getattr(identity, "client_id", None), + token_id=str(identity.token_id) if identity.token_id else None, + ) + raise Forbidden("unsupported_token_type") + + route = self._routes.get(identity.token_type) + if route is None: + raise Forbidden("unsupported_token_type") + + if route.required_edition is not None: + if current_edition() not in route.required_edition: + raise Forbidden("external_sso_requires_ee") + if not license_checked and Edition.EE in route.required_edition: + _check_license() + + return route.pipeline._run(identity, args, kwargs, view, scope=scope) + + +def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool: + if isinstance(step, When): + return step.applies(req_ctx, data) + return True + + +def _subject_type_str(identity: Any) -> str | None: + subject = getattr(identity, "subject_type", None) + if subject is None: + return None + return subject.value if hasattr(subject, "value") else str(subject) + + +def _check_license() -> None: + settings = FeatureService.get_system_features() + if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}: + raise Forbidden("license_invalid") + + +def _mount_flask_login(user: Any) -> None: + current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined] + user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined] diff --git a/api/controllers/openapi/auth/prepare.py b/api/controllers/openapi/auth/prepare.py new file mode 100644 index 0000000000..fe6e031b50 --- /dev/null +++ b/api/controllers/openapi/auth/prepare.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized + +from controllers.openapi.auth.data import AuthData +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from models.account import TenantStatus +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.end_user_service import EndUserService +from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode + + +def load_app(data: AuthData) -> None: + app_id = data.path_params["app_id"] + app = AppService.get_app_by_id(db.session, app_id) + if not app or app.status != "normal": + raise NotFound("app not found") + if not app.enable_api: + raise Forbidden("service_api_disabled") + data.app = app + + +def load_tenant(data: AuthData) -> None: + if data.app is None: + raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant") + tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id)) + if tenant is None or tenant.status == TenantStatus.ARCHIVE: + raise Forbidden("workspace unavailable") + data.tenant = tenant + + +def load_account(data: AuthData) -> None: + account = AccountService.get_account_by_id(db.session, str(data.account_id)) + if account is None: + raise Unauthorized("account not found") + if data.tenant: + account.current_tenant = data.tenant + data.caller = account + data.caller_kind = "account" + + +def resolve_external_user(data: AuthData) -> None: + if data.tenant is None or data.app is None or data.external_identity is None: + raise Unauthorized("missing context for external user resolution") + end_user = EndUserService.get_or_create_end_user_by_type( + InvokeFrom.OPENAPI, + tenant_id=str(data.tenant.id), + app_id=str(data.app.id), + user_id=data.external_identity.email, + ) + data.caller = end_user + data.caller_kind = "end_user" + + +def load_app_access_mode(data: AuthData) -> None: + if data.app is None: + return + try: + settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(data.app.id)) + if settings is None: + data.app_access_mode = None + return + data.app_access_mode = WebAppAccessMode(settings.access_mode) + except ValueError: + data.app_access_mode = None diff --git a/api/controllers/openapi/auth/role_gate.py b/api/controllers/openapi/auth/role_gate.py new file mode 100644 index 0000000000..c5e266f63d --- /dev/null +++ b/api/controllers/openapi/auth/role_gate.py @@ -0,0 +1,77 @@ +"""Workspace role gate. + +Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)` +for routes whose access depends on the caller's `TenantAccountJoin.role` +in the workspace named by the `workspace_id` path parameter. + +Usage:: + + @openapi_ns.route("/workspaces//members") + class Members(Resource): + @validate_bearer(accept=ACCEPT_USER_ANY) + @accept_subjects(SubjectType.ACCOUNT) + @require_workspace_role() # any member + def get(self, workspace_id: str): ... + + @validate_bearer(accept=ACCEPT_USER_ANY) + @accept_subjects(SubjectType.ACCOUNT) + @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) + def post(self, workspace_id: str): ... + +Non-member callers get 404 (matching `GET /openapi/v1/workspaces/`) +so workspace IDs do not leak across tenants. A member without one of the +allowed roles gets 403. +""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import TypeVar + +from werkzeug.exceptions import Forbidden, NotFound + +from extensions.ext_database import db +from libs.oauth_bearer import try_get_auth_ctx +from models.account import TenantAccountRole +from services.account_service import TenantService + +F = TypeVar("F", bound=Callable[..., object]) + + +def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]: + """Gate a route on the caller's role in ``workspace_id``. + + Pass no roles to require only membership. Pass one or more roles to + require the caller's role be in that set. + """ + + allowed = frozenset(allowed_roles) + + def deco(fn: F) -> F: + @wraps(fn) + def wrapper(*args: object, **kwargs: object) -> object: + ctx = try_get_auth_ctx() + if ctx is None or ctx.account_id is None: + raise RuntimeError( + "require_workspace_role called without account-bearer context; " + "stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it" + ) + + workspace_id = kwargs.get("workspace_id") + if not workspace_id: + raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter") + + role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id)) + + if role is None: + raise NotFound("workspace not found") + + if allowed and role not in allowed: + raise Forbidden("insufficient workspace role") + + return fn(*args, **kwargs) + + return wrapper # type: ignore[return-value] + + return deco diff --git a/api/controllers/openapi/auth/steps.py b/api/controllers/openapi/auth/steps.py deleted file mode 100644 index 40a168b489..0000000000 --- a/api/controllers/openapi/auth/steps.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Pipeline steps. Each is one responsibility. - -`BearerCheck` is the only step that touches the token registry; downstream -steps see only the populated `Context`. `BearerCheck` also publishes the -resolved identity to the openapi auth ``ContextVar`` (the same one the -decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the -surface gate and any handler reading the request-scoped context has a single -source of truth across both auth-attach paths. The reset token is stashed -on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in -its `finally` so worker-thread reuse can't leak identity across requests. -""" - -from __future__ import annotations - -from collections.abc import Callable - -from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized - -from configs import dify_config -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter -from controllers.openapi.auth.surface_gate import check_surface -from extensions.ext_database import db -from libs.oauth_bearer import ( - AuthContext, - InvalidBearerError, - Scope, - SubjectType, - check_workspace_membership, - get_authenticator, - set_auth_ctx, -) -from models import TenantStatus -from services.account_service import TenantService -from services.app_service import AppService - - -class BearerCheck: - """Resolve bearer → populate identity fields. Rate-limit is enforced - inside `BearerAuthenticator.authenticate`, so no separate step here. - Also publishes the resolved `AuthContext` via - :func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level - ``validate_bearer`` writes — so the surface gate + downstream readers - don't see two different identity sources. The reset token is parked on - ``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume.""" - - def __call__(self, ctx: Context) -> None: - if not ctx.bearer_token: - raise Unauthorized("bearer required") - - try: - authn = get_authenticator().authenticate(ctx.bearer_token) - except InvalidBearerError as e: - raise Unauthorized(str(e)) - - ctx.subject_type = authn.subject_type - ctx.subject_email = authn.subject_email - ctx.subject_issuer = authn.subject_issuer - ctx.account_id = authn.account_id - ctx.scopes = frozenset(authn.scopes) - ctx.source = authn.source - ctx.token_id = authn.token_id - ctx.expires_at = authn.expires_at - ctx.token_hash = authn.token_hash - ctx.cached_verified_tenants = dict(authn.verified_tenants) - ctx.auth_ctx_reset_token = set_auth_ctx(authn) - - -class ScopeCheck: - """Verify ctx.scopes (already populated by BearerCheck) covers required.""" - - def __call__(self, ctx: Context) -> None: - if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes: - return - raise Forbidden("insufficient_scope") - - -class SurfaceCheck: - """Reject the request if the resolved subject is not in `accepted`.""" - - def __init__(self, *, accepted: frozenset[SubjectType]) -> None: - self._accepted = accepted - - def __call__(self, ctx: Context) -> None: - check_surface(self._accepted) - - -class AppResolver: - """Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant. - - Every endpoint using the OAuth bearer pipeline must declare - ```` in its route — that is the design lock-in (no body / - header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into - ``ctx.path_params`` at the boundary so this step doesn't need to know - about the request object. - """ - - def __call__(self, ctx: Context) -> None: - app_id = ctx.path_params.get("app_id") - if not app_id: - raise BadRequest("app_id is required in path") - app = AppService.get_app_by_id(db.session, app_id) - if not app or app.status != "normal": - raise NotFound("app not found") - if not app.enable_api: - raise Forbidden("service_api_disabled") - tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id)) - if tenant is None or tenant.status == TenantStatus.ARCHIVE: - raise Forbidden("workspace unavailable") - ctx.app, ctx.tenant = app, tenant - - -class WorkspaceMembershipCheck: - """Layer 0 — workspace membership gate. - - CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers - (dfoa_) only — SSO subjects skip. - """ - - def __call__(self, ctx: Context) -> None: - if dify_config.ENTERPRISE_ENABLED: - return - if ctx.subject_type != SubjectType.ACCOUNT: - return - if ctx.account_id is None or ctx.tenant is None: - raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run") - if ctx.token_hash is None: - raise Unauthorized("token_hash unset — BearerCheck did not run") - - check_workspace_membership( - account_id=ctx.account_id, - tenant_id=ctx.must_tenant.id, - token_hash=ctx.token_hash, - cached_verdicts=ctx.cached_verified_tenants or {}, - ) - - -class AppAuthzCheck: - def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None: - self._resolve = resolve_strategy - - def __call__(self, ctx: Context) -> None: - if not self._resolve().authorize(ctx): - raise Forbidden("subject_no_app_access") - - -class CallerMount: - def __init__(self, *mounters: CallerMounter) -> None: - self._mounters = mounters - - def __call__(self, ctx: Context) -> None: - if ctx.subject_type is None: - raise Unauthorized("subject_type unset — BearerCheck did not run") - for m in self._mounters: - if m.applies_to(ctx.must_subject_type): - m.mount(ctx) - return - raise Unauthorized("no caller mounter for subject type") - - -__all__ = [ - "AppAuthzCheck", - "AppResolver", - "AuthContext", - "BearerCheck", - "CallerMount", - "ScopeCheck", - "SurfaceCheck", - "WorkspaceMembershipCheck", -] diff --git a/api/controllers/openapi/auth/strategies.py b/api/controllers/openapi/auth/strategies.py deleted file mode 100644 index aaaaadd948..0000000000 --- a/api/controllers/openapi/auth/strategies.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Strategy classes for the openapi auth pipeline. - -App authorization (Acl/Membership) and caller mounting (Account/EndUser) -vary along independent axes; each strategy is one class so the pipeline -composition stays a flat list. -""" - -from __future__ import annotations - -from typing import Protocol - -from flask import current_app -from flask_login import user_logged_in - -from controllers.openapi.auth.context import Context -from core.app.entities.app_invoke_entities import InvokeFrom -from extensions.ext_database import db -from libs.oauth_bearer import SubjectType -from services.account_service import AccountService, TenantService -from services.end_user_service import EndUserService -from services.enterprise.enterprise_service import ( - EnterpriseService, - WebAppAccessMode, -) - - -class AppAuthzStrategy(Protocol): - def authorize(self, ctx: Context) -> bool: ... - - -class AclStrategy: - """Per-app ACL, evaluated in two stages. - - The EE gateway has already enforced tenancy and workspace membership - by the time this strategy runs, so AclStrategy only owns per-app ACL: - - 1. Subject vs access-mode compatibility (pure rule table). External-SSO - bearers belong to public-facing apps only; account bearers cover the - full set. A mismatch is an immediate deny — no IO. - 2. For modes that pair with the subject, decide whether the inner - permission API must run. Only `PRIVATE` (per-app selected-user list) - requires it; the remaining modes are pass-through. - """ - - _ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = { - SubjectType.ACCOUNT: frozenset( - { - WebAppAccessMode.PUBLIC, - WebAppAccessMode.SSO_VERIFIED, - WebAppAccessMode.PRIVATE_ALL, - WebAppAccessMode.PRIVATE, - } - ), - SubjectType.EXTERNAL_SSO: frozenset( - { - WebAppAccessMode.PUBLIC, - WebAppAccessMode.SSO_VERIFIED, - } - ), - } - - _MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE}) - - def authorize(self, ctx: Context) -> bool: - if ctx.app is None: - return False - access_mode = self._fetch_access_mode(ctx.app.id) - if access_mode is None: - return False - if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode): - return False - if access_mode not in self._MODES_REQUIRING_INNER_CHECK: - return True - return self._inner_permission_check(ctx) - - @staticmethod - def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None: - settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) - if settings is None: - return None - try: - return WebAppAccessMode(settings.access_mode) - except ValueError: - return None - - @classmethod - def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool: - return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset()) - - def _inner_permission_check(self, ctx: Context) -> bool: - if ctx.app is None: - return False - user_id = self._resolve_user_id(ctx) - if user_id is None: - return False - return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( - user_id=user_id, - app_id=ctx.app.id, - ) - - @staticmethod - def _resolve_user_id(ctx: Context) -> str | None: - if ctx.subject_type == SubjectType.ACCOUNT: - return str(ctx.account_id) if ctx.account_id is not None else None - if ctx.subject_email is None: - return None - account = AccountService.get_account_by_email(db.session, ctx.subject_email) - return str(account.id) if account is not None else None - - -class MembershipStrategy: - """Tenant-membership fallback. - - Used when webapp-auth is disabled (CE deployment). Account-bearing - subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is - denied (it requires the webapp-auth surface). - """ - - def authorize(self, ctx: Context) -> bool: - if ctx.subject_type == SubjectType.EXTERNAL_SSO: - return False - if ctx.tenant is None: - return False - return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id) - - -def _login_as(user) -> None: - """Set Flask-Login request user so downstream services see the caller.""" - current_app.login_manager._update_request_context_with_user(user) # type:ignore - user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore - - -class CallerMounter(Protocol): - def applies_to(self, subject_type: SubjectType) -> bool: ... - - def mount(self, ctx: Context) -> None: ... - - -class AccountMounter: - def applies_to(self, subject_type: SubjectType) -> bool: - return subject_type == SubjectType.ACCOUNT - - def mount(self, ctx: Context) -> None: - if ctx.account_id is None: - raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run") - account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) - if account is None: - raise RuntimeError("AccountMounter: account row missing for resolved bearer") - account.current_tenant = ctx.must_tenant - _login_as(account) - ctx.caller, ctx.caller_kind = account, "account" - - -class EndUserMounter: - def applies_to(self, subject_type: SubjectType) -> bool: - return subject_type == SubjectType.EXTERNAL_SSO - - def mount(self, ctx: Context) -> None: - if ctx.tenant is None or ctx.app is None or ctx.subject_email is None: - raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run") - end_user = EndUserService.get_or_create_end_user_by_type( - InvokeFrom.OPENAPI, - tenant_id=ctx.tenant.id, - app_id=ctx.app.id, - user_id=ctx.subject_email, - ) - _login_as(end_user) - ctx.caller, ctx.caller_kind = end_user, "end_user" diff --git a/api/controllers/openapi/auth/verify.py b/api/controllers/openapi/auth/verify.py new file mode 100644 index 0000000000..22410b3374 --- /dev/null +++ b/api/controllers/openapi/auth/verify.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from werkzeug.exceptions import Forbidden, Unauthorized + +from controllers.openapi.auth.data import AuthData +from extensions.ext_database import db +from libs.oauth_bearer import Scope, TokenType, check_workspace_membership +from services.account_service import AccountService, TenantService +from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode + + +def check_scope(data: AuthData) -> None: + if data.required_scope is None: + return + if Scope.FULL in data.scopes or data.required_scope in data.scopes: + return + raise Forbidden("insufficient_scope") + + +def check_membership(data: AuthData) -> None: + if data.tenant is None: + raise Unauthorized("tenant unset") + if data.account_id is None: + raise Unauthorized("account_id unset") + check_workspace_membership( + account_id=data.account_id, + tenant_id=data.tenant.id, + token_hash=data.token_hash, + membership_cache=data.tenants, + ) + + +def check_app_access(data: AuthData) -> None: + if data.tenant is None: + return + if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id): + raise Forbidden("subject_no_app_access") + + +_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = { + TokenType.OAUTH_ACCOUNT: frozenset( + { + WebAppAccessMode.PUBLIC, + WebAppAccessMode.SSO_VERIFIED, + WebAppAccessMode.PRIVATE_ALL, + WebAppAccessMode.PRIVATE, + } + ), + TokenType.OAUTH_EXTERNAL_SSO: frozenset( + { + WebAppAccessMode.PUBLIC, + WebAppAccessMode.SSO_VERIFIED, + } + ), +} + + +def check_acl(data: AuthData) -> None: + if data.app is None or data.app_access_mode is None: + raise Forbidden("app or access mode not loaded") + allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset()) + if data.app_access_mode not in allowed_modes: + raise Forbidden("subject_not_allowed_for_access_mode") + + +def check_private_app_permission(data: AuthData) -> None: + if data.app is None: + raise Forbidden("app not loaded") + user_id = _resolve_user_id(data) + if user_id is None: + raise Forbidden("cannot resolve user for private app check") + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id=user_id, app_id=data.app.id): + raise Forbidden("user_not_allowed_for_private_app") + + +def _resolve_user_id(data: AuthData) -> str | None: + if data.token_type == TokenType.OAUTH_ACCOUNT: + return str(data.account_id) if data.account_id is not None else None + if data.external_identity is None: + return None + account = AccountService.get_account_by_email(db.session, data.external_identity.email) + return str(account.id) if account is not None else None diff --git a/api/controllers/openapi/files.py b/api/controllers/openapi/files.py index eb16015821..1a2c16abf9 100644 --- a/api/controllers/openapi/files.py +++ b/api/controllers/openapi/files.py @@ -17,11 +17,11 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from controllers.openapi import openapi_ns -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from extensions.ext_database import db from fields.file_fields import FileResponse from libs.oauth_bearer import Scope -from models import Account, App from services.file_service import FileService @@ -39,8 +39,9 @@ class AppFileUploadApi(Resource): } ) @openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__]) - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def post(self, app_id: str, *, auth_data: AuthData): + app_model, caller, _ = auth_data.require_app_context() if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: diff --git a/api/controllers/openapi/human_input_form.py b/api/controllers/openapi/human_input_form.py index 7d54140efd..3c359406be 100644 --- a/api/controllers/openapi/human_input_form.py +++ b/api/controllers/openapi/human_input_form.py @@ -17,7 +17,8 @@ from werkzeug.exceptions import BadRequest, NotFound from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values from controllers.common.schema import register_schema_models from controllers.openapi import openapi_ns -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface from extensions.ext_database import db from libs.helper import to_timestamp @@ -55,8 +56,9 @@ def _ensure_form_is_allowed_for_openapi(form) -> None: @openapi_ns.route("/apps//form/human_input/") class OpenApiWorkflowHumanInputFormApi(Resource): @openapi_ns.response(200, "Form definition") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def get(self, app_id: str, form_token: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() service = HumanInputService(db.engine) form = service.get_form_by_token(form_token) if form is None: @@ -69,8 +71,9 @@ class OpenApiWorkflowHumanInputFormApi(Resource): @openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__]) @openapi_ns.response(200, "Form submitted") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def post(self, app_id: str, form_token: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {}) service = HumanInputService(db.engine) diff --git a/api/controllers/openapi/workflow_events.py b/api/controllers/openapi/workflow_events.py index b14b2d400f..f21306e491 100644 --- a/api/controllers/openapi/workflow_events.py +++ b/api/controllers/openapi/workflow_events.py @@ -17,7 +17,8 @@ from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound, UnprocessableEntity from controllers.openapi import openapi_ns -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -28,7 +29,7 @@ from core.workflow.human_input_policy import HumanInputSurface from extensions.ext_database import db from libs.oauth_bearer import Scope from models.enums import CreatorUserRole -from models.model import App, AppMode +from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory from services.workflow_event_snapshot_service import build_workflow_event_stream @@ -36,8 +37,9 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream @openapi_ns.route("/apps//tasks//events") class OpenApiWorkflowEventsApi(Resource): @openapi_ns.response(200, "SSE event stream") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def get(self, app_id: str, task_id: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: raise UnprocessableEntity("mode_not_supported_for_event_reconnect") diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index 5fc1e1178d..b23012a810 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -1,41 +1,129 @@ -"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed -counterparts to the cookie-authed /console/api/workspaces endpoints. +"""User-scoped workspace reads and member management under /openapi/v1/workspaces. -Account bearers (dfoa_) see every tenant they're a member of. External -SSO bearers (dfoe_) have no account_id and so see an empty list — that -matches /openapi/v1/account. +Bearer-authed counterparts to the cookie-authed /console/api/workspaces +endpoints. Account bearers (dfoa_) see every tenant they're a member of. +External SSO bearers (dfoe_) have no account_id and so see an empty list — +that matches /openapi/v1/account. + +Member-management endpoints are gated by both `accept_subjects` (SSO out) +and `require_workspace_role` (membership / role lookup against the path's +``workspace_id``). """ from __future__ import annotations from itertools import starmap +from urllib import parse +from flask import jsonify, make_response, request from flask_restx import Resource -from werkzeug.exceptions import NotFound +from pydantic import BaseModel, ValidationError +from werkzeug.exceptions import BadRequest, Forbidden, NotFound +from configs import dify_config +from controllers.common.schema import query_params_from_model from controllers.openapi import openapi_ns -from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse -from controllers.openapi.auth.surface_gate import accept_subjects -from extensions.ext_database import db -from libs.oauth_bearer import ( - ACCEPT_USER_ANY, - SubjectType, - get_auth_ctx, - validate_bearer, +from controllers.openapi._models import ( + MemberActionResponse, + MemberInvitePayload, + MemberInviteResponse, + MemberListQuery, + MemberListResponse, + MemberResponse, + MemberRoleUpdatePayload, + WorkspaceDetailResponse, + WorkspaceListResponse, + WorkspaceSummaryResponse, ) -from models import Tenant, TenantAccountJoin -from services.account_service import TenantService +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData +from controllers.openapi.auth.role_gate import require_workspace_role +from extensions.ext_database import db +from libs.oauth_bearer import Scope, TokenType +from models import Account, Tenant, TenantAccountJoin +from models.account import TenantAccountRole, TenantStatus +from services.account_service import AccountService, RegisterService, TenantService +from services.errors.account import ( + AccountAlreadyInTenantError, + AccountNotLinkTenantError, + AccountRegisterError, + CannotOperateSelfError, + MemberNotInTenantError, + NoPermissionError, + RoleAlreadyAssignedError, +) +from services.feature_service import FeatureService + + +def _validate_body[M: BaseModel](model: type[M]) -> M: + body = request.get_json(silent=True) or {} + try: + return model.model_validate(body) + except ValidationError as exc: + raise BadRequest(str(exc)) + + +def _member_response(account: Account) -> MemberResponse: + return MemberResponse( + id=str(account.id), + name=account.name, + email=account.email, + role=account.role.value if account.role else "", + status=account.status.value if account.status else "", + avatar=account.avatar, + ) + + +def _load_tenant(workspace_id: str) -> Tenant: + tenant = TenantService.get_tenant_by_id(db.session, workspace_id) + if tenant is None or tenant.status != TenantStatus.NORMAL: + raise NotFound("workspace not found") + return tenant + + +def _load_account(account_id: object) -> Account: + account = AccountService.get_account_by_id(db.session, str(account_id)) if account_id else None + if account is None: + raise RuntimeError("authenticated account_id has no Account row") + return account + + +def _quota_error(*, code: str, message: str, hint: str) -> Forbidden: + err = Forbidden(message) + err.response = make_response( + jsonify({"code": code, "message": message, "hint": hint}), + 403, + ) + return err + + +def _check_member_invite_quota(tenant_id: str) -> None: + features = FeatureService.get_features(tenant_id) + + if features.billing.enabled: + members = features.members + if 0 < members.limit <= members.size: + raise _quota_error( + code="members.limit_exceeded", + message="Subscription member limit reached.", + hint="Upgrade your plan to invite more members or remove an existing member first.", + ) + + if features.workspace_members.enabled: + if not features.workspace_members.is_available(1): + raise _quota_error( + code="workspace_members.license_exceeded", + message="Workspace member license capacity reached.", + hint="Contact your workspace administrator to expand the license seat count.", + ) @openapi_ns.route("/workspaces") class WorkspacesApi(Resource): @openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) - def get(self): - ctx = get_auth_ctx() - - rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id)) + @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, *, auth_data: AuthData): + rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id)) return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200 @@ -43,12 +131,9 @@ class WorkspacesApi(Resource): @openapi_ns.route("/workspaces/") class WorkspaceByIdApi(Resource): @openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) - def get(self, workspace_id: str): - ctx = get_auth_ctx() - - row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id) + @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, workspace_id: str, *, auth_data: AuthData): + row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id) # 404 (not 403) on non-member so workspace IDs don't leak across tenants. if row is None: raise NotFound("workspace not found") @@ -57,6 +142,172 @@ class WorkspaceByIdApi(Resource): return _workspace_detail(tenant, membership).model_dump(mode="json"), 200 +@openapi_ns.route("/workspaces//switch") +class WorkspaceSwitchApi(Resource): + """Server-side switch — equivalent to the console's POST /workspaces/switch. + + CLI `difyctl use workspace ` calls this; it does NOT mutate + ``hosts.yml`` on its own. Failure here must abort the local write so + that ``hosts.yml`` never diverges from the server's ``current`` state. + """ + + @openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__]) + @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @require_workspace_role() + def post(self, workspace_id: str, *, auth_data: AuthData): + account = _load_account(auth_data.account_id) + + try: + TenantService.switch_tenant(account, workspace_id) + except AccountNotLinkTenantError: + raise NotFound("workspace not found") + + row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id) + if row is None: + raise NotFound("workspace not found") + tenant, membership = row + return _workspace_detail(tenant, membership).model_dump(mode="json"), 200 + + +@openapi_ns.route("/workspaces//members") +class WorkspaceMembersApi(Resource): + """List + invite members. + + GET is any-member. POST requires admin/owner — owner can never be + assigned through invite (ownership transfer is console-only). + """ + + @openapi_ns.doc(params=query_params_from_model(MemberListQuery)) + @openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__]) + @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @require_workspace_role() + def get(self, workspace_id: str, *, auth_data: AuthData): + try: + query = MemberListQuery.model_validate(request.args.to_dict(flat=True)) + except ValidationError as exc: + raise BadRequest(str(exc)) + + tenant = _load_tenant(workspace_id) + members = TenantService.get_tenant_members(tenant) + total = len(members) + start = (query.page - 1) * query.limit + page_items = members[start : start + query.limit] + return MemberListResponse( + page=query.page, + limit=query.limit, + total=total, + has_more=query.page * query.limit < total, + data=[_member_response(m) for m in page_items], + ).model_dump(mode="json"), 200 + + @openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__]) + @openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__]) + @auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) + def post(self, workspace_id: str, *, auth_data: AuthData): + payload = _validate_body(MemberInvitePayload) + inviter = _load_account(auth_data.account_id) + tenant = _load_tenant(workspace_id) + + _check_member_invite_quota(str(tenant.id)) + + try: + token = RegisterService.invite_new_member( + tenant=tenant, + email=payload.email, + language=None, + role=payload.role, + inviter=inviter, + ) + except AccountAlreadyInTenantError as exc: + raise BadRequest(str(exc)) + except NoPermissionError as exc: + raise BadRequest(str(exc)) + except AccountRegisterError as exc: + raise BadRequest(str(exc)) + + normalized_email = payload.email.lower() + member = AccountService.get_account_by_email_with_case_fallback(normalized_email) + if member is None: + # invite_new_member just created or fetched this account. + raise RuntimeError("invited member missing from DB after invite") + + encoded_email = parse.quote(normalized_email) + invite_url = f"{dify_config.CONSOLE_WEB_URL}/activate?email={encoded_email}&token={token}" + return MemberInviteResponse( + email=normalized_email, + role=payload.role, + member_id=str(member.id), + invite_url=invite_url, + tenant_id=str(tenant.id), + ).model_dump(mode="json"), 201 + + +@openapi_ns.route("/workspaces//members/") +class WorkspaceMemberApi(Resource): + """Remove a member. + + Self-removal and owner-removal are explicitly rejected by the service + layer (CannotOperateSelfError, NoPermissionError) — both surface as + 400 per the spec, with the service's message preserved. + """ + + @openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__]) + @auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) + def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData): + operator = _load_account(auth_data.account_id) + tenant = _load_tenant(workspace_id) + member = AccountService.get_account_by_id(db.session, member_id) + if member is None: + raise NotFound("member not found") + + try: + TenantService.remove_member_from_tenant(tenant, member, operator) + except CannotOperateSelfError as exc: + raise BadRequest(str(exc)) + except NoPermissionError as exc: + raise BadRequest(str(exc)) + except MemberNotInTenantError as exc: + raise NotFound(str(exc)) + + return MemberActionResponse().model_dump(mode="json"), 200 + + +@openapi_ns.route("/workspaces//members//role") +class WorkspaceMemberRoleApi(Resource): + """Change a member's role. + + Owner cannot be assigned here (closed enum). Admin cannot demote the + standing owner (service NoPermissionError → 400, per spec). + """ + + @openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__]) + @openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__]) + @auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) + def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData): + payload = _validate_body(MemberRoleUpdatePayload) + operator = _load_account(auth_data.account_id) + tenant = _load_tenant(workspace_id) + member = AccountService.get_account_by_id(db.session, member_id) + if member is None: + raise NotFound("member not found") + + try: + TenantService.update_member_role(tenant, member, payload.role, operator) + except CannotOperateSelfError as exc: + raise BadRequest(str(exc)) + except NoPermissionError as exc: + raise BadRequest(str(exc)) + except MemberNotInTenantError as exc: + raise NotFound(str(exc)) + except RoleAlreadyAssignedError as exc: + raise BadRequest(str(exc)) + + return MemberActionResponse().model_dump(mode="json"), 200 + + def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse: return WorkspaceSummaryResponse( id=str(tenant.id), diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index ba914c4dd4..55a1c47c42 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,11 +1,14 @@ from uuid import UUID -from controllers.common.schema import register_schema_model +from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +from fields.hit_testing_fields import HitTestingResponse +from libs.helper import dump_response -register_schema_model(service_api_ns, HitTestingPayload) +register_schema_models(service_api_ns, HitTestingPayload) +register_response_schema_models(service_api_ns, HitTestingResponse) @service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") @@ -13,16 +16,16 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): @service_api_ns.doc("dataset_hit_testing") @service_api_ns.doc(description="Perform hit testing on a dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) - @service_api_ns.doc( - responses={ - 200: "Hit testing results", - 401: "Unauthorized - invalid API token", - 404: "Dataset not found", - } + @service_api_ns.response( + 200, + "Hit testing results", + model=service_api_ns.models[HitTestingResponse.__name__], ) + @service_api_ns.response(401, "Unauthorized - invalid API token") + @service_api_ns.response(404, "Dataset not found") @service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__]) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id: UUID): + def post(self, tenant_id: str, dataset_id: UUID) -> dict[str, object]: """Perform hit testing on a dataset. Tests retrieval performance for the specified dataset. @@ -33,4 +36,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): args = self.parse_args(service_api_ns.payload) self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) + return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args)) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index a019c99a28..66d19fba44 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound, Unauthorized +from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index b4b5c5134d..e73f631000 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -12,7 +12,7 @@ from controllers.common.schema import register_response_schema_models, register_ from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from libs.passport import PassportService from libs.token import extract_webapp_passport -from models.model import App, AppMode +from models.model import App, AppMode, EndUser from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -56,7 +56,7 @@ class AppParameterApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model: App, end_user): + def get(self, app_model: App, end_user: EndUser): """Retrieve app parameters.""" if not app_model.enable_site: raise BadRequest("Site is disabled.") @@ -95,7 +95,7 @@ class AppMeta(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model: App, end_user): + def get(self, app_model: App, end_user: EndUser): """Get app meta""" return AppService().get_app_meta(app_model) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index f08d08ab7d..258493303f 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -29,7 +29,7 @@ from core.errors.error import ( from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from models.model import AppMode +from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService from services.errors.llm import InvokeRateLimitError @@ -86,7 +86,7 @@ class CompletionApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() @@ -140,7 +140,7 @@ class CompletionStopApi(WebApiResource): } ) @web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__]) - def post(self, app_model, end_user, task_id: str): + def post(self, app_model: App, end_user: EndUser, task_id: str): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() @@ -169,7 +169,7 @@ class ChatApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -226,7 +226,7 @@ class ChatStopApi(WebApiResource): } ) @web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__]) - def post(self, app_model, end_user, task_id: str): + def post(self, app_model: App, end_user: EndUser, task_id: str): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 00db29a606..7803b11f4e 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -19,7 +19,7 @@ from fields.conversation_fields import ( SimpleConversation, ) from libs.helper import uuid_value -from models.model import AppMode +from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService @@ -81,7 +81,7 @@ class ConversationListApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -127,7 +127,7 @@ class ConversationApi(WebApiResource): 500: "Internal Server Error", } ) - def delete(self, app_model, end_user, c_id: UUID): + def delete(self, app_model: App, end_user: EndUser, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -166,7 +166,7 @@ class ConversationRenameApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model, end_user, c_id: UUID): + def post(self, app_model: App, end_user: EndUser, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -204,7 +204,7 @@ class ConversationPinApi(WebApiResource): } ) @web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__]) - def patch(self, app_model, end_user, c_id: UUID): + def patch(self, app_model: App, end_user: EndUser, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -235,7 +235,7 @@ class ConversationUnPinApi(WebApiResource): } ) @web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__]) - def patch(self, app_model, end_user, c_id: UUID): + def patch(self, app_model: App, end_user: EndUser, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 6128490104..e08a337364 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -13,6 +13,7 @@ from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db from fields.file_fields import FileResponse +from models.model import App, EndUser from services.file_service import FileService register_schema_models(web_ns, FileResponse) @@ -31,7 +32,7 @@ class FileApi(WebApiResource): } ) @web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__]) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): """Upload a file for use in web applications. Accepts file uploads for use within web applications, supporting diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index cf0363b66e..ee58433679 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -27,7 +27,7 @@ from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfinite from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.enums import FeedbackRating -from models.model import AppMode +from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError @@ -81,7 +81,7 @@ class MessageListApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -133,7 +133,7 @@ class MessageFeedbackApi(WebApiResource): } ) @web_ns.response(200, "Feedback submitted successfully", web_ns.models[ResultResponse.__name__]) - def post(self, app_model, end_user, message_id: UUID): + def post(self, app_model: App, end_user: EndUser, message_id: UUID): message_id_str = str(message_id) payload = MessageFeedbackPayload.model_validate(web_ns.payload or {}) @@ -167,7 +167,7 @@ class MessageMoreLikeThisApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user, message_id: UUID): + def get(self, app_model: App, end_user: EndUser, message_id: UUID): if app_model.mode != "completion": raise NotCompletionAppError() @@ -223,7 +223,7 @@ class MessageSuggestedQuestionApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user, message_id: UUID): + def get(self, app_model: App, end_user: EndUser, message_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index e9f727097b..c18c05d3e9 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -13,6 +13,7 @@ from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from graphon.file import helpers as file_helpers +from models.model import App, EndUser from services.file_service import FileService from ..common.schema import register_response_schema_models, register_schema_models @@ -41,7 +42,7 @@ class RemoteFileInfoApi(WebApiResource): } ) @web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__]) - def get(self, app_model, end_user, url): + def get(self, app_model: App, end_user: EndUser, url: str): """Get information about a remote file. Retrieves basic information about a file located at a remote URL, @@ -85,7 +86,7 @@ class RemoteFileUploadApi(WebApiResource): } ) @web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__]) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): """Upload a file from a remote URL. Downloads a file from the provided remote URL and uploads it diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 766cfc6c60..7ce72e56ab 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -11,6 +11,7 @@ from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem +from models.model import App, EndUser from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -43,7 +44,7 @@ class SavedMessageListApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): if app_model.mode != "completion": raise NotCompletionAppError() @@ -77,7 +78,7 @@ class SavedMessageListApi(WebApiResource): } ) @web_ns.response(200, "Message saved successfully", web_ns.models[ResultResponse.__name__]) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): if app_model.mode != "completion": raise NotCompletionAppError() @@ -106,7 +107,7 @@ class SavedMessageApi(WebApiResource): 500: "Internal Server Error", } ) - def delete(self, app_model, end_user, message_id: UUID): + def delete(self, app_model: App, end_user: EndUser, message_id: UUID): message_id_str = str(message_id) if app_model.mode != "completion": diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index bd21632b05..19b04b7acc 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -10,7 +10,7 @@ from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField from models.account import TenantStatus -from models.model import App, Site +from models.model import App, EndUser, Site from services.feature_service import FeatureService @@ -70,7 +70,7 @@ class AppSiteApi(WebApiResource): } ) @marshal_with(app_fields) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): """Retrieve app site info.""" # get site site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) @@ -78,7 +78,7 @@ class AppSiteApi(WebApiResource): if not site: raise Forbidden() - if app_model.tenant.status == TenantStatus.ARCHIVE: + if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() can_replace_logo = FeatureService.get_features(app_model.tenant_id, exclude_vector_space=True).can_replace_logo diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index cba4659483..694d633148 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -22,9 +22,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ( - ToolParameter, -) from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db @@ -150,44 +147,9 @@ class BaseAgentRunner(AppRunner): message_tool = PromptMessageTool( name=tool.tool_name, description=tool_entity.entity.description.llm, - parameters={ - "type": "object", - "properties": {}, - "required": [], - }, + parameters=tool_entity.get_llm_parameters_json_schema(), ) - parameters = tool_entity.get_merged_runtime_parameters() - for parameter in parameters: - if parameter.form != ToolParameter.ToolParameterForm.LLM: - continue - - parameter_type = parameter.type.as_normal_type() - if parameter.type in { - ToolParameter.ToolParameterType.SYSTEM_FILES, - ToolParameter.ToolParameterType.FILE, - ToolParameter.ToolParameterType.FILES, - }: - continue - enum = [] - if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] if parameter.options else [] - - message_tool.parameters["properties"][parameter.name] = ( - { - "type": parameter_type, - "description": parameter.llm_description or "", - } - if parameter.input_schema is None - else parameter.input_schema - ) - - if len(enum) > 0: - message_tool.parameters["properties"][parameter.name]["enum"] = enum - - if parameter.required: - message_tool.parameters["required"].append(parameter.name) - return message_tool, tool_entity def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: @@ -252,40 +214,7 @@ class BaseAgentRunner(AppRunner): """ update prompt message tool """ - # try to get tool runtime parameters - tool_runtime_parameters = tool.get_runtime_parameters() - - for parameter in tool_runtime_parameters: - if parameter.form != ToolParameter.ToolParameterForm.LLM: - continue - - parameter_type = parameter.type.as_normal_type() - if parameter.type in { - ToolParameter.ToolParameterType.SYSTEM_FILES, - ToolParameter.ToolParameterType.FILE, - ToolParameter.ToolParameterType.FILES, - }: - continue - enum = [] - if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] if parameter.options else [] - - prompt_tool.parameters["properties"][parameter.name] = ( - { - "type": parameter_type, - "description": parameter.llm_description or "", - } - if parameter.input_schema is None - else parameter.input_schema - ) - - if len(enum) > 0: - prompt_tool.parameters["properties"][parameter.name]["enum"] = enum - - if parameter.required: - if parameter.name not in prompt_tool.parameters["required"]: - prompt_tool.parameters["required"].append(parameter.name) - + prompt_tool.parameters = tool.get_llm_parameters_json_schema() return prompt_tool def create_agent_thought( diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 1251b397e2..0ca682e87a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AppGenerateEntity, EasyUIBasedAppGenerateEntity, @@ -292,46 +293,51 @@ class AppRunner: prompt_messages: list[PromptMessage] = [] text = "" usage = None - for result in invoke_result: - if not agent: - queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) - else: - queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) + try: + for result in invoke_result: + if not agent: + queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) + else: + queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) - message = result.delta.message - if isinstance(message.content, str): - text += message.content - elif isinstance(message.content, list): - for content in message.content: - if isinstance(content, str): - text += content - elif isinstance(content, TextPromptMessageContent): - text += content.data - elif isinstance(content, ImagePromptMessageContent): - if message_id and user_id and tenant_id: - try: - self._handle_multimodal_image_content( - content=content, - message_id=message_id, - user_id=user_id, - tenant_id=tenant_id, - queue_manager=queue_manager, - ) - except Exception: - _logger.exception("Failed to handle multimodal image output") + message = result.delta.message + if isinstance(message.content, str): + text += message.content + elif isinstance(message.content, list): + for content in message.content: + if isinstance(content, str): + text += content + elif isinstance(content, TextPromptMessageContent): + text += content.data + elif isinstance(content, ImagePromptMessageContent): + if message_id and user_id and tenant_id: + try: + self._handle_multimodal_image_content( + content=content, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + queue_manager=queue_manager, + ) + except Exception: + _logger.exception("Failed to handle multimodal image output") + else: + _logger.warning("Received multimodal output but missing required parameters") else: - _logger.warning("Received multimodal output but missing required parameters") - else: - text += content.data if hasattr(content, "data") else str(content) + text += content.data if hasattr(content, "data") else str(content) - if not model: - model = result.model + if not model: + model = result.model - if not prompt_messages: - prompt_messages = list(result.prompt_messages) + if not prompt_messages: + prompt_messages = list(result.prompt_messages) - if result.delta.usage: - usage = result.delta.usage + if result.delta.usage: + usage = result.delta.usage + except GenerateTaskStoppedError: + # Explicitly close provider stream to stop in-flight token generation ASAP. + invoke_result.close() + raise if usage is None: usage = LLMUsage.empty_usage() diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 619590c81e..4561388f8b 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -47,6 +47,12 @@ from graphon.graph_events import ( ) from graphon.node_events import NodeRunResult from libs.datetime_utils import naive_utc_now +from services.workflow.inspector_events import ( + publish_node_changed as _inspector_publish_node_changed, +) +from services.workflow.inspector_events import ( + publish_workflow_completed as _inspector_publish_workflow_completed, +) @dataclass(slots=True) @@ -163,6 +169,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._workflow_execution_repository.save(execution) self._enqueue_trace_task(execution) + _inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value)) def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None: execution = self._get_workflow_execution() @@ -173,6 +180,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._workflow_execution_repository.save(execution) self._enqueue_trace_task(execution) + _inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value)) def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None: execution = self._get_workflow_execution() @@ -184,6 +192,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._fail_running_node_executions(error_message=event.error) self._workflow_execution_repository.save(execution) self._enqueue_trace_task(execution) + _inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value)) def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None: execution = self._get_workflow_execution() @@ -194,6 +203,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._fail_running_node_executions(error_message=execution.error_message or "") self._workflow_execution_repository.save(execution) self._enqueue_trace_task(execution) + _inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value)) def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None: execution = self._get_workflow_execution() @@ -241,6 +251,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): created_at=event.start_at, ) self._node_snapshots[event.id] = snapshot + _inspector_publish_node_changed(workflow_run_id=execution.id_, node_id=event.node_id, status="running") def _handle_node_retry(self, event: NodeRunRetryEvent) -> None: domain_execution = self._get_node_execution(event.id) @@ -248,6 +259,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer): domain_execution.error = event.error self._workflow_node_execution_repository.save(domain_execution) self._workflow_node_execution_repository.save_execution_data(domain_execution) + _inspector_publish_node_changed( + workflow_run_id=self._get_workflow_execution().id_, + node_id=domain_execution.node_id, + status="retry", + ) def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: domain_execution = self._get_node_execution(event.id) @@ -257,6 +273,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer): WorkflowNodeExecutionStatus.SUCCEEDED, finished_at=event.finished_at, ) + _inspector_publish_node_changed( + workflow_run_id=self._get_workflow_execution().id_, + node_id=domain_execution.node_id, + status="succeeded", + ) def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: domain_execution = self._get_node_execution(event.id) @@ -267,6 +288,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer): error=event.error, finished_at=event.finished_at, ) + _inspector_publish_node_changed( + workflow_run_id=self._get_workflow_execution().id_, + node_id=domain_execution.node_id, + status="failed", + ) def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: domain_execution = self._get_node_execution(event.id) @@ -277,6 +303,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer): error=event.error, finished_at=event.finished_at, ) + _inspector_publish_node_changed( + workflow_run_id=self._get_workflow_execution().id_, + node_id=domain_execution.node_id, + status="exception", + ) def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None: domain_execution = self._get_node_execution(event.id) diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index ab0f73a9a2..4d784b5f23 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -126,34 +126,89 @@ class Tool(ABC): message_id: str | None = None, ) -> list[ToolParameter]: """ - get merged runtime parameters + Get the effective parameter declarations for this tool. + + Runtime parameters override declared parameters by name and append new + parameters, but the returned list is always detached from the tool's + cached declarations so callers can safely mutate it while building + downstream schemas. :return: merged runtime parameters """ - parameters = self.entity.parameters - parameters = parameters.copy() - user_parameters = self.get_runtime_parameters() or [] - user_parameters = user_parameters.copy() + parameters = [deepcopy(parameter) for parameter in self.entity.parameters or []] + user_parameters = [ + deepcopy(parameter) + for parameter in self.get_runtime_parameters( + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + or [] + ] + + parameter_indexes = {parameter.name: index for index, parameter in enumerate(parameters)} - # override parameters for parameter in user_parameters: - # check if parameter in tool parameters - for tool_parameter in parameters: - if tool_parameter.name == parameter.name: - # override parameter - tool_parameter.type = parameter.type - tool_parameter.form = parameter.form - tool_parameter.required = parameter.required - tool_parameter.default = parameter.default - tool_parameter.options = parameter.options - tool_parameter.llm_description = parameter.llm_description - break - else: - # add new parameter + existing_index = parameter_indexes.get(parameter.name) + if existing_index is None: + parameter_indexes[parameter.name] = len(parameters) parameters.append(parameter) + continue + parameters[existing_index] = parameter return parameters + def get_llm_parameters_json_schema( + self, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> dict[str, Any]: + """Build the model-visible JSON schema from effective tool parameters. + + Hidden/manual parameters stay available for invocation preparation on the + API side, but are intentionally omitted from the LLM-facing schema. + """ + schema: dict[str, Any] = { + "type": "object", + "properties": {}, + "required": [], + } + + for parameter in self.get_merged_runtime_parameters( + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ): + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + continue + + parameter_schema: dict[str, Any] = ( + { + "type": parameter.type.as_normal_type(), + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else deepcopy(parameter.input_schema) + ) + parameter_schema.setdefault("description", parameter.llm_description or "") + + if parameter.type == ToolParameter.ToolParameterType.SELECT and parameter.options: + parameter_schema["enum"] = [option.value for option in parameter.options] + + schema["properties"][parameter.name] = parameter_schema + if parameter.required: + schema["required"].append(parameter.name) + + return schema + def create_image_message( self, image: str, diff --git a/api/core/workflow/nodes/agent_v2/plugin_tools_builder.py b/api/core/workflow/nodes/agent_v2/plugin_tools_builder.py new file mode 100644 index 0000000000..0dd98f65ed --- /dev/null +++ b/api/core/workflow/nodes/agent_v2/plugin_tools_builder.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Protocol, cast + +from dify_agent.layers.dify_plugin import ( + DifyPluginCredentialValue, + DifyPluginToolConfig, + DifyPluginToolCredentialType, + DifyPluginToolParameter, + DifyPluginToolParameterForm, + DifyPluginToolsLayerConfig, +) + +from core.agent.entities import AgentToolEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.errors import ( + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) +from core.tools.tool_manager import ToolManager +from models.agent_config_entities import AgentSoulDifyToolConfig, AgentSoulToolsConfig +from models.provider_ids import ToolProviderID + + +class WorkflowAgentPluginToolsBuildError(ValueError): + """Raised when Agent Soul tools cannot be prepared for Agent backend.""" + + def __init__(self, error_code: str, message: str) -> None: + self.error_code = error_code + super().__init__(message) + + +class AgentToolRuntimeProvider(Protocol): + def get_agent_tool_runtime( + self, + tenant_id: str, + app_id: str, + agent_tool: AgentToolEntity, + user_id: str | None = None, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + variable_pool: Any | None = None, + ) -> Tool: ... + + +class WorkflowAgentPluginToolsBuilder: + """Prepare Agent Soul Dify Plugin Tools for the public Agent backend DTO.""" + + def __init__(self, *, tool_runtime_provider: AgentToolRuntimeProvider | None = None) -> None: + self._tool_runtime_provider = tool_runtime_provider or ToolManager + + def build( + self, + *, + tenant_id: str, + app_id: str, + user_id: str | None, + tools: AgentSoulToolsConfig, + invoke_from: InvokeFrom, + ) -> DifyPluginToolsLayerConfig | None: + """Resolve user-selected Dify Plugin Tools into the Agent backend DTO. + + ``invoke_from`` is the *real* runtime caller category (DEBUGGER for a + Composer test run, SERVICE_API / WEB_APP for a published run). It must + be threaded through to :class:`ToolManager` so credential quotas, rate + limits, and audit tags match the actual call site. + """ + enabled_tools = [tool for tool in tools.dify_tools if tool.enabled] + if not enabled_tools: + return None + + prepared: list[DifyPluginToolConfig] = [] + seen_names: set[str] = set() + for tool_config in enabled_tools: + agent_tool = self._to_agent_tool_entity(tool_config) + tool_runtime = self._fetch_tool_runtime( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + agent_tool=agent_tool, + invoke_from=invoke_from, + tool_config=tool_config, + ) + + exposed_name = self._exposed_tool_name(tool_config) + if exposed_name in seen_names: + raise WorkflowAgentPluginToolsBuildError( + "agent_tool_name_duplicated", + f"Duplicate Dify Plugin Tool name {exposed_name!r}.", + ) + seen_names.add(exposed_name) + + prepared.append(self._to_backend_tool_config(tool_config, tool_runtime, exposed_name)) + + return DifyPluginToolsLayerConfig(tools=prepared) + + def _fetch_tool_runtime( + self, + *, + tenant_id: str, + app_id: str, + user_id: str | None, + agent_tool: AgentToolEntity, + invoke_from: InvokeFrom, + tool_config: AgentSoulDifyToolConfig, + ) -> Tool: + """Resolve the API-side ``Tool`` runtime, mapping fetch errors to + Inspector-friendly error codes so callers can render distinct UX for + "tool definition gone" vs "credential failed". + """ + try: + return self._tool_runtime_provider.get_agent_tool_runtime( + tenant_id=tenant_id, + app_id=app_id, + agent_tool=agent_tool, + user_id=user_id, + invoke_from=invoke_from, + variable_pool=None, + ) + except ToolProviderNotFoundError as exc: + raise WorkflowAgentPluginToolsBuildError( + "agent_tool_declaration_not_found", + f"Dify Plugin Tool {tool_config.tool_name!r} declaration not found: {exc}", + ) from exc + except ToolProviderCredentialValidationError as exc: + raise WorkflowAgentPluginToolsBuildError( + "agent_tool_credential_invalid", + f"Dify Plugin Tool {tool_config.tool_name!r} credential validation failed: {exc}", + ) from exc + except ValueError as exc: + # ToolManager raises bare ValueError when the agent tool's + # ``runtime`` / runtime parameters are missing. Surface it under a + # narrower error code than a generic "declaration not found" so + # frontend can render an actionable hint. + raise WorkflowAgentPluginToolsBuildError( + "agent_tool_config_invalid", + f"Dify Plugin Tool {tool_config.tool_name!r} runtime construction failed: {exc}", + ) from exc + + @staticmethod + def _to_agent_tool_entity(tool_config: AgentSoulDifyToolConfig) -> AgentToolEntity: + return AgentToolEntity( + provider_type=ToolProviderType.value_of(tool_config.provider_type), + provider_id=WorkflowAgentPluginToolsBuilder._provider_id(tool_config), + tool_name=tool_config.tool_name, + tool_parameters=dict(tool_config.runtime_parameters), + credential_id=tool_config.credential_ref.id if tool_config.credential_ref else None, + ) + + @staticmethod + def _provider_id(tool_config: AgentSoulDifyToolConfig) -> str: + if tool_config.provider_id: + return tool_config.provider_id + assert tool_config.plugin_id is not None + assert tool_config.provider is not None + return f"{tool_config.plugin_id}/{tool_config.provider}" + + @staticmethod + def _exposed_tool_name(tool_config: AgentSoulDifyToolConfig) -> str: + # Stage 3.1 decision: no user rename yet. Keep the model-visible tool + # name aligned with the plugin declaration identity. + return tool_config.tool_name + + def _to_backend_tool_config( + self, + tool_config: AgentSoulDifyToolConfig, + tool_runtime: Tool, + exposed_name: str, + ) -> DifyPluginToolConfig: + runtime = tool_runtime.runtime + if runtime is None: + raise WorkflowAgentPluginToolsBuildError( + "agent_tool_config_invalid", + f"Dify Plugin Tool {tool_config.tool_name!r} has no runtime.", + ) + + provider_id = self._provider_id(tool_config) + plugin_id, provider = self._plugin_provider(tool_config, provider_id) + parameters = [ + DifyPluginToolParameter.model_validate(parameter.model_dump(mode="json")) + for parameter in tool_runtime.get_merged_runtime_parameters() + ] + runtime_parameters = self._runtime_parameters(tool_runtime, parameters) + description = tool_config.description + if description is None and tool_runtime.entity.description is not None: + description = tool_runtime.entity.description.llm + + return DifyPluginToolConfig( + plugin_id=plugin_id, + provider=provider, + tool_name=tool_config.tool_name, + credential_type=self._credential_type(tool_config, runtime.credentials), + name=exposed_name, + description=description, + credentials=self._normalize_credentials(runtime.credentials, tool_name=tool_config.tool_name), + runtime_parameters=runtime_parameters, + parameters=parameters, + parameters_json_schema=cast(dict[str, Any], tool_runtime.get_llm_parameters_json_schema()), + ) + + @staticmethod + def _plugin_provider(tool_config: AgentSoulDifyToolConfig, provider_id: str) -> tuple[str, str]: + if tool_config.plugin_id and tool_config.provider: + return tool_config.plugin_id, tool_config.provider + provider_id_entity = ToolProviderID(provider_id) + return provider_id_entity.plugin_id, provider_id_entity.provider_name + + @staticmethod + def _credential_type( + tool_config: AgentSoulDifyToolConfig, + credentials: Mapping[str, Any], + ) -> DifyPluginToolCredentialType: + if not credentials and tool_config.credential_type == "unauthorized": + return "unauthorized" + return tool_config.credential_type + + @staticmethod + def _runtime_parameters( + tool_runtime: Tool, + parameters: list[DifyPluginToolParameter], + ) -> dict[str, Any]: + runtime = tool_runtime.runtime + runtime_parameters = dict(runtime.runtime_parameters if runtime is not None else {}) + missing = [ + parameter.name + for parameter in parameters + if parameter.form is not DifyPluginToolParameterForm.LLM + and parameter.required + and parameter.default is None + and parameter.name not in runtime_parameters + ] + if missing: + names = ", ".join(sorted(missing)) + raise WorkflowAgentPluginToolsBuildError( + "agent_tool_runtime_parameter_missing", + f"Dify Plugin Tool {tool_runtime.entity.identity.name!r} is missing runtime parameters: {names}.", + ) + return runtime_parameters + + @staticmethod + def _normalize_credentials( + credentials: Mapping[str, Any], + *, + tool_name: str, + ) -> dict[str, DifyPluginCredentialValue]: + """Forward only scalar credential values to the Agent backend. + + ``DifyPluginCredentialValue`` is ``str | int | float | bool | None``. + Refusing non-scalar values (lists, dicts, custom objects) is safer than + ``str(value)`` — stringifying a nested OAuth token blob produces a + Python ``repr`` that the plugin daemon cannot use, and we'd rather + surface a clear ``agent_tool_credential_shape_invalid`` than send junk. + """ + normalized: dict[str, DifyPluginCredentialValue] = {} + for key, value in credentials.items(): + if isinstance(value, str | int | float | bool) or value is None: + normalized[key] = value + continue + raise WorkflowAgentPluginToolsBuildError( + "agent_tool_credential_shape_invalid", + ( + f"Dify Plugin Tool {tool_name!r} credential {key!r} has a non-scalar value " + f"({type(value).__name__}); only str/int/float/bool/None are forwarded to the daemon." + ), + ) + return normalized diff --git a/api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py b/api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py index afd730f652..b08dc36fb7 100644 --- a/api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py +++ b/api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py @@ -11,13 +11,14 @@ SUPPORTED_AGENT_BACKEND_FEATURES = frozenset( "workflow_context", "model", "structured_output", + "tools.dify_tools", } ) RESERVED_AGENT_BACKEND_FEATURES = frozenset( { "skills_files", - "tools", + "tools.cli_tools", "knowledge", "human", "env", @@ -32,7 +33,7 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any warnings: list[dict[str, str]] = [] soul_dump = agent_soul.model_dump(mode="json") for section in sorted(RESERVED_AGENT_BACKEND_FEATURES): - value = soul_dump.get(section) + value = _get_nested(soul_dump, section) has_value = bool(value) if isinstance(value, dict): has_value = any(bool(item) for item in value.values()) @@ -41,11 +42,12 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any { "section": f"agent_soul.{section}", "code": "agent_backend_layer_not_available", - "message": f"{section} is saved in Agent Soul but is not executed by Agent backend in phase 3.", + "message": f"{section} is saved in Agent Soul but is not executed by Agent backend.", } ) reserved_status = dict.fromkeys(sorted(RESERVED_AGENT_BACKEND_FEATURES), "reserved_not_executed") + reserved_status["tools.dify_tools"] = "supported_when_config_valid" return { "supported": sorted(SUPPORTED_AGENT_BACKEND_FEATURES), @@ -53,3 +55,12 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any "reserved_status": reserved_status, "unsupported_runtime_warnings": warnings, } + + +def _get_nested(value: dict[str, Any], path: str) -> Any: + current: Any = value + for part in path.split("."): + if not isinstance(current, dict): + return None + current = current.get(part) + return current diff --git a/api/core/workflow/nodes/agent_v2/runtime_request_builder.py b/api/core/workflow/nodes/agent_v2/runtime_request_builder.py index 431f658e33..0a0960d493 100644 --- a/api/core/workflow/nodes/agent_v2/runtime_request_builder.py +++ b/api/core/workflow/nodes/agent_v2/runtime_request_builder.py @@ -4,7 +4,8 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass from typing import Any, Literal, Protocol, cast -from dify_agent.protocol import CreateRunRequest, ExecutionContext +from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig +from dify_agent.protocol import CreateRunRequest from clients.agent_backend import ( AgentBackendModelConfig, @@ -29,6 +30,7 @@ from models.agent_config_entities import ( ) from .output_failure_orchestrator import retry_idempotency_key +from .plugin_tools_builder import WorkflowAgentPluginToolsBuilder, WorkflowAgentPluginToolsBuildError from .runtime_feature_manifest import build_runtime_feature_manifest @@ -83,9 +85,11 @@ class WorkflowAgentRuntimeRequestBuilder: *, credentials_provider: CredentialsProvider, request_builder: AgentBackendRunRequestBuilder | None = None, + plugin_tools_builder: WorkflowAgentPluginToolsBuilder | None = None, ) -> None: self._credentials_provider = credentials_provider self._request_builder = request_builder or AgentBackendRunRequestBuilder() + self._plugin_tools_builder = plugin_tools_builder or WorkflowAgentPluginToolsBuilder() def build(self, context: WorkflowAgentRuntimeBuildContext) -> WorkflowAgentRuntimeRequest: agent_soul = AgentSoulConfig.model_validate(context.snapshot.config_snapshot_dict) @@ -101,20 +105,44 @@ class WorkflowAgentRuntimeRequestBuilder: workflow_job_prompt = node_job.workflow_prompt.strip() or "Run this workflow Agent Node for the current run." user_prompt = workflow_context_prompt.strip() or "Use the current workflow context." credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model) + try: + tools_layer = self._plugin_tools_builder.build( + tenant_id=context.dify_context.tenant_id, + app_id=context.dify_context.app_id, + user_id=context.dify_context.user_id, + tools=agent_soul.tools, + # Thread the *real* runtime invocation source through to + # ToolManager so credential quotas, rate limits, and audit + # trails match the actual call site (DEBUGGER for draft test + # run, SERVICE_API / WEB_APP for published run). + invoke_from=context.dify_context.invoke_from, + ) + except WorkflowAgentPluginToolsBuildError as error: + raise WorkflowAgentRuntimeRequestBuildError(error.error_code, str(error)) from error + if tools_layer is not None: + metadata["agent_tools"] = { + "dify_tool_count": len(tools_layer.tools), + "dify_tool_names": [tool.name or tool.tool_name for tool in tools_layer.tools], + "cli_tool_count": len(agent_soul.tools.cli_tools), + } request = self._request_builder.build_for_workflow_node( AgentBackendWorkflowNodeRunInput( model=AgentBackendModelConfig( - tenant_id=context.dify_context.tenant_id, plugin_id=agent_soul.model.plugin_id, model_provider=agent_soul.model.model_provider, model=agent_soul.model.model, - user_id=context.dify_context.user_id, credentials=self._normalize_credentials(credentials), model_settings=cast(dict[str, Any], agent_soul.model.model_settings), ), - execution_context=ExecutionContext( + # The execution-context layer is now the only public protocol + # carrier for Dify tenant/user/run identifiers. ``user_id`` must + # be forwarded here because downstream plugin-daemon provider and + # tool clients read it from this layer rather than from any + # parallel top-level request field. + execution_context=DifyExecutionContextLayerConfig( tenant_id=context.dify_context.tenant_id, + user_id=context.dify_context.user_id, app_id=context.dify_context.app_id, workflow_id=context.workflow_id, workflow_run_id=context.workflow_run_id, @@ -129,6 +157,7 @@ class WorkflowAgentRuntimeRequestBuilder: workflow_node_job_prompt=workflow_job_prompt, user_prompt=user_prompt, output=self._build_output_config(node_job.declared_outputs), + tools=tools_layer, idempotency_key=self._idempotency_key(context), metadata=metadata, ) diff --git a/api/core/workflow/nodes/agent_v2/validators.py b/api/core/workflow/nodes/agent_v2/validators.py index f8df0506e8..768fcdadeb 100644 --- a/api/core/workflow/nodes/agent_v2/validators.py +++ b/api/core/workflow/nodes/agent_v2/validators.py @@ -126,6 +126,7 @@ class WorkflowAgentNodeValidator: raise WorkflowAgentNodeValidationError( f"Workflow Agent node {binding.node_id} requires Agent Soul model config." ) + cls._validate_agent_soul_tools(binding=binding, agent_soul=agent_soul) node_job = WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict) cls.validate_node_job(session=session, binding=binding, node_job=node_job, topology=topology) @@ -280,6 +281,26 @@ class WorkflowAgentNodeValidator: f"Workflow Agent node {binding.node_id} references unsupported human contact channel {channel}." ) + @classmethod + def _validate_agent_soul_tools( + cls, + *, + binding: WorkflowAgentNodeBinding, + agent_soul: AgentSoulConfig, + ) -> None: + exposed_names: set[str] = set() + for tool in agent_soul.tools.dify_tools: + if not tool.enabled: + continue + exposed_name = tool.tool_name + if exposed_name in exposed_names: + raise WorkflowAgentNodeValidationError( + f"Workflow Agent node {binding.node_id} has duplicate Dify Plugin Tool name {exposed_name}." + ) + exposed_names.add(exposed_name) + # CLI tools remain saved-but-not-executed. They are allowed at publish + # time so existing Agent Soul drafts are not blocked by a reserved field. + @staticmethod def _validate_file_ref( *, diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index fe95cc5816..18a0c75aca 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -12,6 +12,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, convert_to_agent_apps, create_tenant, + data_migrate, delete_archived_workflow_runs, export_app_messages, extract_plugins, @@ -44,6 +45,7 @@ def init_app(app: DifyApp): convert_to_agent_apps, add_qdrant_index, create_tenant, + data_migrate, upgrade_db, fix_app_site_missing, migrate_data_for_plugin, diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 0b54992835..dd7865afed 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,62 +1,96 @@ -from flask_restx import fields +from datetime import datetime +from typing import Any -from libs.helper import TimestampField +from pydantic import field_validator -document_fields = { - "id": fields.String, - "data_source_type": fields.String, - "name": fields.String, - "doc_type": fields.String, - "doc_metadata": fields.Raw, -} +from fields.base import ResponseModel +from libs.helper import to_timestamp -segment_fields = { - "id": fields.String, - "position": fields.Integer, - "document_id": fields.String, - "content": fields.String, - "sign_content": fields.String, - "answer": fields.String, - "word_count": fields.Integer, - "tokens": fields.Integer, - "keywords": fields.List(fields.String), - "index_node_id": fields.String, - "index_node_hash": fields.String, - "hit_count": fields.Integer, - "enabled": fields.Boolean, - "disabled_at": TimestampField, - "disabled_by": fields.String, - "status": fields.String, - "created_by": fields.String, - "created_at": TimestampField, - "indexing_at": TimestampField, - "completed_at": TimestampField, - "error": fields.String, - "stopped_at": TimestampField, - "document": fields.Nested(document_fields), -} -child_chunk_fields = { - "id": fields.String, - "content": fields.String, - "position": fields.Integer, - "score": fields.Float, -} +class HitTestingQuery(ResponseModel): + content: str -files_fields = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "mime_type": fields.String, - "source_url": fields.String, -} -hit_testing_record_fields = { - "segment": fields.Nested(segment_fields), - "child_chunks": fields.List(fields.Nested(child_chunk_fields)), - "score": fields.Float, - "tsne_position": fields.Raw, - "files": fields.List(fields.Nested(files_fields)), - "summary": fields.String, # Summary content if retrieved via summary index -} +class HitTestingDocument(ResponseModel): + id: str + data_source_type: str + name: str + doc_type: str | None + doc_metadata: Any | None + + @field_validator("data_source_type", "doc_type", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + +class HitTestingSegment(ResponseModel): + id: str + position: int + document_id: str + content: str + sign_content: str | None + answer: str | None + word_count: int + tokens: int + keywords: list[str] + index_node_id: str | None + index_node_hash: str | None + hit_count: int + enabled: bool + disabled_at: int | None + disabled_by: str | None + status: str + created_by: str + created_at: int + indexing_at: int | None + completed_at: int | None + error: str | None + stopped_at: int | None + document: HitTestingDocument + + @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + @field_validator("status", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + +class HitTestingChildChunk(ResponseModel): + id: str + content: str + position: int + score: float + + +class HitTestingFile(ResponseModel): + id: str + name: str + size: int + extension: str + mime_type: str + source_url: str + + +class HitTestingRecord(ResponseModel): + segment: HitTestingSegment + child_chunks: list[HitTestingChildChunk] + score: float | None + tsne_position: Any | None + files: list[HitTestingFile] + summary: str | None + + +class HitTestingResponse(ResponseModel): + query: HitTestingQuery + records: list[HitTestingRecord] + + +def _normalize_enum(value: Any) -> Any: + if isinstance(value, str) or value is None: + return value + return getattr(value, "value", value) diff --git a/api/libs/oauth_bearer.py b/api/libs/oauth_bearer.py index 6e8678eca0..7433c6c177 100644 --- a/api/libs/oauth_bearer.py +++ b/api/libs/oauth_bearer.py @@ -43,6 +43,11 @@ class SubjectType(StrEnum): EXTERNAL_SSO = "external_sso" +class TokenType(StrEnum): + OAUTH_ACCOUNT = "oauth_account" + OAUTH_EXTERNAL_SSO = "oauth_external_sso" + + class Scope(StrEnum): """Catalog of bearer scopes recognised by the openapi surface. @@ -55,6 +60,8 @@ class Scope(StrEnum): APPS_READ = "apps:read" APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external" APPS_RUN = "apps:run" + WORKSPACE_READ = "workspace:read" + WORKSPACE_WRITE = "workspace:write" class Accepts(StrEnum): @@ -77,7 +84,7 @@ _SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = { class AuthContext: """Per-request identity published via :data:`_auth_ctx_var` (see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` / - ``subject_type`` / ``source`` come from the TokenKind, not the DB — + ``subject_type`` / ``token_type`` come from the TokenKind, not the DB — corrupt rows can't elevate scope. `verified_tenants` is a snapshot of the Layer-0 verdict cache at @@ -92,7 +99,7 @@ class AuthContext: client_id: str | None scopes: frozenset[Scope] token_id: uuid.UUID - source: str + token_type: TokenType expires_at: datetime | None token_hash: str verified_tenants: dict[str, bool] = field(default_factory=dict) @@ -180,7 +187,7 @@ class TokenKind: prefix: str subject_type: SubjectType scopes: frozenset[Scope] - source: str + token_type: TokenType resolver: Resolver def matches(self, token: str) -> bool: @@ -291,7 +298,7 @@ class BearerAuthenticator: client_id=row.client_id, scopes=kind.scopes, token_id=row.token_id, - source=kind.source, + token_type=kind.token_type, expires_at=row.expires_at, token_hash=token_hash, verified_tenants=dict(row.verified_tenants), @@ -483,7 +490,7 @@ def check_workspace_membership( account_id: uuid.UUID | str, tenant_id: str, token_hash: str, - cached_verdicts: dict[str, bool], + membership_cache: dict[str, bool], ) -> None: """Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow. @@ -492,7 +499,7 @@ def check_workspace_membership( short-circuiting on EE / SSO subjects before invoking — this function runs the membership + active-status checks unconditionally. """ - cached = cached_verdicts.get(tenant_id) + cached = membership_cache.get(tenant_id) if cached is True: return if cached is False: @@ -530,7 +537,7 @@ def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None: account_id=ctx.account_id, tenant_id=tenant_id, token_hash=ctx.token_hash, - cached_verdicts=ctx.verified_tenants, + membership_cache=ctx.verified_tenants, ) @@ -664,14 +671,14 @@ def build_registry(session_factory, redis_client) -> TokenKindRegistry: prefix=account.prefix, subject_type=account.subject_type, scopes=account.scopes, - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, resolver=oauth.for_account(), ), TokenKind( prefix=external.prefix, subject_type=external.subject_type, scopes=external.scopes, - source="oauth_external_sso", + token_type=TokenType.OAUTH_EXTERNAL_SSO, resolver=oauth.for_external_sso(), ), ] diff --git a/api/models/agent_config_entities.py b/api/models/agent_config_entities.py index 9524d22d7f..ec604115de 100644 --- a/api/models/agent_config_entities.py +++ b/api/models/agent_config_entities.py @@ -1,6 +1,6 @@ import re from enum import StrEnum -from typing import Any, Final +from typing import Any, Final, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ -50,8 +50,90 @@ class AgentSoulSkillsFilesConfig(BaseModel): skills: list[dict[str, Any]] = Field(default_factory=list) +class AgentSoulDifyToolCredentialRef(BaseModel): + """Reference to a stored Dify Plugin Tool credential. + + Secret values are resolved only at runtime. The legacy ``credential_id`` + field is accepted by :class:`AgentSoulDifyToolConfig` and normalized here so + old Agent tool payloads can be read while new payloads stay explicit. + """ + + model_config = ConfigDict(extra="ignore") + + type: Literal["provider", "tool"] = "tool" + id: str | None = Field(default=None, max_length=255) + provider: str | None = Field(default=None, max_length=255) + + +class AgentSoulDifyToolConfig(BaseModel): + """One Dify Plugin Tool configured on Agent Soul. + + The API backend prepares this persisted product shape into + ``DifyPluginToolConfig`` before sending a run request to Agent backend. + ``provider_id`` keeps compatibility with existing Agent tool config payloads; + new callers should send ``plugin_id`` + ``provider`` when available. + """ + + # ``extra="ignore"`` (not ``"allow"``) so historical Agent Soul payloads + # with unknown fields still load — but the extra keys are dropped instead + # of silently riding along into ``model_dump``. New callers should send the + # explicit schema fields below. + model_config = ConfigDict(extra="ignore") + + enabled: bool = True + # Dify Plugin Tools live behind the ``PLUGIN`` provider type. ``BUILT_IN`` / + # ``WORKFLOW`` / ``API`` providers are not exposed to the Agent backend in + # this layer — keep the default narrow so a missing field surfaces as + # ``agent_tool_declaration_not_found`` against the correct provider table. + provider_type: str = "plugin" + provider_id: str | None = Field(default=None, max_length=255) + plugin_id: str | None = Field(default=None, max_length=255) + provider: str | None = Field(default=None, max_length=255) + tool_name: str = Field(min_length=1, max_length=255) + credential_type: Literal["api-key", "oauth2", "unauthorized"] = "api-key" + credential_ref: AgentSoulDifyToolCredentialRef | None = None + # Reserved for a future user-rename UX. Accepted but currently rejected at + # validation time so frontend cannot silently believe a rename took effect + # (see :meth:`_validate_provider_and_credentials`). + name: str | None = Field(default=None, max_length=255) + description: str | None = None + runtime_parameters: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def _normalize_legacy_payload(cls, value: Any) -> Any: + if not isinstance(value, dict): + return value + normalized = dict(value) + if normalized.get("provider_id") is None and isinstance(normalized.get("provider_name"), str): + normalized["provider_id"] = normalized["provider_name"] + if normalized.get("runtime_parameters") is None and isinstance(normalized.get("tool_parameters"), dict): + normalized["runtime_parameters"] = normalized["tool_parameters"] + if normalized.get("credential_ref") is None and normalized.get("credential_id"): + normalized["credential_ref"] = { + "type": "tool", + "id": normalized.get("credential_id"), + "provider": normalized.get("provider_id") or normalized.get("provider"), + } + return normalized + + @model_validator(mode="after") + def _validate_provider_and_credentials(self) -> "AgentSoulDifyToolConfig": + if not self.provider_id and not (self.plugin_id and self.provider): + raise ValueError("Dify tool requires provider_id or plugin_id + provider") + if self.credential_type != "unauthorized" and (self.credential_ref is None or not self.credential_ref.id): + raise ValueError("credential_ref.id is required for credentialed Dify tools") + # ``name`` is reserved for a future user-rename UX. Until that lands + # the model-visible name is forced to match ``tool_name``; reject + # explicit values so a frontend bug surfaces immediately instead of + # producing a silently-ignored override. + if self.name is not None and self.name != self.tool_name: + raise ValueError("name override is not yet supported; omit ``name`` or set it equal to ``tool_name``.") + return self + + class AgentSoulToolsConfig(BaseModel): - dify_tools: list[dict[str, Any]] = Field(default_factory=list) + dify_tools: list[AgentSoulDifyToolConfig] = Field(default_factory=list) cli_tools: list[dict[str, Any]] = Field(default_factory=list) diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index 5ea5bbe008..a2302be32c 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -3447,6 +3447,89 @@ Run draft workflow | 200 | Draft workflow run started successfully | | 403 | Permission denied | +### /apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs + +#### GET +##### Description + +Snapshot of every node's declared outputs for a draft workflow run. + +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | Application ID | Yes | string | +| run_id | path | Workflow run ID | Yes | string | + +##### Responses + +| Code | Description | +| ---- | ----------- | +| 404 | Workflow run not found | + +### /apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/events + +#### GET +##### Description + +Server-Sent Events stream of inspector deltas for a draft workflow run. + +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | Application ID | Yes | string | +| run_id | path | Workflow run ID | Yes | string | + +##### Responses + +| Code | Description | +| ---- | ----------- | +| 404 | Workflow run not found | + +### /apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/{node_id} + +#### GET +##### Description + +One node's declared outputs for a draft workflow run. + +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | Application ID | Yes | string | +| node_id | path | Node ID inside the workflow graph | Yes | string | +| run_id | path | Workflow run ID | Yes | string | + +##### Responses + +| Code | Description | +| ---- | ----------- | +| 404 | Workflow run / node not found | + +### /apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/{node_id}/{output_name}/preview + +#### GET +##### Description + +Full value for one declared output, including signed download URL for files. + +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | Application ID | Yes | string | +| node_id | path | Node ID inside the workflow graph | Yes | string | +| output_name | path | Declared output name as exposed by Composer | Yes | string | +| run_id | path | Workflow run ID | Yes | string | + +##### Responses + +| Code | Description | +| ---- | ----------- | +| 404 | Workflow run / node / output not found | + ### /apps/{app_id}/workflows/draft/system-variables #### GET @@ -3684,6 +3767,89 @@ Publish workflow | ---- | ----------- | | 200 | Success | +### /apps/{app_id}/workflows/published/runs/{run_id}/node-outputs + +#### GET +##### Description + +Snapshot of every node's declared outputs for a published workflow run. + +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | Application ID | Yes | string | +| run_id | path | Workflow run ID | Yes | string | + +##### Responses + +| Code | Description | +| ---- | ----------- | +| 404 | Workflow run not found | + +### /apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/events + +#### GET +##### Description + +Server-Sent Events stream of inspector deltas for a published workflow run. + +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | Application ID | Yes | string | +| run_id | path | Workflow run ID | Yes | string | + +##### Responses + +| Code | Description | +| ---- | ----------- | +| 404 | Workflow run not found | + +### /apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/{node_id} + +#### GET +##### Description + +One node's declared outputs for a published workflow run. + +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | Application ID | Yes | string | +| node_id | path | Node ID inside the workflow graph | Yes | string | +| run_id | path | Workflow run ID | Yes | string | + +##### Responses + +| Code | Description | +| ---- | ----------- | +| 404 | Workflow run / node not found | + +### /apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/{node_id}/{output_name}/preview + +#### GET +##### Description + +Full value for one declared output of a published run. + +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | Application ID | Yes | string | +| node_id | path | Node ID inside the workflow graph | Yes | string | +| output_name | path | Declared output name as exposed by Composer | Yes | string | +| run_id | path | Workflow run ID | Yes | string | + +##### Responses + +| Code | Description | +| ---- | ----------- | +| 404 | Workflow run / node / output not found | + ### /apps/{app_id}/workflows/triggers/webhook #### GET @@ -10539,6 +10705,43 @@ Supported icon storage formats for Agent roster entries. | skills_files | [AgentSoulSkillsFilesConfig](#agentsoulskillsfilesconfig) | | No | | tools | [AgentSoulToolsConfig](#agentsoultoolsconfig) | | No | +#### AgentSoulDifyToolConfig + +One Dify Plugin Tool configured on Agent Soul. + +The API backend prepares this persisted product shape into +``DifyPluginToolConfig`` before sending a run request to Agent backend. +``provider_id`` keeps compatibility with existing Agent tool config payloads; +new callers should send ``plugin_id`` + ``provider`` when available. + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| credential_ref | [AgentSoulDifyToolCredentialRef](#agentsouldifytoolcredentialref) | | No | +| credential_type | string | *Enum:* `"api-key"`, `"oauth2"`, `"unauthorized"` | No | +| description | string | | No | +| enabled | boolean | | No | +| name | string | | No | +| plugin_id | string | | No | +| provider | string | | No | +| provider_id | string | | No | +| provider_type | string | | No | +| runtime_parameters | object | | No | +| tool_name | string | | Yes | + +#### AgentSoulDifyToolCredentialRef + +Reference to a stored Dify Plugin Tool credential. + +Secret values are resolved only at runtime. The legacy ``credential_id`` +field is accepted by :class:`AgentSoulDifyToolConfig` and normalized here so +old Agent tool payloads can be read while new payloads stay explicit. + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| id | string | | No | +| provider | string | | No | +| type | string | *Enum:* `"provider"`, `"tool"` | No | + #### AgentSoulEnvConfig | Name | Type | Description | Required | @@ -10616,7 +10819,7 @@ Reference to model credentials resolved only at runtime. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | | cli_tools | [ object ] | | No | -| dify_tools | [ object ] | | No | +| dify_tools | [ [AgentSoulDifyToolConfig](#agentsouldifytoolconfig) ] | | No | #### AgentThought @@ -12848,31 +13051,31 @@ Request payload for bulk downloading documents as a zip archive. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| content | string | | No | -| id | string | | No | -| position | integer | | No | -| score | number | | No | +| content | string | | Yes | +| id | string | | Yes | +| position | integer | | Yes | +| score | number | | Yes | #### HitTestingDocument | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| data_source_type | string | | No | -| doc_metadata | | | No | -| doc_type | string | | No | -| id | string | | No | -| name | string | | No | +| data_source_type | string | | Yes | +| doc_metadata | | | Yes | +| doc_type | string | | Yes | +| id | string | | Yes | +| name | string | | Yes | #### HitTestingFile | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| extension | string | | No | -| id | string | | No | -| mime_type | string | | No | -| name | string | | No | -| size | integer | | No | -| source_url | string | | No | +| extension | string | | Yes | +| id | string | | Yes | +| mime_type | string | | Yes | +| name | string | | Yes | +| size | integer | | Yes | +| source_url | string | | Yes | #### HitTestingPayload @@ -12883,51 +13086,57 @@ Request payload for bulk downloading documents as a zip archive. | query | string | | Yes | | retrieval_model | [RetrievalModel](#retrievalmodel) | | No | +#### HitTestingQuery + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| content | string | | Yes | + #### HitTestingRecord | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | No | -| files | [ [HitTestingFile](#hittestingfile) ] | | No | -| score | number | | No | -| segment | [HitTestingSegment](#hittestingsegment) | | No | -| summary | string | | No | -| tsne_position | | | No | +| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | Yes | +| files | [ [HitTestingFile](#hittestingfile) ] | | Yes | +| score | number | | Yes | +| segment | [HitTestingSegment](#hittestingsegment) | | Yes | +| summary | string | | Yes | +| tsne_position | | | Yes | #### HitTestingResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| query | string | | Yes | -| records | [ [HitTestingRecord](#hittestingrecord) ] | | No | +| query | [HitTestingQuery](#hittestingquery) | | Yes | +| records | [ [HitTestingRecord](#hittestingrecord) ] | | Yes | #### HitTestingSegment | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| answer | string | | No | -| completed_at | integer | | No | -| content | string | | No | -| created_at | integer | | No | -| created_by | string | | No | -| disabled_at | integer | | No | -| disabled_by | string | | No | -| document | [HitTestingDocument](#hittestingdocument) | | No | -| document_id | string | | No | -| enabled | boolean | | No | -| error | string | | No | -| hit_count | integer | | No | -| id | string | | No | -| index_node_hash | string | | No | -| index_node_id | string | | No | -| indexing_at | integer | | No | -| keywords | [ string ] | | No | -| position | integer | | No | -| sign_content | string | | No | -| status | string | | No | -| stopped_at | integer | | No | -| tokens | integer | | No | -| word_count | integer | | No | +| answer | string | | Yes | +| completed_at | integer | | Yes | +| content | string | | Yes | +| created_at | integer | | Yes | +| created_by | string | | Yes | +| disabled_at | integer | | Yes | +| disabled_by | string | | Yes | +| document | [HitTestingDocument](#hittestingdocument) | | Yes | +| document_id | string | | Yes | +| enabled | boolean | | Yes | +| error | string | | Yes | +| hit_count | integer | | Yes | +| id | string | | Yes | +| index_node_hash | string | | Yes | +| index_node_id | string | | Yes | +| indexing_at | integer | | Yes | +| keywords | [ string ] | | Yes | +| position | integer | | Yes | +| sign_content | string | | Yes | +| status | string | | Yes | +| stopped_at | integer | | Yes | +| tokens | integer | | Yes | +| word_count | integer | | Yes | #### HumanInputContent diff --git a/api/openapi/markdown/openapi-swagger.md b/api/openapi/markdown/openapi-swagger.md index 419acdca24..899e09ff4a 100644 --- a/api/openapi/markdown/openapi-swagger.md +++ b/api/openapi/markdown/openapi-swagger.md @@ -323,6 +323,85 @@ Upload a file to use as an input variable when running the app | ---- | ----------- | ------ | | 200 | Workspace detail | [WorkspaceDetailResponse](#workspacedetailresponse) | +### /workspaces/{workspace_id}/members + +#### GET +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| workspace_id | path | | Yes | string | +| limit | query | | No | integer | +| page | query | | No | integer | + +##### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Member list | [MemberListResponse](#memberlistresponse) | + +#### POST +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| workspace_id | path | | Yes | string | +| payload | body | | Yes | [MemberInvitePayload](#memberinvitepayload) | + +##### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 201 | Member invited | [MemberInviteResponse](#memberinviteresponse) | + +### /workspaces/{workspace_id}/members/{member_id} + +#### DELETE +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| member_id | path | | Yes | string | +| workspace_id | path | | Yes | string | + +##### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Member removed | [MemberActionResponse](#memberactionresponse) | + +### /workspaces/{workspace_id}/members/{member_id}/role + +#### PUT +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| member_id | path | | Yes | string | +| workspace_id | path | | Yes | string | +| payload | body | | Yes | [MemberRoleUpdatePayload](#memberroleupdatepayload) | + +##### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Role updated | [MemberActionResponse](#memberactionresponse) | + +### /workspaces/{workspace_id}/switch + +#### POST +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| workspace_id | path | | Yes | string | + +##### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workspace detail | [WorkspaceDetailResponse](#workspacedetailresponse) | + --- ### Models @@ -526,6 +605,66 @@ mode is a closed enum. | ---- | ---- | ----------- | -------- | | JsonValue | | | | +#### MemberActionResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| result | string | | No | + +#### MemberInvitePayload + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| email | string | | Yes | +| role | string | *Enum:* `"admin"`, `"normal"` | Yes | + +#### MemberInviteResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| email | string | | Yes | +| invite_url | string | | Yes | +| member_id | string | | Yes | +| result | string | | No | +| role | string | | Yes | +| tenant_id | string | | Yes | + +#### MemberListQuery + +Strict (extra='forbid'). + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| limit | integer | | No | +| page | integer | | No | + +#### MemberListResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| data | [ [MemberResponse](#memberresponse) ] | | Yes | +| has_more | boolean | | Yes | +| limit | integer | | Yes | +| page | integer | | Yes | +| total | integer | | Yes | + +#### MemberResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| avatar | string | | No | +| email | string | | Yes | +| id | string | | Yes | +| name | string | | Yes | +| role | string | | Yes | +| status | string | | Yes | + +#### MemberRoleUpdatePayload + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| role | string | *Enum:* `"admin"`, `"normal"` | Yes | + #### MessageMetadata | Name | Type | Description | Required | diff --git a/api/openapi/markdown/service-swagger.md b/api/openapi/markdown/service-swagger.md index ee801b1b8e..071b1b526c 100644 --- a/api/openapi/markdown/service-swagger.md +++ b/api/openapi/markdown/service-swagger.md @@ -1363,11 +1363,11 @@ Tests retrieval performance for the specified dataset. ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Hit testing results | -| 401 | Unauthorized - invalid API token | -| 404 | Dataset not found | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Hit testing results | [HitTestingResponse](#hittestingresponse) | +| 401 | Unauthorized - invalid API token | | +| 404 | Dataset not found | | ### /datasets/{dataset_id}/metadata @@ -1614,11 +1614,11 @@ Tests retrieval performance for the specified dataset. ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Hit testing results | -| 401 | Unauthorized - invalid API token | -| 404 | Dataset not found | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Hit testing results | [HitTestingResponse](#hittestingresponse) | +| 401 | Unauthorized - invalid API token | | +| 404 | Dataset not found | | ### /datasets/{dataset_id}/tags @@ -2691,6 +2691,36 @@ Note: The SQLAlchemy model defines an `is_anonymous` property for Flask-Login se | tenant_id | string | | No | | user_id | string | | No | +#### HitTestingChildChunk + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| content | string | | Yes | +| id | string | | Yes | +| position | integer | | Yes | +| score | number | | Yes | + +#### HitTestingDocument + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| data_source_type | string | | Yes | +| doc_metadata | | | Yes | +| doc_type | string | | Yes | +| id | string | | Yes | +| name | string | | Yes | + +#### HitTestingFile + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| extension | string | | Yes | +| id | string | | Yes | +| mime_type | string | | Yes | +| name | string | | Yes | +| size | integer | | Yes | +| source_url | string | | Yes | + #### HitTestingPayload | Name | Type | Description | Required | @@ -2700,6 +2730,58 @@ Note: The SQLAlchemy model defines an `is_anonymous` property for Flask-Login se | query | string | | Yes | | retrieval_model | [RetrievalModel](#retrievalmodel) | | No | +#### HitTestingQuery + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| content | string | | Yes | + +#### HitTestingRecord + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | Yes | +| files | [ [HitTestingFile](#hittestingfile) ] | | Yes | +| score | number | | Yes | +| segment | [HitTestingSegment](#hittestingsegment) | | Yes | +| summary | string | | Yes | +| tsne_position | | | Yes | + +#### HitTestingResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| query | [HitTestingQuery](#hittestingquery) | | Yes | +| records | [ [HitTestingRecord](#hittestingrecord) ] | | Yes | + +#### HitTestingSegment + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| answer | string | | Yes | +| completed_at | integer | | Yes | +| content | string | | Yes | +| created_at | integer | | Yes | +| created_by | string | | Yes | +| disabled_at | integer | | Yes | +| disabled_by | string | | Yes | +| document | [HitTestingDocument](#hittestingdocument) | | Yes | +| document_id | string | | Yes | +| enabled | boolean | | Yes | +| error | string | | Yes | +| hit_count | integer | | Yes | +| id | string | | Yes | +| index_node_hash | string | | Yes | +| index_node_id | string | | Yes | +| indexing_at | integer | | Yes | +| keywords | [ string ] | | Yes | +| position | integer | | Yes | +| sign_content | string | | Yes | +| status | string | | Yes | +| stopped_at | integer | | Yes | +| tokens | integer | | Yes | +| word_count | integer | | Yes | + #### HumanInputFormSubmitPayload | Name | Type | Description | Required | diff --git a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py index 823b877707..ac47be7a37 100644 --- a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py +++ b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py @@ -42,6 +42,9 @@ class MilvusConfig(BaseModel): database: str = "default" # Database name enable_hybrid_search: bool = False # Flag to enable hybrid search analyzer_params: str | None = None # Analyzer params + secure: bool = False # Enable one-way TLS to Milvus + server_pem_path: str | None = None # Path to server certificate (PEM) for TLS verification + server_name: str | None = None # Server name to verify against the certificate (SNI / CN) @model_validator(mode="before") @classmethod @@ -388,16 +391,19 @@ class MilvusVector(BaseVector): """ Initialize and return a Milvus client. """ + kwargs: dict[str, Any] = {"uri": config.uri, "db_name": config.database} if config.token: - client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database) + kwargs["token"] = config.token else: - client = MilvusClient( - uri=config.uri, - user=config.user or "", - password=config.password or "", - db_name=config.database, - ) - return client + kwargs["user"] = config.user or "" + kwargs["password"] = config.password or "" + if config.secure: + kwargs["secure"] = True + if config.server_pem_path: + kwargs["server_pem_path"] = config.server_pem_path + if config.server_name: + kwargs["server_name"] = config.server_name + return MilvusClient(**kwargs) class MilvusVectorFactory(AbstractVectorFactory): @@ -427,5 +433,8 @@ class MilvusVectorFactory(AbstractVectorFactory): database=dify_config.MILVUS_DATABASE or "", enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False, analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "", + secure=dify_config.MILVUS_SECURE, + server_pem_path=dify_config.MILVUS_SERVER_PEM_PATH, + server_name=dify_config.MILVUS_SERVER_NAME, ), ) diff --git a/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py b/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py index 730ff9f296..028842a7d6 100644 --- a/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py +++ b/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py @@ -163,6 +163,35 @@ def test_init_client_supports_token_and_user_password(milvus_module): assert user_client.init_kwargs["password"] == "Milvus" +def test_init_client_passes_tls_kwargs_when_secure(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + client = vector._init_client( + milvus_module.MilvusConfig.model_validate( + { + "uri": "https://milvus.example.com:19530", + "token": "abc", + "database": "db", + "secure": True, + "server_pem_path": "/etc/milvus/certs/server.pem", + "server_name": "milvus.example.com", + } + ) + ) + assert client.init_kwargs["secure"] is True + assert client.init_kwargs["server_pem_path"] == "/etc/milvus/certs/server.pem" + assert client.init_kwargs["server_name"] == "milvus.example.com" + + +def test_init_client_omits_tls_kwargs_when_not_secure(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + client = vector._init_client( + milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"}) + ) + assert "secure" not in client.init_kwargs + assert "server_pem_path" not in client.init_kwargs + assert "server_name" not in client.init_kwargs + + def test_init_loads_fields_when_collection_exists(milvus_module): client = milvus_module.MilvusClient(uri="http://localhost:19530") client.has_collection.return_value = True diff --git a/api/pyproject.toml b/api/pyproject.toml index 1920a9f4de..95f764aef7 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,7 +6,7 @@ requires-python = "~=3.12.0" dependencies = [ # Legacy: mature and widely deployed "bleach>=6.3.0,<7.0.0", - "boto3>=1.43.10,<2.0.0", + "boto3>=1.43.14,<2.0.0", "celery>=5.6.3,<6.0.0", "croniter>=6.2.2,<7.0.0", "dify-agent", @@ -102,10 +102,7 @@ dify-trace-weave = { workspace = true } [tool.uv] default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false -override-dependencies = [ - "litellm>=1.83.10,<2.0.0", - "pyarrow>=23.0.1,<24.0.0", -] +override-dependencies = ["litellm>=1.83.10,<2.0.0", "pyarrow>=23.0.1,<24.0.0"] [dependency-groups] diff --git a/api/services/account_service.py b/api/services/account_service.py index 344b3619f2..6705bdc4e6 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1287,6 +1287,34 @@ class TenantService: ).scalar_one_or_none() return row is not None + @staticmethod + def get_account_role_in_tenant( + session: Session | scoped_session, + account_id: uuid.UUID | str | None, + tenant_id: str, + ) -> TenantAccountRole | None: + """Return the caller's role in ``tenant_id``, or ``None`` if not a member. + + Backs ``controllers.openapi.auth.role_gate.require_workspace_role``: + the gate maps ``None`` to 404 (non-member — no cross-tenant ID leak) + and an out-of-set role to 403, so it never touches the ORM itself. + + ``None``/empty ``account_id`` short-circuits to ``None`` so SSO + bearers (no account) collapse to the non-member path. Mirrors the + session-injection style of :meth:`account_belongs_to_tenant` rather + than :meth:`get_user_role`, which loads full ``Account``/``Tenant`` + objects against the Flask-scoped ``db.session``. + """ + if not account_id: + return None + role = session.execute( + select(TenantAccountJoin.role).where( + TenantAccountJoin.tenant_id == tenant_id, + TenantAccountJoin.account_id == account_id, + ) + ).scalar_one_or_none() + return TenantAccountRole(role) if role is not None else None + @staticmethod def get_tenant_by_id(session: Session | scoped_session, tenant_id: str) -> Tenant | None: """Plain ``session.get(Tenant, tenant_id)`` — no status filter. diff --git a/api/services/legacy_model_type_migration.py b/api/services/legacy_model_type_migration.py new file mode 100644 index 0000000000..2de5e7f7f3 --- /dev/null +++ b/api/services/legacy_model_type_migration.py @@ -0,0 +1,2464 @@ +""" +Migrate legacy provider-related model_type values to canonical values. + +The grouped tables scan legacy candidates in id order, then reload the full business-key +group inside a transaction before deciding a winner row and the loser rows to delete. +Those grouped flows share the same dry-run/apply handling for group reloads, winner-loser +decisions, row updates, row deletes, and structured logging. Only some grouped tables +also add cache cleanup; that includes `provider_models` and +`provider_model_credentials`. Provider-model-credential groups extend that flow by +rewriting credential references in provider models and load-balancing configs before +removing loser credential rows. `load_balancing_model_configs` stays mostly row-level, +but it first deduplicates `name="__inherit__"` rows by business key before it +canonicalizes the remaining legacy rows independently with row-level cache cleanup. + +Tenant scheduling has two modes. When callers provide an explicit tenant list, the +service preserves the original tenant-scoped execution model and runs all selected tables +for each tenant. When callers omit `tenant_ids`, the service discovers tenant +ids per table and then runs only that table for the discovered tenants. Most +tables keep the active `model_types` filter in the discovery query, while +`load_balancing_model_configs` deliberately uses a whole-table tenant scan so +that query stays easy to understand. +""" + +from __future__ import annotations + +import io +import json +import sys +import threading +import traceback +import uuid +from collections.abc import Iterable, Sequence +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict, dataclass +from datetime import datetime +from enum import IntEnum, StrEnum +from typing import Protocol, cast + +import sqlalchemy as sa +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import Session +from sqlalchemy.sql import select + +from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from graphon.model_runtime.entities.model_entities import ModelType +from libs.datetime_utils import naive_utc_now +from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, TenantDefaultModel +from models.base import TypeBase +from models.provider import ProviderModelCredential + +type ORMModel = type[TypeBase] + + +def _json_default(value: object) -> object: + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, (IntEnum, StrEnum)): + return value.value + return value + + +def _normalize_log_value(field_name: str, value: object) -> object: + if field_name == "encrypted_config" and isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + return value + return value + + +def _normalize_log_mapping(values: dict[str, object]) -> dict[str, object]: + return {key: _normalize_log_value(key, value) for key, value in values.items()} + + +def _normalize_log_payload(value: object) -> object: + if value is None or isinstance(value, bool | int | float | str): + return value + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, (IntEnum, StrEnum)): + return value.value + if isinstance(value, dict): + return {str(key): _normalize_log_payload(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_normalize_log_payload(item) for item in value] + if isinstance(value, (set, frozenset)): + normalized_items = [_normalize_log_payload(item) for item in value] + return sorted(normalized_items, key=lambda item: json.dumps(item, sort_keys=True)) + + table_name = getattr(value, "__tablename__", None) + if isinstance(table_name, str): + return table_name + + name = getattr(value, "name", None) + if isinstance(name, str): + return name + + table = getattr(value, "table", None) + if table is not None: + referenced_table_name = getattr(table, "name", None) + if isinstance(referenced_table_name, str): + return referenced_table_name + + return f"<{type(value).__module__}.{type(value).__qualname__}>" + + +def _format_exception_stacktrace(exc: BaseException) -> str: + return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + + +@dataclass(frozen=True, slots=True) +class _RowWithRawModelType[T: TypeBase]: + row: T + raw_model_type: str + canonical_model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _CacheDeletePlan: + tenant_id: str + identity_id: str + cache_type: ProviderCredentialsCacheType + table_name: str + row_id: str + tx_id: str + business_key: _BusinessKey + + +@dataclass(frozen=True, slots=True) +class _BusinessKey: + """Marker base type for structured migration business keys.""" + + +class _HasRowId(Protocol): + id: object + + +class _HasRowIdAndUpdatedAt(_HasRowId, Protocol): + updated_at: datetime + + +def _normalize_error_code_string(value: object) -> str | None: + if isinstance(value, str): + normalized_value = value.strip().upper() + return normalized_value or None + return None + + +def _normalize_error_code_int(value: object) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, str): + normalized_value = value.strip() + if normalized_value.isdigit(): + return int(normalized_value) + return None + + +@dataclass(frozen=True, slots=True) +class _ProviderModelBusinessKey(_BusinessKey): + """unique index: unique_provider_model_name""" + + tenant_id: str + provider_name: str + model_name: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _TenantDefaultModelBusinessKey(_BusinessKey): + """unique index: unique_tenant_default_model_type""" + + tenant_id: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _ProviderModelSettingBusinessKey(_BusinessKey): + """Although `ProviderModelSetting` does not have the unique index + (tenant_id, provider_name. model_name, model_type). The acutal business logic + relies on this uniqueness property heavily. + """ + + tenant_id: str + provider_name: str + model_name: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _LoadBalancingModelConfigInheritBusinessKey(_BusinessKey): + """Business key for `name="__inherit__"` load-balancing configs.""" + + tenant_id: str + provider_name: str + model_name: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _ProviderModelCredentialBusinessKey(_BusinessKey): + """Although `ProviderModelCredential` does not have the unique index + (tenant_id, provider_name. model_name, model_type, credential_name). + The acutal business logic implies it.""" + + tenant_id: str + provider_name: str + model_name: str + credential_name: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _ProviderModelGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[ProviderModel] | None + loser_rows: list[_RowWithRawModelType[ProviderModel]] + + +@dataclass(frozen=True, slots=True) +class _TenantDefaultModelGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[TenantDefaultModel] | None + loser_rows: list[_RowWithRawModelType[TenantDefaultModel]] + + +@dataclass(frozen=True, slots=True) +class _ProviderModelSettingGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[ProviderModelSetting] | None + loser_rows: list[_RowWithRawModelType[ProviderModelSetting]] + + +@dataclass(frozen=True, slots=True) +class _LoadBalancingModelConfigInheritGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[LoadBalancingModelConfig] | None + loser_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] + + +@dataclass(frozen=True, slots=True) +class _ProviderModelReferenceRewritePlan: + row_id: str + old_credential_id: str + new_credential_id: str + + +@dataclass(frozen=True, slots=True) +class _LoadBalancingCredentialRewritePlan: + row_id: str + old_credential_id: str | None + old_name: str + old_encrypted_config: str | None + new_credential_id: str + new_name: str + new_encrypted_config: str | None + + +@dataclass(frozen=True, slots=True) +class _ProviderModelCredentialGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[ProviderModelCredential] | None + loser_rows: list[_RowWithRawModelType[ProviderModelCredential]] + provider_model_rewrites: list[_ProviderModelReferenceRewritePlan] + load_balancing_rewrites: list[_LoadBalancingCredentialRewritePlan] + + +VALID_TABLE_NAMES: tuple[str, ...] = ( + ProviderModel.__tablename__, + TenantDefaultModel.__tablename__, + ProviderModelSetting.__tablename__, + LoadBalancingModelConfig.__tablename__, + ProviderModelCredential.__tablename__, +) + +_SUPPORTED_MODEL_TYPES: tuple[ModelType, ...] = ( + ModelType.LLM, + ModelType.TEXT_EMBEDDING, + ModelType.RERANK, +) +_CANONICAL_TO_LEGACY: dict[ModelType, tuple[str, ...]] = { + ModelType.LLM: ("text-generation",), + ModelType.TEXT_EMBEDDING: ("embeddings",), + ModelType.RERANK: ("reranking",), +} +_LEGACY_TO_CANONICAL: dict[str, ModelType] = { + legacy_value: canonical_model_type + for canonical_model_type, legacy_values in _CANONICAL_TO_LEGACY.items() + for legacy_value in legacy_values +} +_POSTGRES_LOCK_TIMEOUT_SQLSTATES: frozenset[str] = frozenset({"55P03"}) +_MYSQL_LOCK_TIMEOUT_ERRNOS: frozenset[int] = frozenset({1205}) +_LOCK_TIMEOUT_FALLBACK_MESSAGES: tuple[str, ...] = ( + "canceling statement due to lock timeout", + "lock wait timeout exceeded", +) +_RAW_MODEL_TYPE_COLUMN = "_raw_model_type" + + +def _selected_legacy_values(model_types: Sequence[ModelType]) -> list[str]: + legacy_values: list[str] = [] + for model_type in model_types: + legacy_values.extend(_CANONICAL_TO_LEGACY[model_type]) + return legacy_values + + +def _selected_model_type_values(model_types: Sequence[ModelType]) -> list[str]: + model_type_values: list[str] = [] + for model_type in model_types: + model_type_values.append(model_type.value) + model_type_values.extend(_CANONICAL_TO_LEGACY[model_type]) + return list(dict.fromkeys(model_type_values)) + + +def _session_factory(engine: sa.Engine) -> Session: + return Session(bind=engine, expire_on_commit=False) + + +class _ThreadSafeLineWriter(io.TextIOBase): + """ + Serialize line-oriented writes to a shared text stream across tenant workers. + + `Migration._log_event` writes one JSON document per `print(..., flush=True)` call. The + wrapper buffers fragments per thread until a newline arrives, then emits the full line + while holding a process-local lock so concurrent tenants cannot interleave bytes. + """ + + _stream: io.TextIOBase + _lock: threading.Lock + _local: threading.local + + def __init__(self, stream: io.TextIOBase) -> None: + super().__init__() + self._stream = stream + self._lock = threading.Lock() + self._local = threading.local() + + def writable(self) -> bool: + return True + + def write(self, text: str) -> int: + if not text: + return 0 + + buffered_text = self._buffer + text + lines = buffered_text.splitlines(keepends=True) + remainder = "" + if lines and not lines[-1].endswith(("\n", "\r")): + remainder = lines.pop() + + for line in lines: + self._write_line(line) + + self._buffer = remainder + return len(text) + + def flush(self) -> None: + buffered_text = self._buffer + if buffered_text: + self._write_line(buffered_text) + self._buffer = "" + + with self._lock: + self._stream.flush() + + @property + def _buffer(self) -> str: + return cast(str, getattr(self._local, "buffer", "")) + + @_buffer.setter + def _buffer(self, value: str) -> None: + self._local.buffer = value + + def _write_line(self, text: str) -> None: + with self._lock: + self._stream.write(text) + + +class LegacyModelTypeMigrationService: + """ + Migrate legacy provider-related model_type values to canonical values. + + The command can scope the migration by table, tenant, and canonical model type. When + `provider_model_credentials` is selected, that migration also rewrites references in + `provider_models` and `load_balancing_model_configs`. Tenant migrations can run in a + thread pool; JSONL output remains line-safe through a shared synchronized writer. + + If `tenant_ids` is omitted, tenant discovery becomes table-scoped: each selected ORM + model loads its own tenant ids, then only that table is dispatched for those tenants. + Most tables keep the active model-type filter in discovery, while + `load_balancing_model_configs` intentionally uses the whole table so the tenant query + stays simple. This still avoids merging tenant ids across unrelated tables. + """ + + _engine: sa.Engine + _apply: bool + _concurrency: int + _output: io.TextIOBase + _model_types: tuple[ModelType, ...] + _orm_models: tuple[ORMModel, ...] + _tenant_ids: tuple[str, ...] | None + + def __init__( + self, + engine: sa.Engine, + *, + apply: bool = False, + concurrency: int = 1, + output: io.TextIOBase | None = None, + tables: Sequence[str] | None = None, + model_types: Sequence[ModelType] = _SUPPORTED_MODEL_TYPES, + tenant_ids: Sequence[str] | None = None, + ) -> None: + if concurrency < 1: + raise ValueError("concurrency must be greater than or equal to 1") + + self._engine = engine + self._apply = apply + self._concurrency = concurrency + self._output = cast(io.TextIOBase, sys.stdout if output is None else output) + self._model_types = tuple(dict.fromkeys(model_types)) + self._orm_models = self._resolve_models(tables) + self._tenant_ids = tuple(dict.fromkeys(tenant_ids)) if tenant_ids is not None else None + + def _resolve_models(self, tables: Sequence[str] | None) -> tuple[ORMModel, ...]: + if tables is None: + return ( + ProviderModel, + TenantDefaultModel, + ProviderModelSetting, + LoadBalancingModelConfig, + ProviderModelCredential, + ) + + ordered_models: list[ORMModel] = [] + seen_tables: set[str] = set() + for table_name in tables: + if table_name in seen_tables: + continue + seen_tables.add(table_name) + if table_name == ProviderModel.__tablename__: + ordered_models.append(ProviderModel) + elif table_name == TenantDefaultModel.__tablename__: + ordered_models.append(TenantDefaultModel) + elif table_name == ProviderModelSetting.__tablename__: + ordered_models.append(ProviderModelSetting) + elif table_name == LoadBalancingModelConfig.__tablename__: + ordered_models.append(LoadBalancingModelConfig) + elif table_name == ProviderModelCredential.__tablename__: + ordered_models.append(ProviderModelCredential) + else: + raise ValueError(f"invalid table name: {table_name}") + return tuple(ordered_models) + + def migrate(self) -> None: + output = _ThreadSafeLineWriter(self._output) + if self._tenant_ids is not None: + self._migrate_explicit_tenants(output) + return + + self._migrate_tables_with_discovered_tenants(output) + + def _migrate_explicit_tenants(self, output: io.TextIOBase) -> None: + tenant_ids = self._tenant_ids + if not tenant_ids: + return + + self._run_migrations_for_tenants(tenant_ids, self._orm_models, output) + + def _migrate_tables_with_discovered_tenants(self, output: io.TextIOBase) -> None: + for orm_model in self._orm_models: + tenant_ids = self._load_tenant_ids_for_model(orm_model) + if not tenant_ids: + continue + self._run_migrations_for_tenants(tenant_ids, (orm_model,), output) + + def _run_migrations_for_tenants( + self, + tenant_ids: Sequence[str], + orm_models: Sequence[ORMModel], + output: io.TextIOBase, + ) -> None: + if self._concurrency == 1 or len(tenant_ids) == 1: + for tenant_id in tenant_ids: + self._run_tenant_migration(tenant_id, orm_models, output) + return + + with ThreadPoolExecutor(max_workers=min(self._concurrency, len(tenant_ids))) as executor: + futures = [ + executor.submit(self._run_tenant_migration, tenant_id, orm_models, output) for tenant_id in tenant_ids + ] + for future in as_completed(futures): + future.result() + + def _run_tenant_migration( + self, + tenant_id: str, + orm_models: Sequence[ORMModel], + output: io.TextIOBase, + ) -> None: + """ + Execute one tenant migration with the shared, line-synchronized output stream. + """ + + Migration( + tenant_id=tenant_id, + engine=self._engine, + apply=self._apply, + output=output, + model_types=self._model_types, + orm_models=orm_models, + ).run() + + def _load_tenant_ids_for_model(self, orm_model: ORMModel) -> tuple[str, ...]: + """ + Discover only the tenants that have candidate rows for the current table. + + In automatic tenant mode we keep discovery table-scoped so large shared tenant + populations do not force empty work for unrelated tables. Most table queries + still apply the active `model_types` filter before scheduling migrations, while + `load_balancing_model_configs` intentionally trades a wider tenant set for a + simpler discovery query. + """ + + legacy_model_type_values = _selected_legacy_values(self._model_types) + with _session_factory(self._engine) as session: + if orm_model is ProviderModel: + tenant_ids = ( + session.execute( + select(ProviderModel.tenant_id) + .where(sa.type_coerce(ProviderModel.model_type, sa.String()).in_(legacy_model_type_values)) + .distinct() + .order_by(ProviderModel.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is TenantDefaultModel: + tenant_ids = ( + session.execute( + select(TenantDefaultModel.tenant_id) + .where(sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_(legacy_model_type_values)) + .distinct() + .order_by(TenantDefaultModel.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is ProviderModelSetting: + tenant_ids = ( + session.execute( + select(ProviderModelSetting.tenant_id) + .where( + sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_(legacy_model_type_values) + ) + .distinct() + .order_by(ProviderModelSetting.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is LoadBalancingModelConfig: + # Deliberately discover tenants from the whole table so the query stays + # easier to understand than the legacy/canonical mixed-row filter. + tenant_ids = ( + session.execute( + select(LoadBalancingModelConfig.tenant_id) + .distinct() + .order_by(LoadBalancingModelConfig.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is ProviderModelCredential: + tenant_ids = ( + session.execute( + select(ProviderModelCredential.tenant_id) + .where( + sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_( + legacy_model_type_values + ) + ) + .distinct() + .order_by(ProviderModelCredential.tenant_id.asc()) + ) + .scalars() + .all() + ) + else: + raise ValueError(f"unsupported orm model: {orm_model}") + + return tuple(tenant_ids) + + +class Migration: + """ + Execute the migration for one tenant. + + The implementation is intentionally table-specific. Each table has its own scan function + and its own apply/dry-run path so the online migration logic stays explicit and auditable. + """ + + _tenant_id: str + _engine: sa.Engine + _apply: bool + _output: io.TextIOBase + _model_types: tuple[ModelType, ...] + _orm_models: tuple[ORMModel, ...] + _batch_size: int + _lock_timeout_seconds: int + + def __init__( + self, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: Sequence[ModelType], + orm_models: Sequence[ORMModel], + ) -> None: + self._tenant_id = tenant_id + self._engine = engine + self._apply = apply + self._output = output + self._model_types = tuple(model_types) + self._orm_models = tuple(orm_models) + self._batch_size = 200 + self._lock_timeout_seconds = 5 + + def run(self) -> None: + self._log_event( + "tenant_started", + "Started tenant migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "tables": [model.__tablename__ for model in self._orm_models], + "model_types": [model_type.value for model_type in self._model_types], + }, + ) + + for orm_model in self._orm_models: + if orm_model is ProviderModel: + self._migrate_provider_models() + elif orm_model is TenantDefaultModel: + self._migrate_tenant_default_models() + elif orm_model is ProviderModelSetting: + self._migrate_provider_model_settings() + elif orm_model is LoadBalancingModelConfig: + self._migrate_load_balancing_model_configs() + elif orm_model is ProviderModelCredential: + self._migrate_provider_model_credentials() + + self._log_event( + "tenant_completed", + "Completed tenant migration.", + {"tenant_id": self._tenant_id, "apply": self._apply}, + ) + + def _selected_legacy_values(self) -> list[str]: + return _selected_legacy_values(self._model_types) + + def _selected_model_type_values(self) -> list[str]: + return _selected_model_type_values(self._model_types) + + def _allowed_values_for_canonical_model_type(self, canonical_model_type: ModelType) -> tuple[str, ...]: + return (*_CANONICAL_TO_LEGACY[canonical_model_type], canonical_model_type.value) + + def _normalize_selected_model_type(self, raw_model_type: str) -> ModelType | None: + canonical_model_type = _LEGACY_TO_CANONICAL.get(raw_model_type) + if canonical_model_type is not None: + return canonical_model_type + + try: + parsed_model_type = ModelType(raw_model_type) + except ValueError: + return None + + if parsed_model_type not in self._model_types: + return None + return parsed_model_type + + def _has_legacy_rows[T: TypeBase](self, rows: Sequence[_RowWithRawModelType[T]]) -> bool: + return any(row.raw_model_type in _LEGACY_TO_CANONICAL for row in rows) + + def _select_winner[T: TypeBase](self, rows: Sequence[_RowWithRawModelType[T]]) -> _RowWithRawModelType[T]: + return max(rows, key=lambda row: self._winner_sort_key(row.row)) + + def _winner_sort_key(self, row: TypeBase) -> tuple[datetime, str]: + typed_row = cast(_HasRowIdAndUpdatedAt, row) + return typed_row.updated_at, str(typed_row.id) + + def _row_id(self, row: TypeBase) -> str: + return str(cast(_HasRowId, row).id) + + def _new_tx_id(self) -> str: + return str(uuid.uuid4()) + + def _migrate_provider_models(self) -> None: + self._log_event( + "table_started", + "Started table migration.", + {"tenant_id": self._tenant_id, "apply": self._apply, "table_name": ProviderModel.__tablename__}, + ) + + seen_business_keys: dict[_ProviderModelBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_provider_model_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _ProviderModelBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_provider_model_group(candidate, business_key) + processed_groups += 1 + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": ProviderModel.__tablename__, + "processed_groups": processed_groups, + }, + ) + + def _load_provider_model_candidates(self, last_id: str | None) -> list[_RowWithRawModelType[ProviderModel]]: + raw_model_type = sa.type_coerce(ProviderModel.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(ProviderModel, raw_model_type) + .where( + ProviderModel.tenant_id == self._tenant_id, + sa.type_coerce(ProviderModel.model_type, sa.String()).in_(self._selected_legacy_values()), + ) + .order_by(ProviderModel.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(ProviderModel.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[ProviderModel]] = [] + for provider_model, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={"id": provider_model.id, "table_name": provider_model.__tablename__}, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_provider_model_group( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModel], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[ProviderModel]]: + raw_model_type = sa.type_coerce(ProviderModel.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(ProviderModel, raw_model_type) + .where( + ProviderModel.tenant_id == candidate.row.tenant_id, + ProviderModel.provider_name == candidate.row.provider_name, + ProviderModel.model_name == candidate.row.model_name, + sa.type_coerce(ProviderModel.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(ProviderModel.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[ProviderModel]] = [] + for provider_model, raw_value in rows: + raw_model_type_value = str(raw_value) + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model, + raw_model_type=raw_model_type_value, + canonical_model_type=_LEGACY_TO_CANONICAL.get( + raw_model_type_value, + candidate.canonical_model_type, + ), + ) + ) + return wrapped_rows + + def _build_provider_model_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModel], + *, + lock_rows: bool, + ) -> _ProviderModelGroupPlan: + rows = self._load_provider_model_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if not self._has_legacy_rows(rows): + return _ProviderModelGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[]) + + winner = self._select_winner(rows) + return _ProviderModelGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=[row for row in rows if row.row.id != winner.row.id], + ) + + def _emit_provider_model_group_plan( + self, + plan: _ProviderModelGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _BusinessKey, + ) -> None: + if plan.winner is None: + return + + cache_plans: list[_CacheDeletePlan] = [] + for loser in plan.loser_rows: + if self._apply: + session.execute(sa.delete(ProviderModel).where(ProviderModel.id == str(loser.row.id))) + self._log_row_deleted( + ProviderModel.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=str(plan.winner.row.id), + ) + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=str(loser.row.id), + cache_type=ProviderCredentialsCacheType.MODEL, + table_name=ProviderModel.__tablename__, + row_id=str(loser.row.id), + tx_id=tx_id, + business_key=business_key, + ) + ) + + if plan.winner.raw_model_type != plan.winner.canonical_model_type.value: + if self._apply: + session.execute( + sa.update(ProviderModel) + .where(ProviderModel.id == str(plan.winner.row.id)) + .values(model_type=plan.winner.canonical_model_type.value) + ) + self._log_row_updated( + ProviderModel.__tablename__, + str(plan.winner.row.id), + {"model_type": plan.winner.raw_model_type}, + {"model_type": plan.winner.canonical_model_type.value}, + tx_id=tx_id, + business_key=business_key, + ) + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=str(plan.winner.row.id), + cache_type=ProviderCredentialsCacheType.MODEL, + table_name=ProviderModel.__tablename__, + row_id=str(plan.winner.row.id), + tx_id=tx_id, + business_key=business_key, + ) + ) + + self._log_cache_plans(cache_plans, apply=self._apply) + self._log_group_processed( + ProviderModel.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_provider_model_group( + self, + candidate: _RowWithRawModelType[ProviderModel], + business_key: _ProviderModelBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_provider_model_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_provider_model_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + ProviderModel.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + + return group_row_ids + + def _migrate_tenant_default_models(self) -> None: + self._log_event( + "table_started", + "Started table migration.", + {"tenant_id": self._tenant_id, "apply": self._apply, "table_name": TenantDefaultModel.__tablename__}, + ) + + seen_business_keys: dict[_TenantDefaultModelBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_tenant_default_model_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _TenantDefaultModelBusinessKey( + tenant_id=candidate.row.tenant_id, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_tenant_default_model_group(candidate, business_key) + processed_groups += 1 + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": TenantDefaultModel.__tablename__, + "processed_groups": processed_groups, + }, + ) + + def _load_tenant_default_model_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[TenantDefaultModel]]: + raw_model_type = sa.type_coerce(TenantDefaultModel.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(TenantDefaultModel, raw_model_type) + .where( + TenantDefaultModel.tenant_id == self._tenant_id, + sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_(self._selected_legacy_values()), + ) + .order_by(TenantDefaultModel.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(TenantDefaultModel.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[TenantDefaultModel]] = [] + for tenant_default_model, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={"id": tenant_default_model.id, "table_name": tenant_default_model.__tablename__}, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=tenant_default_model, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_tenant_default_model_group( + self, + session: Session, + candidate: _RowWithRawModelType[TenantDefaultModel], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[TenantDefaultModel]]: + raw_model_type = sa.type_coerce(TenantDefaultModel.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(TenantDefaultModel, raw_model_type) + .where( + TenantDefaultModel.tenant_id == candidate.row.tenant_id, + sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(TenantDefaultModel.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[TenantDefaultModel]] = [] + for tenant_default_model, raw_value in rows: + raw_model_type_value = str(raw_value) + wrapped_rows.append( + _RowWithRawModelType( + row=tenant_default_model, + raw_model_type=raw_model_type_value, + canonical_model_type=_LEGACY_TO_CANONICAL.get( + raw_model_type_value, + candidate.canonical_model_type, + ), + ) + ) + return wrapped_rows + + def _build_tenant_default_model_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[TenantDefaultModel], + *, + lock_rows: bool, + ) -> _TenantDefaultModelGroupPlan: + rows = self._load_tenant_default_model_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if not self._has_legacy_rows(rows): + return _TenantDefaultModelGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[]) + + winner = self._select_winner(rows) + return _TenantDefaultModelGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=[row for row in rows if row.row.id != winner.row.id], + ) + + def _emit_tenant_default_model_group_plan( + self, + plan: _TenantDefaultModelGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _BusinessKey, + ) -> None: + if plan.winner is None: + return + + for loser in plan.loser_rows: + if self._apply: + session.execute(sa.delete(TenantDefaultModel).where(TenantDefaultModel.id == str(loser.row.id))) + self._log_row_deleted( + TenantDefaultModel.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=str(plan.winner.row.id), + ) + if plan.winner.raw_model_type != plan.winner.canonical_model_type.value: + if self._apply: + session.execute( + sa.update(TenantDefaultModel) + .where(TenantDefaultModel.id == str(plan.winner.row.id)) + .values(model_type=plan.winner.canonical_model_type.value) + ) + self._log_row_updated( + TenantDefaultModel.__tablename__, + str(plan.winner.row.id), + {"model_type": plan.winner.raw_model_type}, + {"model_type": plan.winner.canonical_model_type.value}, + tx_id=tx_id, + business_key=business_key, + ) + + self._log_group_processed( + TenantDefaultModel.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_tenant_default_model_group( + self, + candidate: _RowWithRawModelType[TenantDefaultModel], + business_key: _TenantDefaultModelBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_tenant_default_model_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_tenant_default_model_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + TenantDefaultModel.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + return group_row_ids + + def _migrate_provider_model_settings(self) -> None: + self._log_event( + "table_started", + "Started table migration.", + {"tenant_id": self._tenant_id, "apply": self._apply, "table_name": ProviderModelSetting.__tablename__}, + ) + + seen_business_keys: dict[_ProviderModelSettingBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_provider_model_setting_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _ProviderModelSettingBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_provider_model_setting_group(candidate, business_key) + processed_groups += 1 + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": ProviderModelSetting.__tablename__, + "processed_groups": processed_groups, + }, + ) + + def _load_provider_model_setting_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[ProviderModelSetting]]: + raw_model_type = sa.type_coerce(ProviderModelSetting.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(ProviderModelSetting, raw_model_type) + .where( + ProviderModelSetting.tenant_id == self._tenant_id, + sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_(self._selected_legacy_values()), + ) + .order_by(ProviderModelSetting.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(ProviderModelSetting.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[ProviderModelSetting]] = [] + for provider_model_setting, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={"id": provider_model_setting.id, "table_name": provider_model_setting.__tablename__}, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model_setting, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_provider_model_setting_group( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModelSetting], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[ProviderModelSetting]]: + raw_model_type = sa.type_coerce(ProviderModelSetting.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(ProviderModelSetting, raw_model_type) + .where( + ProviderModelSetting.tenant_id == candidate.row.tenant_id, + ProviderModelSetting.provider_name == candidate.row.provider_name, + ProviderModelSetting.model_name == candidate.row.model_name, + sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(ProviderModelSetting.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[ProviderModelSetting]] = [] + for provider_model_setting, raw_value in rows: + raw_model_type_value = str(raw_value) + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model_setting, + raw_model_type=raw_model_type_value, + canonical_model_type=_LEGACY_TO_CANONICAL.get( + raw_model_type_value, + candidate.canonical_model_type, + ), + ) + ) + return wrapped_rows + + def _build_provider_model_setting_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModelSetting], + *, + lock_rows: bool, + ) -> _ProviderModelSettingGroupPlan: + rows = self._load_provider_model_setting_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if not self._has_legacy_rows(rows): + return _ProviderModelSettingGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[]) + + winner = self._select_winner(rows) + return _ProviderModelSettingGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=[row for row in rows if row.row.id != winner.row.id], + ) + + def _emit_provider_model_setting_group_plan( + self, + plan: _ProviderModelSettingGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _BusinessKey, + ) -> None: + if plan.winner is None: + return + + for loser in plan.loser_rows: + if self._apply: + session.execute(sa.delete(ProviderModelSetting).where(ProviderModelSetting.id == str(loser.row.id))) + self._log_row_deleted( + ProviderModelSetting.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=str(plan.winner.row.id), + ) + + if plan.winner.raw_model_type != plan.winner.canonical_model_type.value: + if self._apply: + session.execute( + sa.update(ProviderModelSetting) + .where(ProviderModelSetting.id == str(plan.winner.row.id)) + .values(model_type=plan.winner.canonical_model_type.value) + ) + self._log_row_updated( + ProviderModelSetting.__tablename__, + str(plan.winner.row.id), + {"model_type": plan.winner.raw_model_type}, + {"model_type": plan.winner.canonical_model_type.value}, + tx_id=tx_id, + business_key=business_key, + ) + + self._log_group_processed( + ProviderModelSetting.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_provider_model_setting_group( + self, + candidate: _RowWithRawModelType[ProviderModelSetting], + business_key: _ProviderModelSettingBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_provider_model_setting_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_provider_model_setting_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + ProviderModelSetting.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + return group_row_ids + + def _migrate_load_balancing_model_configs(self) -> None: + """ + Migrate load-balancing configs row by row. + + This table first deduplicates `name="__inherit__"` rows per normalized + `(tenant_id, provider_name, model_name, model_type)` business key, then + canonicalizes the remaining legacy rows independently. The pre-pass must run + first so a legacy/canonical `__inherit__` pair keeps only the newest row before + the row-level canonicalization would collapse them onto the same canonical key. + """ + self._log_event( + "table_started", + "Started table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": LoadBalancingModelConfig.__tablename__, + }, + ) + + processed_inherit_groups = self._deduplicate_inherit_load_balancing_model_configs() + processed_rows = 0 + last_id: str | None = None + + while True: + candidates = self._load_load_balancing_model_config_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + processed_rows += 1 + self._process_load_balancing_model_config_row(candidate) + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": LoadBalancingModelConfig.__tablename__, + "processed_inherit_groups": processed_inherit_groups, + "processed_rows": processed_rows, + }, + ) + + def _deduplicate_inherit_load_balancing_model_configs(self) -> int: + seen_business_keys: dict[_LoadBalancingModelConfigInheritBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_load_balancing_inherit_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _LoadBalancingModelConfigInheritBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_load_balancing_inherit_group(candidate, business_key) + processed_groups += 1 + + return processed_groups + + def _load_load_balancing_inherit_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]: + raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(LoadBalancingModelConfig, raw_model_type) + .where( + LoadBalancingModelConfig.tenant_id == self._tenant_id, + LoadBalancingModelConfig.name == "__inherit__", + sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).in_( + self._selected_model_type_values() + ), + ) + .order_by(LoadBalancingModelConfig.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(LoadBalancingModelConfig.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] = [] + for load_balancing_model_config, raw_value in rows: + raw_model_type_value = str(raw_value) + canonical_model_type = self._normalize_selected_model_type(raw_model_type_value) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={ + "id": load_balancing_model_config.id, + "table_name": load_balancing_model_config.__tablename__, + }, + ) + continue + + wrapped_rows.append( + _RowWithRawModelType( + row=load_balancing_model_config, + raw_model_type=raw_model_type_value, + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_load_balancing_inherit_group( + self, + session: Session, + candidate: _RowWithRawModelType[LoadBalancingModelConfig], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]: + raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(LoadBalancingModelConfig, raw_model_type) + .where( + LoadBalancingModelConfig.tenant_id == candidate.row.tenant_id, + LoadBalancingModelConfig.provider_name == candidate.row.provider_name, + LoadBalancingModelConfig.model_name == candidate.row.model_name, + LoadBalancingModelConfig.name == "__inherit__", + sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(LoadBalancingModelConfig.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] = [] + for load_balancing_model_config, raw_value in rows: + raw_model_type_value = str(raw_value) + canonical_model_type = self._normalize_selected_model_type(raw_model_type_value) + if canonical_model_type is None: + continue + wrapped_rows.append( + _RowWithRawModelType( + row=load_balancing_model_config, + raw_model_type=raw_model_type_value, + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _build_load_balancing_inherit_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[LoadBalancingModelConfig], + *, + lock_rows: bool, + ) -> _LoadBalancingModelConfigInheritGroupPlan: + rows = self._load_load_balancing_inherit_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if len(rows) <= 1: + return _LoadBalancingModelConfigInheritGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[]) + + winner = self._select_winner(rows) + return _LoadBalancingModelConfigInheritGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=[row for row in rows if row.row.id != winner.row.id], + ) + + def _emit_load_balancing_inherit_group_plan( + self, + plan: _LoadBalancingModelConfigInheritGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _LoadBalancingModelConfigInheritBusinessKey, + ) -> None: + if plan.winner is None: + return + + cache_plans: list[_CacheDeletePlan] = [] + for loser in plan.loser_rows: + loser_row_id = str(loser.row.id) + if self._apply: + session.execute(sa.delete(LoadBalancingModelConfig).where(LoadBalancingModelConfig.id == loser_row_id)) + self._log_row_deleted( + LoadBalancingModelConfig.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=str(plan.winner.row.id), + ) + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=loser_row_id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + table_name=LoadBalancingModelConfig.__tablename__, + row_id=loser_row_id, + tx_id=tx_id, + business_key=business_key, + ) + ) + + self._log_cache_plans(cache_plans, apply=self._apply) + self._log_group_processed( + LoadBalancingModelConfig.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_load_balancing_inherit_group( + self, + candidate: _RowWithRawModelType[LoadBalancingModelConfig], + business_key: _LoadBalancingModelConfigInheritBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_load_balancing_inherit_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_load_balancing_inherit_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + LoadBalancingModelConfig.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + + return group_row_ids + + def _load_load_balancing_model_config_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]: + raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(LoadBalancingModelConfig, raw_model_type) + .where( + LoadBalancingModelConfig.tenant_id == self._tenant_id, + sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).in_( + self._selected_legacy_values() + ), + ) + .order_by(LoadBalancingModelConfig.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(LoadBalancingModelConfig.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] = [] + for load_balancing_model_config, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={ + "id": load_balancing_model_config.id, + "table_name": load_balancing_model_config.__tablename__, + }, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=load_balancing_model_config, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _reload_load_balancing_model_config_candidate( + self, + session: Session, + candidate: _RowWithRawModelType[LoadBalancingModelConfig], + *, + lock_rows: bool, + ) -> _RowWithRawModelType[LoadBalancingModelConfig] | None: + raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = select(LoadBalancingModelConfig, raw_model_type).where( + LoadBalancingModelConfig.id == candidate.row.id, + LoadBalancingModelConfig.tenant_id == self._tenant_id, + ) + if lock_rows: + stmt = stmt.with_for_update() + + row = session.execute(stmt).first() + if row is None: + return None + + load_balancing_model_config, raw_value = row + raw_model_type_value = str(raw_value) + canonical_model_type = _LEGACY_TO_CANONICAL.get(raw_model_type_value) + if canonical_model_type is None: + return None + + return _RowWithRawModelType( + row=load_balancing_model_config, + raw_model_type=raw_model_type_value, + canonical_model_type=canonical_model_type, + ) + + def _log_load_balancing_model_config_cache_cleanup( + self, + *, + row_id: str, + tx_id: str, + ) -> None: + attrs = { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": LoadBalancingModelConfig.__tablename__, + "id": row_id, + "cache_type": ProviderCredentialsCacheType.LOAD_BALANCING_MODEL.value, + "tx_id": tx_id, + } + if not self._apply: + self._log_event( + "cache_delete_planned", + "Would delete related cache entry in apply mode.", + attrs, + ) + return + + try: + ProviderCredentialsCache( + tenant_id=self._tenant_id, + identity_id=row_id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ).delete() + self._log_event("cache_deleted", "Deleted related cache entry.", attrs) + except Exception as exc: + self._log_exception_event( + "cache_delete_failed", + "Failed to delete related cache entry.", + attrs, + exc, + ) + + def _process_load_balancing_model_config_row( + self, candidate: _RowWithRawModelType[LoadBalancingModelConfig] + ) -> None: + tx_id = self._new_tx_id() + processed_row_id: str | None = None + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + current_row = self._reload_load_balancing_model_config_candidate(session, candidate, lock_rows=True) + if current_row is None: + return + processed_row_id = str(current_row.row.id) + + if self._apply: + session.execute( + sa.update(LoadBalancingModelConfig) + .where(LoadBalancingModelConfig.id == processed_row_id) + .values(model_type=current_row.canonical_model_type.value) + ) + self._log_row_updated( + LoadBalancingModelConfig.__tablename__, + processed_row_id, + {"model_type": current_row.raw_model_type}, + {"model_type": current_row.canonical_model_type.value}, + tx_id=tx_id, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + LoadBalancingModelConfig.__tablename__, + str(candidate.row.id), + tx_id, + None, + exc, + ) + return + raise + + if processed_row_id is not None: + self._log_load_balancing_model_config_cache_cleanup(row_id=processed_row_id, tx_id=tx_id) + + def _migrate_provider_model_credentials(self) -> None: + self._log_event( + "table_started", + "Started table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": ProviderModelCredential.__tablename__, + }, + ) + + seen_business_keys: dict[_ProviderModelCredentialBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_provider_model_credential_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _ProviderModelCredentialBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + credential_name=candidate.row.credential_name, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_provider_model_credential_group( + candidate, + business_key, + ) + processed_groups += 1 + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": ProviderModelCredential.__tablename__, + "processed_groups": processed_groups, + }, + ) + + def _load_provider_model_credential_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[ProviderModelCredential]]: + raw_model_type = sa.type_coerce(ProviderModelCredential.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(ProviderModelCredential, raw_model_type) + .where( + ProviderModelCredential.tenant_id == self._tenant_id, + sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_(self._selected_legacy_values()), + ) + .order_by(ProviderModelCredential.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(ProviderModelCredential.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[ProviderModelCredential]] = [] + for provider_model_credential, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={"id": provider_model_credential.id, "table_name": provider_model_credential.__tablename__}, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model_credential, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_provider_model_credential_group( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModelCredential], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[ProviderModelCredential]]: + raw_model_type = sa.type_coerce(ProviderModelCredential.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(ProviderModelCredential, raw_model_type) + .where( + ProviderModelCredential.tenant_id == candidate.row.tenant_id, + ProviderModelCredential.provider_name == candidate.row.provider_name, + ProviderModelCredential.model_name == candidate.row.model_name, + ProviderModelCredential.credential_name == candidate.row.credential_name, + sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(ProviderModelCredential.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[ProviderModelCredential]] = [] + for provider_model_credential, raw_value in rows: + raw_model_type_value = str(raw_value) + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model_credential, + raw_model_type=raw_model_type_value, + canonical_model_type=_LEGACY_TO_CANONICAL.get( + raw_model_type_value, + candidate.canonical_model_type, + ), + ) + ) + return wrapped_rows + + def _build_provider_model_credential_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModelCredential], + *, + lock_rows: bool, + ) -> _ProviderModelCredentialGroupPlan: + rows = self._load_provider_model_credential_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if not self._has_legacy_rows(rows): + return _ProviderModelCredentialGroupPlan( + group_row_ids=group_row_ids, + winner=None, + loser_rows=[], + provider_model_rewrites=[], + load_balancing_rewrites=[], + ) + + winner = self._select_winner(rows) + loser_rows = [row for row in rows if row.row.id != winner.row.id] + return _ProviderModelCredentialGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=loser_rows, + provider_model_rewrites=self._plan_provider_model_reference_rewrites( + session, + winner, + loser_rows, + lock_rows=lock_rows, + ), + load_balancing_rewrites=self._plan_load_balancing_reference_rewrites( + session, + winner, + loser_rows, + lock_rows=lock_rows, + ), + ) + + def _emit_provider_model_reference_rewrites( + self, + session: Session, + rewrites: Sequence[_ProviderModelReferenceRewritePlan], + *, + winner_credential_id: str, + loser_credential_ids: Sequence[str], + tx_id: str, + business_key: _BusinessKey, + ) -> list[_CacheDeletePlan]: + cache_plans: list[_CacheDeletePlan] = [] + for rewrite in rewrites: + if self._apply: + session.execute( + sa.update(ProviderModel) + .where(ProviderModel.id == rewrite.row_id) + .values(credential_id=rewrite.new_credential_id) + ) + self._log_row_updated( + ProviderModel.__tablename__, + rewrite.row_id, + {"credential_id": rewrite.old_credential_id}, + {"credential_id": rewrite.new_credential_id}, + tx_id=tx_id, + business_key=business_key, + rewrite_source={ + "rewrite_kind": "credential_reference", + "winner_credential_id": winner_credential_id, + "loser_credential_ids": list(loser_credential_ids), + }, + ) + + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=rewrite.row_id, + cache_type=ProviderCredentialsCacheType.MODEL, + table_name=ProviderModel.__tablename__, + row_id=rewrite.row_id, + tx_id=tx_id, + business_key=business_key, + ) + ) + return cache_plans + + def _emit_load_balancing_reference_rewrites( + self, + session: Session, + rewrites: Sequence[_LoadBalancingCredentialRewritePlan], + *, + winner_credential_id: str, + loser_credential_ids: Sequence[str], + tx_id: str, + business_key: _BusinessKey, + ) -> list[_CacheDeletePlan]: + cache_plans: list[_CacheDeletePlan] = [] + for rewrite in rewrites: + if self._apply: + session.execute( + sa.update(LoadBalancingModelConfig) + .where(LoadBalancingModelConfig.id == rewrite.row_id) + .values( + credential_id=rewrite.new_credential_id, + name=rewrite.new_name, + encrypted_config=rewrite.new_encrypted_config, + ) + ) + + self._log_row_updated( + LoadBalancingModelConfig.__tablename__, + rewrite.row_id, + { + "credential_id": rewrite.old_credential_id, + "encrypted_config": rewrite.old_encrypted_config, + "name": rewrite.old_name, + }, + { + "credential_id": rewrite.new_credential_id, + "encrypted_config": rewrite.new_encrypted_config, + "name": rewrite.new_name, + }, + tx_id=tx_id, + business_key=business_key, + rewrite_source={ + "rewrite_kind": "credential_reference", + "winner_credential_id": winner_credential_id, + "loser_credential_ids": list(loser_credential_ids), + }, + ) + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=rewrite.row_id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + table_name=LoadBalancingModelConfig.__tablename__, + row_id=rewrite.row_id, + tx_id=tx_id, + business_key=business_key, + ) + ) + return cache_plans + + def _emit_provider_model_credential_group_plan( + self, + plan: _ProviderModelCredentialGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _BusinessKey, + ) -> None: + if plan.winner is None: + return + + loser_credential_ids = [str(row.row.id) for row in plan.loser_rows] + winner_credential_id = str(plan.winner.row.id) + cache_plans: list[_CacheDeletePlan] = [] + cache_plans.extend( + self._emit_provider_model_reference_rewrites( + session, + plan.provider_model_rewrites, + winner_credential_id=winner_credential_id, + loser_credential_ids=loser_credential_ids, + tx_id=tx_id, + business_key=business_key, + ) + ) + cache_plans.extend( + self._emit_load_balancing_reference_rewrites( + session, + plan.load_balancing_rewrites, + winner_credential_id=winner_credential_id, + loser_credential_ids=loser_credential_ids, + tx_id=tx_id, + business_key=business_key, + ) + ) + + for loser in plan.loser_rows: + if self._apply: + session.execute( + sa.delete(ProviderModelCredential).where(ProviderModelCredential.id == str(loser.row.id)) + ) + self._log_row_deleted( + ProviderModelCredential.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=winner_credential_id, + ) + + if plan.winner.raw_model_type != plan.winner.canonical_model_type.value: + if self._apply: + session.execute( + sa.update(ProviderModelCredential) + .where(ProviderModelCredential.id == winner_credential_id) + .values(model_type=plan.winner.canonical_model_type.value) + ) + self._log_row_updated( + ProviderModelCredential.__tablename__, + winner_credential_id, + {"model_type": plan.winner.raw_model_type}, + {"model_type": plan.winner.canonical_model_type.value}, + tx_id=tx_id, + business_key=business_key, + ) + + self._log_cache_plans(cache_plans, apply=self._apply) + self._log_group_processed( + ProviderModelCredential.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_provider_model_credential_group( + self, + candidate: _RowWithRawModelType[ProviderModelCredential], + business_key: _ProviderModelCredentialBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_provider_model_credential_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_provider_model_credential_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + ProviderModelCredential.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + + return group_row_ids + + def _plan_provider_model_reference_rewrites( + self, + session: Session, + winner: _RowWithRawModelType[ProviderModelCredential], + loser_rows: Sequence[_RowWithRawModelType[ProviderModelCredential]], + *, + lock_rows: bool, + ) -> list[_ProviderModelReferenceRewritePlan]: + loser_ids = [str(row.row.id) for row in loser_rows] + if not loser_ids: + return [] + + stmt = ( + select(ProviderModel) + .where( + ProviderModel.tenant_id == self._tenant_id, + ProviderModel.credential_id.in_(loser_ids), + ) + .order_by(ProviderModel.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rewrite_plans: list[_ProviderModelReferenceRewritePlan] = [] + provider_models = session.execute(stmt).scalars().all() + for provider_model in provider_models: + rewrite_plans.append( + _ProviderModelReferenceRewritePlan( + row_id=str(provider_model.id), + old_credential_id=str(provider_model.credential_id), + new_credential_id=str(winner.row.id), + ) + ) + return rewrite_plans + + def _plan_load_balancing_reference_rewrites( + self, + session: Session, + winner: _RowWithRawModelType[ProviderModelCredential], + loser_rows: Sequence[_RowWithRawModelType[ProviderModelCredential]], + *, + lock_rows: bool, + ) -> list[_LoadBalancingCredentialRewritePlan]: + loser_ids = [str(row.row.id) for row in loser_rows] + if not loser_ids: + return [] + + stmt = ( + select(LoadBalancingModelConfig) + .where( + LoadBalancingModelConfig.tenant_id == self._tenant_id, + LoadBalancingModelConfig.credential_id.in_(loser_ids), + ) + .order_by(LoadBalancingModelConfig.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + winner_credential = winner.row + winner_credential_id = str(winner_credential.id) + winner_credential_name = winner_credential.credential_name + winner_encrypted_config = winner_credential.encrypted_config + + rewrite_plans: list[_LoadBalancingCredentialRewritePlan] = [] + load_balancing_model_configs = session.execute(stmt).scalars().all() + for load_balancing_model_config in load_balancing_model_configs: + rewrite_plans.append( + _LoadBalancingCredentialRewritePlan( + row_id=str(load_balancing_model_config.id), + old_credential_id=load_balancing_model_config.credential_id, + old_name=load_balancing_model_config.name, + old_encrypted_config=load_balancing_model_config.encrypted_config, + new_credential_id=winner_credential_id, + new_name=winner_credential_name, + new_encrypted_config=winner_encrypted_config, + ) + ) + return rewrite_plans + + def _configure_lock_timeout(self, session: Session) -> None: + dialect_name = session.get_bind().dialect.name + if dialect_name == "postgresql": + session.execute(sa.text("SET LOCAL lock_timeout = :timeout"), {"timeout": f"{self._lock_timeout_seconds}s"}) + return + if dialect_name == "mysql": + session.execute( + sa.text("SET SESSION innodb_lock_wait_timeout = :timeout"), + {"timeout": self._lock_timeout_seconds}, + ) + session.execute( + sa.text("SET SESSION lock_wait_timeout = :timeout"), + {"timeout": self._lock_timeout_seconds}, + ) + + def _is_lock_timeout_error(self, exc: OperationalError) -> bool: + orig = exc.orig + structured_string_codes: set[str] = set() + structured_int_codes: set[int] = set() + + if orig is not None: + for raw_code in ( + getattr(orig, "sqlstate", None), + getattr(orig, "pgcode", None), + getattr(orig, "code", None), + getattr(orig, "errno", None), + ): + normalized_string_code = _normalize_error_code_string(raw_code) + if normalized_string_code is not None: + structured_string_codes.add(normalized_string_code) + + normalized_int_code = _normalize_error_code_int(raw_code) + if normalized_int_code is not None: + structured_int_codes.add(normalized_int_code) + + raw_args = getattr(orig, "args", None) + if isinstance(raw_args, tuple | list) and raw_args: + first_arg = raw_args[0] + normalized_string_code = _normalize_error_code_string(first_arg) + if normalized_string_code is not None: + structured_string_codes.add(normalized_string_code) + + normalized_int_code = _normalize_error_code_int(first_arg) + if normalized_int_code is not None: + structured_int_codes.add(normalized_int_code) + + if structured_string_codes & _POSTGRES_LOCK_TIMEOUT_SQLSTATES: + return True + if structured_int_codes & _MYSQL_LOCK_TIMEOUT_ERRNOS: + return True + + error_message = str(orig if orig is not None else exc).lower() + return any(message in error_message for message in _LOCK_TIMEOUT_FALLBACK_MESSAGES) + + def _log_lock_timeout( + self, + table_name: str, + row_id: str, + tx_id: str, + business_key: _BusinessKey | None, + exc: OperationalError, + ) -> None: + attrs: dict[str, object] = { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": table_name, + "id": row_id, + "tx_id": tx_id, + } + if business_key is not None: + attrs["business_key"] = self._business_key_to_dict(business_key) + self._log_exception_event( + "lock_timeout_skipped", + "Skipped transaction because row lock timed out.", + attrs, + exc, + ) + + def _business_key_to_dict(self, business_key: _BusinessKey) -> dict[str, object]: + return cast(dict[str, object], asdict(business_key)) + + def _row_to_dict(self, row: TypeBase, *, raw_model_type: str | None = None) -> dict[str, object]: + mapper = sa.inspect(row).mapper + row_dict = {column.key: row.__dict__[column.key] for column in mapper.column_attrs} + if raw_model_type is not None and "model_type" in row_dict: + row_dict["model_type"] = raw_model_type + return _normalize_log_mapping(row_dict) + + def _log_row_deleted[T: TypeBase]( + self, + table_name: str, + row: _RowWithRawModelType[T], + *, + tx_id: str, + business_key: _BusinessKey, + related_winner_id: str, + ) -> None: + self._log_event( + "row_deleted", + "Deleted loser row during canonicalization.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": table_name, + "id": self._row_id(row.row), + "tx_id": tx_id, + "business_key": self._business_key_to_dict(business_key), + "merge_winner_id": related_winner_id, + "row": self._row_to_dict(row.row, raw_model_type=row.raw_model_type), + }, + ) + + def _log_row_updated( + self, + table_name: str, + row_id: str, + old_values: dict[str, object], + new_values: dict[str, object], + *, + tx_id: str, + business_key: _BusinessKey | None = None, + rewrite_source: dict[str, object] | None = None, + ) -> None: + attrs: dict[str, object] = { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": table_name, + "id": row_id, + "tx_id": tx_id, + "old_values": _normalize_log_mapping(old_values), + "new_values": _normalize_log_mapping(new_values), + } + if business_key is not None: + attrs["business_key"] = self._business_key_to_dict(business_key) + if rewrite_source is not None: + attrs["rewrite_source"] = rewrite_source + self._log_event("row_updated", "Updated row values during canonicalization.", attrs) + + def _log_group_processed( + self, + table_name: str, + business_key: _BusinessKey, + group_row_ids: Sequence[str], + *, + tx_id: str, + ) -> None: + self._log_event( + "group_processed", + "Processed business-key group during canonicalization.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": table_name, + "business_key": self._business_key_to_dict(business_key), + "group_row_ids": list(group_row_ids), + "tx_id": tx_id, + }, + ) + + def _log_cache_plans(self, cache_plans: Iterable[_CacheDeletePlan], *, apply: bool) -> None: + for cache_plan in cache_plans: + if apply: + try: + ProviderCredentialsCache( + tenant_id=cache_plan.tenant_id, + identity_id=cache_plan.identity_id, + cache_type=cache_plan.cache_type, + ).delete() + self._log_event( + "cache_deleted", + "Deleted related cache entry.", + { + "tenant_id": cache_plan.tenant_id, + "apply": apply, + "table_name": cache_plan.table_name, + "id": cache_plan.row_id, + "cache_type": cache_plan.cache_type.value, + "tx_id": cache_plan.tx_id, + "business_key": self._business_key_to_dict(cache_plan.business_key), + }, + ) + except Exception as exc: + self._log_exception_event( + "cache_delete_failed", + "Failed to delete related cache entry.", + { + "tenant_id": cache_plan.tenant_id, + "apply": apply, + "table_name": cache_plan.table_name, + "id": cache_plan.row_id, + "cache_type": cache_plan.cache_type.value, + "tx_id": cache_plan.tx_id, + "business_key": self._business_key_to_dict(cache_plan.business_key), + }, + exc, + ) + else: + self._log_event( + "cache_delete_planned", + "Would delete related cache entry in apply mode.", + { + "tenant_id": cache_plan.tenant_id, + "apply": apply, + "table_name": cache_plan.table_name, + "id": cache_plan.row_id, + "cache_type": cache_plan.cache_type.value, + "tx_id": cache_plan.tx_id, + "business_key": self._business_key_to_dict(cache_plan.business_key), + }, + ) + + def _log_exception_event( + self, + event: str, + message: str, + attrs: dict[str, object], + exc: BaseException, + ) -> None: + self._log_event( + event, + message, + { + **attrs, + "error": str(exc), + "stacktrace": _format_exception_stacktrace(exc), + }, + ) + + def _log_event(self, event: str, message: str, attrs: dict[str, object]) -> None: + record = { + "event": event, + "message": message, + "attrs": _normalize_log_payload(attrs), + "ts": naive_utc_now().isoformat(), + } + print(json.dumps(record, default=_json_default), file=self._output, flush=True) + + +def load_tenant_ids_from_file(path: str) -> list[str]: + """ + Load tenant ids from a plain-text file, one tenant id per line. + """ + + tenant_ids: list[str] = [] + seen_tenant_ids: set[str] = set() + with open(path, encoding="utf-8") as file: + for raw_line in file: + tenant_id = raw_line.strip() + if not tenant_id or tenant_id in seen_tenant_ids: + continue + seen_tenant_ids.add(tenant_id) + tenant_ids.append(tenant_id) + return tenant_ids diff --git a/api/services/workflow/inspector_events.py b/api/services/workflow/inspector_events.py new file mode 100644 index 0000000000..3a7c09fb7e --- /dev/null +++ b/api/services/workflow/inspector_events.py @@ -0,0 +1,194 @@ +"""Inspector pub/sub fanout for live workflow run updates (Stage 4 §8.5). + +The Node Output Inspector exposes a Server-Sent Events stream alongside its +three REST endpoints so the frontend can render per-output progress without +DB polling. This module owns the redis pub/sub channel that connects the two +sides: + +* :func:`publish_node_changed` / :func:`publish_workflow_completed` — + invoked by :class:`core.app.workflow.layers.persistence.WorkflowPersistenceLayer` + at the very end of each handler, after the DB write has already + succeeded. Publish failures are swallowed so the engine never trips on a + flaky redis connection. +* :func:`subscribe` — async iterator the SSE endpoint consumes. + +Channel layout +-------------- +``dify:inspector:workflow_run:{workflow_run_id}`` + +One channel per workflow run; the SSE endpoint subscribes for the lifetime of +the run and unsubscribes on the terminal event. Multiple clients can attach +to the same run safely — redis pub/sub fans every message out to every +listener. + +The message envelope intentionally carries only the *delta* needed to invalidate +a slice of the inspector view; the SSE handler re-reads the canonical +``WorkflowNodeExecutionModel`` row from the DB so we never serialize stale +state across the wire. This means messages stay tiny (~150 bytes) and the +inspector view stays consistent even if a publisher races persistence. + +Decision D-5: the on-wire SSE envelope ``{event, data, id}`` is shared with +the babysit chat stream; this module only emits the *internal* pub/sub +message — the SSE controller turns it into the public envelope. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Iterator +from dataclasses import asdict, dataclass +from typing import Final, Literal + +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +# ────────────────────────────────────────────────────────────────────────────── +# Channel naming +# ────────────────────────────────────────────────────────────────────────────── + + +_CHANNEL_PREFIX: Final = "dify:inspector:workflow_run" + + +def channel_for(workflow_run_id: str) -> str: + """Return the pub/sub channel name for ``workflow_run_id``. + + Kept as a module-level helper so tests can pin the channel without + reaching into the publish/subscribe code paths. + """ + return f"{_CHANNEL_PREFIX}:{workflow_run_id}" + + +# ────────────────────────────────────────────────────────────────────────────── +# Message envelope +# ────────────────────────────────────────────────────────────────────────────── + +#: Tags discriminating the wire-level message kinds. Kept narrow so the SSE +#: controller can pattern-match exhaustively. +InspectorMessageKind = Literal["node_changed", "workflow_completed"] + + +@dataclass(frozen=True, slots=True) +class InspectorMessage: + """Minimal delta carried across the pub/sub channel. + + ``node_id`` is set only for ``node_changed`` messages; ``status`` is the + coarse string status straight from the persistence layer (``"running"`` / + ``"succeeded"`` / ``"failed"`` for nodes, plus ``"succeeded"`` / + ``"failed"`` / ``"partial_succeeded"`` / ``"stopped"`` for workflow runs). + """ + + kind: InspectorMessageKind + workflow_run_id: str + node_id: str | None = None + status: str | None = None + + def to_json(self) -> str: + return json.dumps(asdict(self), ensure_ascii=False) + + @classmethod + def from_json(cls, blob: str) -> InspectorMessage | None: + """Decode a payload, returning ``None`` for any shape we can't trust.""" + try: + decoded = json.loads(blob) + except (json.JSONDecodeError, TypeError): + return None + if not isinstance(decoded, dict): + return None + kind = decoded.get("kind") + if kind not in ("node_changed", "workflow_completed"): + return None + workflow_run_id = decoded.get("workflow_run_id") + if not isinstance(workflow_run_id, str) or not workflow_run_id: + return None + node_id = decoded.get("node_id") + if node_id is not None and not isinstance(node_id, str): + return None + status = decoded.get("status") + if status is not None and not isinstance(status, str): + return None + return cls(kind=kind, workflow_run_id=workflow_run_id, node_id=node_id, status=status) + + +# ────────────────────────────────────────────────────────────────────────────── +# Publisher (called from the persistence layer) +# ────────────────────────────────────────────────────────────────────────────── + + +def _publish(message: InspectorMessage) -> None: + """Best-effort fire-and-forget publish. + + Persistence runs inside the workflow engine thread; we never want a redis + glitch to break the workflow. Any exception is logged at debug level so + operators still see them when they grep, but the engine keeps running. + """ + try: + redis_client.publish(channel_for(message.workflow_run_id), message.to_json()) + except Exception: + logger.debug("InspectorEventPublisher: publish failed for %s", message.workflow_run_id, exc_info=True) + + +def publish_node_changed(*, workflow_run_id: str, node_id: str, status: str) -> None: + """Announce that one node's execution row just changed. + + The SSE handler will recompute the node slice from the DB on receipt. + """ + _publish(InspectorMessage(kind="node_changed", workflow_run_id=workflow_run_id, node_id=node_id, status=status)) + + +def publish_workflow_completed(*, workflow_run_id: str, status: str) -> None: + """Announce that the workflow run reached a terminal state. + + The SSE handler emits one last envelope and disconnects. + """ + _publish(InspectorMessage(kind="workflow_completed", workflow_run_id=workflow_run_id, status=status)) + + +# ────────────────────────────────────────────────────────────────────────────── +# Subscriber (consumed by the SSE controller) +# ────────────────────────────────────────────────────────────────────────────── + + +def subscribe(workflow_run_id: str, *, timeout_seconds: float = 1.0) -> Iterator[InspectorMessage]: + """Yield ``InspectorMessage`` instances until the consumer abandons us. + + The loop polls redis with ``timeout_seconds`` so the SSE handler can + interleave keepalive heartbeats. Yields ``None`` on timeout so the caller + can decide whether to keep blocking; malformed payloads are silently + skipped. + + The pub/sub connection is closed when the iterator is garbage-collected + (the wrapping ``finally`` releases it as soon as the SSE handler exits). + """ + pubsub = redis_client.pubsub() + pubsub.subscribe(channel_for(workflow_run_id)) + try: + while True: + raw = pubsub.get_message(ignore_subscribe_messages=True, timeout=timeout_seconds) + if raw is None: + # Surface a heartbeat tick — caller can keep-alive or check + # disconnection without blocking redis any longer. + yield InspectorMessage(kind="node_changed", workflow_run_id=workflow_run_id, node_id=None, status=None) + continue + data = raw.get("data") if isinstance(raw, dict) else None + if isinstance(data, bytes): + data = data.decode("utf-8", errors="replace") + if not isinstance(data, str): + continue + message = InspectorMessage.from_json(data) + if message is None: + continue + yield message + finally: + try: + pubsub.unsubscribe(channel_for(workflow_run_id)) + pubsub.close() + except Exception: + logger.debug( + "InspectorEventPublisher: pubsub teardown failed for %s", + workflow_run_id, + exc_info=True, + ) diff --git a/api/services/workflow/node_output_inspector_service.py b/api/services/workflow/node_output_inspector_service.py new file mode 100644 index 0000000000..c09bd9065e --- /dev/null +++ b/api/services/workflow/node_output_inspector_service.py @@ -0,0 +1,712 @@ +"""Node Output Inspector service (Stage 4 §8). + +PRD §Node Output Inspector renames every workflow "Variable" to a "Node Output" +and re-organizes the panel by **producer node** rather than consumer node. This +service backs the new REST surface +``/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs[/...]`` with three +read-only views: + +* :meth:`snapshot_workflow_run` — every node + its declared outputs + per-output + status, for one debug workflow run. +* :meth:`node_detail` — the same shape filtered down to one node. +* :meth:`output_preview` — full payload for one output, with signed download + URL when the output references an upload file. + +Design constraints baked into this version: + +1. **No new tables** (§8.1). Topology comes from ``WorkflowRun.graph`` (the + graph snapshot taken at execution time so the view stays consistent even + if the draft was edited mid-run). Execution facts come from + ``WorkflowNodeExecutionModel`` rows already produced by the workflow + runtime. +2. **Draft + published runs** (decision D-1 lifted 2026-05-26). The Inspector + accepts ``WorkflowRunTriggeredFrom.DEBUGGING`` (draft test runs) as well as + ``APP_RUN`` / ``WEBHOOK`` / ``SCHEDULE`` / ``PLUGIN`` / ``RAG_PIPELINE_*`` + triggers; the underlying graph + executions are uniform across all of them. + Cross-tenant / cross-app rows still 404 via the standard tenant/app scope. +3. **Declared outputs by node kind**: + * Agent v2 nodes resolve their declared list via + :class:`WorkflowAgentBindingResolver` (the binding owns the canonical + ``DeclaredOutputConfig`` list and falls back to PRD defaults when empty). + * Other node kinds don't have a declared-output schema yet; we surface the + keys present in the execution payload as a best-effort list typed + ``unknown`` so the panel can still render them. +4. **Per-output status** is derived from the metadata the agent_v2 stack + already records (``output_type_check`` and ``output_check`` blobs); falling + back to the node's overall status when those signals aren't present. +5. **SSE stream** (design §8.5) lives in + :mod:`controllers.console.app.workflow_node_output_inspector` alongside the + REST endpoints. The Inspector and the babysit chat SSE share the + ``{event, data, id}`` envelope per decision D-5. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from datetime import datetime +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import select + +from core.db.session_factory import session_factory +from core.workflow.nodes.agent_v2.binding_resolver import ( + WorkflowAgentBindingError, + WorkflowAgentBindingResolver, +) +from core.workflow.nodes.agent_v2.runtime_request_builder import ( + WorkflowAgentRuntimeRequestBuilder, +) +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionStatus, +) +from graphon.file import helpers as file_helpers +from models import App +from models.agent_config_entities import DeclaredOutputConfig, DeclaredOutputType +from models.workflow import WorkflowNodeExecutionModel, WorkflowRun + +logger = logging.getLogger(__name__) + + +# ────────────────────────────────────────────────────────────────────────────── +# Public dataclasses / enums (Pydantic — these go straight on the wire) +# ────────────────────────────────────────────────────────────────────────────── + + +class NodeOutputStatus(StrEnum): + """Lifecycle status of a single declared output within a run.""" + + PENDING = "pending" # node not started + RUNNING = "running" # node running, output not ready yet + READY = "ready" + TYPE_CHECK_FAILED = "type_check_failed" + OUTPUT_CHECK_FAILED = "output_check_failed" + NOT_PRODUCED = "not_produced" # node succeeded but did not produce this declared output + FAILED = "failed" # node itself failed/exception/stopped + + +class NodeStatus(StrEnum): + """Coarse node-level status used by Inspector to pick a banner.""" + + IDLE = "idle" + RUNNING = "running" + READY = "ready" + FAILED = "failed" + + +class CheckResultView(BaseModel): + """``type_check`` / ``output_check`` per-output summary block.""" + + model_config = ConfigDict(extra="forbid") + + passed: bool + reason: str | None = None + + +class NodeOutputView(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + type: DeclaredOutputType | None = None + status: NodeOutputStatus + value_preview: Any = None + type_check: CheckResultView | None = None + output_check: CheckResultView | None = None + retried: int = 0 + + +class NodeOutputsView(BaseModel): + model_config = ConfigDict(extra="forbid") + + node_id: str + node_kind: str + node_display_name: str + node_status: NodeStatus + node_started_at: datetime | None = None + node_completed_at: datetime | None = None + outputs: list[NodeOutputView] = Field(default_factory=list) + + +class WorkflowRunSnapshotView(BaseModel): + model_config = ConfigDict(extra="forbid") + + workflow_run_id: str + workflow_run_status: WorkflowExecutionStatus + node_outputs: list[NodeOutputsView] = Field(default_factory=list) + + +class OutputPreviewView(BaseModel): + model_config = ConfigDict(extra="forbid") + + node_id: str + output_name: str + type: DeclaredOutputType | None = None + status: NodeOutputStatus + value: Any = None # full value (with signed URL for file refs) + + +class NodeOutputInspectorError(Exception): + """Raised when a request cannot be served (404-level conditions).""" + + def __init__(self, code: str, message: str) -> None: + super().__init__(message) + self.code = code + + +# ────────────────────────────────────────────────────────────────────────────── +# Internal helpers — declared outputs per node +# ────────────────────────────────────────────────────────────────────────────── + + +@dataclass(frozen=True, slots=True) +class _ResolvedDeclaration: + """Declared output the Inspector should expose, with a normalized type. + + ``inferred`` is ``True`` when the node kind has no declared-output schema + (we derived the list from the execution payload). The frontend can use + this to dim the type column. + """ + + name: str + declared_type: DeclaredOutputType | None + inferred: bool + + +def _is_agent_v2_node(node: Mapping[str, Any]) -> bool: + """A graph node is Agent v2 iff its ``data.type`` is the AGENT builtin + AND its ``data.version`` is ``"2"``. + + ``BuiltinNodeTypes.AGENT`` is a ``ClassVar[NodeType]`` (plain string), not + a StrEnum, so we compare against it directly without ``.value``. + """ + data = node.get("data") or {} + if not isinstance(data, Mapping): + return False + if data.get("type") != BuiltinNodeTypes.AGENT: + return False + return str(data.get("version", "")) == "2" + + +def _graph_nodes(workflow_run: WorkflowRun) -> list[Mapping[str, Any]]: + """Parse ``WorkflowRun.graph`` (LongText JSON) into a node list. + + The graph blob may be missing / unparseable for very old runs; we treat + that as "no nodes" rather than failing the Inspector, so the panel still + renders an empty state. + """ + if not workflow_run.graph: + return [] + try: + parsed = json.loads(workflow_run.graph) + except (json.JSONDecodeError, TypeError): + logger.warning("NodeOutputInspector: workflow_run %s has unparseable graph blob", workflow_run.id) + return [] + if not isinstance(parsed, Mapping): + return [] + nodes = parsed.get("nodes") + if not isinstance(nodes, list): + return [] + return [n for n in nodes if isinstance(n, Mapping) and "id" in n] + + +# ────────────────────────────────────────────────────────────────────────────── +# Internal helpers — per-output status derivation +# ────────────────────────────────────────────────────────────────────────────── + + +def _decode_json_blob(blob: str | None) -> Mapping[str, Any] | None: + if not blob: + return None + try: + decoded = json.loads(blob) + except (json.JSONDecodeError, TypeError): + return None + if not isinstance(decoded, Mapping): + return None + return decoded + + +def _node_status_for(execution: WorkflowNodeExecutionModel | None) -> NodeStatus: + if execution is None: + return NodeStatus.IDLE + if execution.status == WorkflowNodeExecutionStatus.RUNNING: + return NodeStatus.RUNNING + if execution.status == WorkflowNodeExecutionStatus.SUCCEEDED: + return NodeStatus.READY + return NodeStatus.FAILED + + +def _type_check_by_name(metadata: Mapping[str, Any] | None) -> dict[str, Mapping[str, Any]]: + if not metadata: + return {} + block = metadata.get("output_type_check") + if not isinstance(block, Mapping): + return {} + results = block.get("results") or [] + if not isinstance(results, list): + return {} + indexed: dict[str, Mapping[str, Any]] = {} + for r in results: + if isinstance(r, Mapping) and isinstance(r.get("name"), str): + indexed[r["name"]] = r + return indexed + + +def _output_check_by_name(metadata: Mapping[str, Any] | None) -> dict[str, Mapping[str, Any]]: + if not metadata: + return {} + block = metadata.get("output_check") + if not isinstance(block, Mapping): + return {} + results = block.get("results") or [] + if not isinstance(results, list): + return {} + indexed: dict[str, Mapping[str, Any]] = {} + for r in results: + if isinstance(r, Mapping) and isinstance(r.get("name"), str): + indexed[r["name"]] = r + return indexed + + +def _retried_attempt_count(metadata: Mapping[str, Any] | None) -> int: + """The agent_v2 runtime records the final attempt index in metadata. + + ``attempt`` is 0-indexed so a single execution with no retry has + ``attempt == 0`` and a Inspector-friendly ``retried == 0``. + """ + if not metadata: + return 0 + attempt = metadata.get("attempt") + if isinstance(attempt, int) and attempt > 0: + return attempt + return 0 + + +# ────────────────────────────────────────────────────────────────────────────── +# Value preview (file refs get signed URLs) +# ────────────────────────────────────────────────────────────────────────────── + + +_PREVIEW_TEXT_LIMIT = 500 +_FILE_ID_KEYS: tuple[str, ...] = ("file_id", "upload_file_id", "tool_file_id") + + +def _looks_like_file_ref(value: Any) -> str | None: + """Return the resolved ``file_id`` when ``value`` is a file-shaped dict.""" + if not isinstance(value, Mapping): + return None + for key in _FILE_ID_KEYS: + candidate = value.get(key) + if isinstance(candidate, str) and candidate: + return candidate + return None + + +def _value_preview(value: Any) -> Any: + """Compact preview suitable for the snapshot endpoint. + + File refs are augmented with a signed download URL so the panel can render + a thumbnail / link without a second round-trip; long strings are truncated; + other scalar / dict / list shapes are returned as-is (the Pydantic layer + enforces JSON-safety on serialization). + """ + file_id = _looks_like_file_ref(value) + if file_id: + assert isinstance(value, Mapping) + try: + preview_url = file_helpers.get_signed_file_url(upload_file_id=file_id) + except Exception: + logger.warning("NodeOutputInspector: signed URL failed for file_id=%s", file_id, exc_info=True) + preview_url = None + return {**dict(value), "preview_url": preview_url} + if isinstance(value, str) and len(value) > _PREVIEW_TEXT_LIMIT: + return value[:_PREVIEW_TEXT_LIMIT] + "…" + return value + + +def _full_value(value: Any) -> Any: + """Same shape as :func:`_value_preview` minus the truncation.""" + file_id = _looks_like_file_ref(value) + if file_id: + assert isinstance(value, Mapping) + try: + preview_url = file_helpers.get_signed_file_url(upload_file_id=file_id) + except Exception: + logger.warning("NodeOutputInspector: signed URL failed for file_id=%s", file_id, exc_info=True) + preview_url = None + return {**dict(value), "preview_url": preview_url} + return value + + +# ────────────────────────────────────────────────────────────────────────────── +# Service +# ────────────────────────────────────────────────────────────────────────────── + + +class NodeOutputInspectorService: + """Read-only Inspector for draft + published workflow runs. + + The service is dependency-light: it holds a single + :class:`WorkflowAgentBindingResolver` so agent v2 nodes can map to their + declared outputs without re-implementing binding lookup. All other I/O + uses the global session factory so workflow runs / executions stay on the + repo-default code path. + + Tenancy is enforced via ``app_model.tenant_id`` + ``app_model.id`` on + every load — the same scope guard regardless of trigger source. + """ + + def __init__(self, binding_resolver: WorkflowAgentBindingResolver | None = None) -> None: + self._binding_resolver = binding_resolver or WorkflowAgentBindingResolver() + + # ── public API ──────────────────────────────────────────────────────── + + def snapshot_workflow_run(self, *, app_model: App, workflow_run_id: str) -> WorkflowRunSnapshotView: + """Build the per-node snapshot for one debug workflow run.""" + workflow_run, executions = self._load_run_and_executions(app_model=app_model, workflow_run_id=workflow_run_id) + executions_by_node = self._index_executions_by_node(executions) + graph_nodes = _graph_nodes(workflow_run) + + node_views: list[NodeOutputsView] = [] + for raw_node in graph_nodes: + node_id = str(raw_node["id"]) + execution = executions_by_node.get(node_id) + view = self._build_node_view( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=workflow_run.workflow_id, + raw_node=raw_node, + execution=execution, + ) + node_views.append(view) + + return WorkflowRunSnapshotView( + workflow_run_id=workflow_run.id, + workflow_run_status=workflow_run.status, + node_outputs=node_views, + ) + + def node_detail(self, *, app_model: App, workflow_run_id: str, node_id: str) -> NodeOutputsView: + """Per-node Inspector entry — returns one ``NodeOutputsView``.""" + workflow_run, executions = self._load_run_and_executions(app_model=app_model, workflow_run_id=workflow_run_id) + graph_nodes = _graph_nodes(workflow_run) + raw_node = next((n for n in graph_nodes if str(n.get("id")) == node_id), None) + if raw_node is None: + raise NodeOutputInspectorError( + "node_not_in_workflow_run", + f"Node {node_id!r} does not appear in workflow run {workflow_run_id!r}.", + ) + + execution = self._index_executions_by_node(executions).get(node_id) + return self._build_node_view( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=workflow_run.workflow_id, + raw_node=raw_node, + execution=execution, + ) + + def output_preview( + self, + *, + app_model: App, + workflow_run_id: str, + node_id: str, + output_name: str, + ) -> OutputPreviewView: + """Full payload for one declared output (with signed file URL).""" + detail = self.node_detail( + app_model=app_model, + workflow_run_id=workflow_run_id, + node_id=node_id, + ) + view = next((o for o in detail.outputs if o.name == output_name), None) + if view is None: + raise NodeOutputInspectorError( + "node_output_not_declared", + f"Output {output_name!r} is not declared on node {node_id!r}.", + ) + + # ``node_detail`` already produced a truncated value_preview; reload + # the raw value from the execution payload so the preview endpoint can + # return the full thing (still wrapped through ``_full_value`` for + # signed file URLs). + execution = self._index_executions_by_node( + self._load_run_and_executions(app_model=app_model, workflow_run_id=workflow_run_id)[1] + ).get(node_id) + full_value: Any = None + if execution is not None: + outputs = _decode_json_blob(execution.outputs) or {} + if output_name in outputs: + full_value = _full_value(outputs[output_name]) + + return OutputPreviewView( + node_id=node_id, + output_name=output_name, + type=view.type, + status=view.status, + value=full_value, + ) + + # ── DB loading ──────────────────────────────────────────────────────── + + def _load_run_and_executions( + self, *, app_model: App, workflow_run_id: str + ) -> tuple[WorkflowRun, Sequence[WorkflowNodeExecutionModel]]: + """Fetch the ``WorkflowRun`` row + every execution that belongs to it. + + Enforces: + * row exists, + * row belongs to the app's tenant + app. + + The trigger source (DEBUGGING vs. APP_RUN / WEBHOOK / SCHEDULE / ...) is + deliberately not checked here — D-1 was lifted 2026-05-26 and the + Inspector now serves both draft and published runs. + """ + with session_factory.create_session() as session: + workflow_run = session.scalar( + select(WorkflowRun).where( + WorkflowRun.id == workflow_run_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.tenant_id == app_model.tenant_id, + ) + ) + if workflow_run is None: + raise NodeOutputInspectorError("workflow_run_not_found", "Workflow run not found.") + + executions = session.scalars( + select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, + WorkflowNodeExecutionModel.app_id == app_model.id, + ) + ).all() + + return workflow_run, executions + + @staticmethod + def _index_executions_by_node( + executions: Sequence[WorkflowNodeExecutionModel], + ) -> dict[str, WorkflowNodeExecutionModel]: + """Keep the latest execution per ``node_id``. + + A given node may have multiple rows when retries or iterations occur; + ``index`` is the per-run sequence counter, so we keep the one with + the highest index as the canonical "current" view. + """ + latest: dict[str, WorkflowNodeExecutionModel] = {} + for execution in executions: + existing = latest.get(execution.node_id) + if existing is None or execution.index > existing.index: + latest[execution.node_id] = execution + return latest + + # ── Per-node view construction ──────────────────────────────────────── + + def _build_node_view( + self, + *, + tenant_id: str, + app_id: str, + workflow_id: str, + raw_node: Mapping[str, Any], + execution: WorkflowNodeExecutionModel | None, + ) -> NodeOutputsView: + node_id = str(raw_node["id"]) + data = raw_node.get("data") or {} + if not isinstance(data, Mapping): + data = {} + + node_kind = str(data.get("type") or (execution.node_type if execution else "") or "unknown") + display_name = str(data.get("title") or (execution.title if execution else node_id)) + node_status = _node_status_for(execution) + + declarations = self._resolve_declared_outputs( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + node_id=node_id, + raw_node=raw_node, + execution=execution, + ) + + outputs_dict = _decode_json_blob(execution.outputs) if execution else None + metadata_dict = _decode_json_blob(execution.execution_metadata) if execution else None + type_check_by_name = _type_check_by_name(metadata_dict) + output_check_by_name = _output_check_by_name(metadata_dict) + retried = _retried_attempt_count(metadata_dict) + + output_views: list[NodeOutputView] = [] + for declaration in declarations: + output_views.append( + self._build_output_view( + declaration=declaration, + node_status=node_status, + outputs_dict=outputs_dict, + type_check_by_name=type_check_by_name, + output_check_by_name=output_check_by_name, + retried=retried, + ) + ) + + return NodeOutputsView( + node_id=node_id, + node_kind=node_kind, + node_display_name=display_name, + node_status=node_status, + node_started_at=execution.created_at if execution else None, + node_completed_at=execution.finished_at if execution else None, + outputs=output_views, + ) + + def _build_output_view( + self, + *, + declaration: _ResolvedDeclaration, + node_status: NodeStatus, + outputs_dict: Mapping[str, Any] | None, + type_check_by_name: Mapping[str, Mapping[str, Any]], + output_check_by_name: Mapping[str, Mapping[str, Any]], + retried: int, + ) -> NodeOutputView: + name = declaration.name + declared_type = declaration.declared_type + + if node_status == NodeStatus.IDLE: + return NodeOutputView( + name=name, + type=declared_type, + status=NodeOutputStatus.PENDING, + retried=retried, + ) + if node_status == NodeStatus.RUNNING: + return NodeOutputView( + name=name, + type=declared_type, + status=NodeOutputStatus.RUNNING, + retried=retried, + ) + if node_status == NodeStatus.FAILED: + return NodeOutputView( + name=name, + type=declared_type, + status=NodeOutputStatus.FAILED, + retried=retried, + ) + + # ── node succeeded ──────────────────────────────────────────── + type_check_result = type_check_by_name.get(name) + output_check_result = output_check_by_name.get(name) + type_check_view = self._coerce_check_view(type_check_result) + output_check_view = self._coerce_check_view(output_check_result) + + # type check loses first; output check next; otherwise ready. + status: NodeOutputStatus + if type_check_result and not _is_passing(type_check_result): + status = NodeOutputStatus.TYPE_CHECK_FAILED + elif output_check_result and not _is_passing(output_check_result): + status = NodeOutputStatus.OUTPUT_CHECK_FAILED + elif outputs_dict is not None and name not in outputs_dict: + status = NodeOutputStatus.NOT_PRODUCED + else: + status = NodeOutputStatus.READY + + value_preview = _value_preview(outputs_dict.get(name)) if outputs_dict and name in outputs_dict else None + + return NodeOutputView( + name=name, + type=declared_type, + status=status, + value_preview=value_preview, + type_check=type_check_view, + output_check=output_check_view, + retried=retried, + ) + + @staticmethod + def _coerce_check_view(result: Mapping[str, Any] | None) -> CheckResultView | None: + if not result: + return None + # type_check rows use ``status``; output_check rows use ``status`` too — + # both record per-output state. We treat ``status == "ready"``/"passed" + # as passing and everything else as failing, so the view stays + # stable regardless of which producer wrote the metadata. + return CheckResultView(passed=_is_passing(result), reason=result.get("reason")) + + # ── Declared-output resolution ──────────────────────────────────────── + + def _resolve_declared_outputs( + self, + *, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + raw_node: Mapping[str, Any], + execution: WorkflowNodeExecutionModel | None, + ) -> list[_ResolvedDeclaration]: + if _is_agent_v2_node(raw_node): + agent_decl = self._declared_outputs_for_agent_v2( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + node_id=node_id, + ) + if agent_decl is not None: + return [_ResolvedDeclaration(name=o.name, declared_type=o.type, inferred=False) for o in agent_decl] + + # Non-agent (or agent-binding-missing) fall back to inferring from the + # produced payload so the Inspector still has something to show. + return self._infer_outputs_from_payload(execution=execution) + + def _declared_outputs_for_agent_v2( + self, + *, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> list[DeclaredOutputConfig] | None: + try: + bundle = self._binding_resolver.resolve( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + node_id=node_id, + ) + except WorkflowAgentBindingError: + return None + try: + from models.agent_config_entities import WorkflowNodeJobConfig + + node_job = WorkflowNodeJobConfig.model_validate(bundle.binding.node_job_config_dict) + except Exception: + logger.warning( + "NodeOutputInspector: malformed node_job_config for binding %s", bundle.binding.id, exc_info=True + ) + return None + return list(WorkflowAgentRuntimeRequestBuilder.effective_declared_outputs(list(node_job.declared_outputs))) + + @staticmethod + def _infer_outputs_from_payload(*, execution: WorkflowNodeExecutionModel | None) -> list[_ResolvedDeclaration]: + if execution is None: + return [] + outputs = _decode_json_blob(execution.outputs) + if not outputs: + return [] + return [_ResolvedDeclaration(name=name, declared_type=None, inferred=True) for name in outputs] + + +def _is_passing(result: Mapping[str, Any]) -> bool: + """A check-result row is "passing" when its ``status`` is the ready/passed + sentinel emitted by the type-checker / output-check executor.""" + status = result.get("status") + if status in {"ready", "passed", "not_produced"}: + return True + return False diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index ad98e95b03..16121cefa6 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -102,12 +102,13 @@ class AppExecutionParams(BaseModel): workflow_run_id: str | None = None, ): user_params: _Account | _EndUser - if isinstance(user, Account): - user_params = _Account(user_id=user.id) - elif isinstance(user, EndUser): - user_params = _EndUser(end_user_id=user.id) - else: - raise AssertionError("this statement should be unreachable.") + match user: + case Account(): + user_params = _Account(user_id=user.id) + case EndUser(): + user_params = _EndUser(end_user_id=user.id) + case _: + raise AssertionError("this statement should be unreachable.") return cls( app_id=app_model.id, workflow_id=workflow.id, @@ -365,36 +366,37 @@ def _resume_app_execution(payload: dict[str, Any]) -> None: state_owner_user_id=workflow.created_by, ) - if isinstance(generate_entity, AdvancedChatAppGenerateEntity): - assert conversation is not None - assert message is not None - _resume_advanced_chat( - app_model=app_model, - workflow=workflow, - user=user, - conversation=conversation, - message=message, - generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - session_factory=session_factory, - pause_state_config=pause_config, - workflow_run_id=workflow_run_id, - workflow_run=workflow_run, - ) - elif isinstance(generate_entity, WorkflowAppGenerateEntity): - _resume_workflow( - app_model=app_model, - workflow=workflow, - user=user, - generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - session_factory=session_factory, - pause_state_config=pause_config, - workflow_run_id=workflow_run_id, - workflow_run=workflow_run, - workflow_run_repo=workflow_run_repo, - pause_entity=pause_entity, - ) + match generate_entity: + case AdvancedChatAppGenerateEntity(): + assert conversation is not None + assert message is not None + _resume_advanced_chat( + app_model=app_model, + workflow=workflow, + user=user, + conversation=conversation, + message=message, + generate_entity=generate_entity, + graph_runtime_state=graph_runtime_state, + session_factory=session_factory, + pause_state_config=pause_config, + workflow_run_id=workflow_run_id, + workflow_run=workflow_run, + ) + case WorkflowAppGenerateEntity(): + _resume_workflow( + app_model=app_model, + workflow=workflow, + user=user, + generate_entity=generate_entity, + graph_runtime_state=graph_runtime_state, + session_factory=session_factory, + pause_state_config=pause_config, + workflow_run_id=workflow_run_id, + workflow_run=workflow_run, + workflow_run_repo=workflow_run_repo, + pause_entity=pause_entity, + ) def _resume_advanced_chat( diff --git a/api/tests/helpers/__init__.py b/api/tests/helpers/__init__.py new file mode 100644 index 0000000000..5183591f40 --- /dev/null +++ b/api/tests/helpers/__init__.py @@ -0,0 +1 @@ +"""Shared test helpers for backend migration tests.""" diff --git a/api/tests/helpers/legacy_model_type_migration.py b/api/tests/helpers/legacy_model_type_migration.py new file mode 100644 index 0000000000..12f092a0fe --- /dev/null +++ b/api/tests/helpers/legacy_model_type_migration.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta +from uuid import uuid4 + +import sqlalchemy as sa +from sqlalchemy.engine import Engine + +from models.account import Tenant +from models.enums import CredentialSourceType +from models.provider import ( + LoadBalancingModelConfig, + ProviderModel, + ProviderModelCredential, + ProviderModelSetting, + TenantDefaultModel, +) + +LEGACY_TO_CANONICAL: dict[str, str] = { + "text-generation": "llm", + "embeddings": "text-embedding", + "reranking": "rerank", +} +UNCHANGED_MODEL_TYPES: tuple[str, ...] = ("speech2text", "moderation", "tts") +ALL_TABLE_NAMES: tuple[str, ...] = ( + ProviderModel.__tablename__, + TenantDefaultModel.__tablename__, + ProviderModelSetting.__tablename__, + LoadBalancingModelConfig.__tablename__, + ProviderModelCredential.__tablename__, +) +DEFAULT_PRIMARY_TENANT_ID = "00000000-0000-0000-0000-000000000101" +DEFAULT_SECONDARY_TENANT_ID = "00000000-0000-0000-0000-000000000202" + + +@dataclass(frozen=True, slots=True) +class DirtyTenantFixture: + tenant_id: str + winner_credential_id: str + loser_credential_id: str + distinct_credential_id: str + provider_model_id: str + load_balancing_config_id: str + provider_model_setting_id: str + tenant_default_model_id: str + embedding_provider_model_id: str + embedding_setting_id: str + loser_credential_name: str + distinct_credential_name: str + loser_encrypted_config: str + winner_encrypted_config: str + + +@dataclass(frozen=True, slots=True) +class DirtyDataFixture: + primary: DirtyTenantFixture + secondary: DirtyTenantFixture + + +def create_minimal_legacy_model_type_schema(engine: Engine) -> None: + metadata = Tenant.__table__.metadata + metadata.create_all( + engine, + tables=[ + Tenant.__table__, + ProviderModel.__table__, + TenantDefaultModel.__table__, + ProviderModelSetting.__table__, + LoadBalancingModelConfig.__table__, + ProviderModelCredential.__table__, + ], + checkfirst=True, + ) + + +def drop_minimal_legacy_model_type_schema(engine: Engine) -> None: + metadata = Tenant.__table__.metadata + metadata.drop_all( + engine, + tables=[ + LoadBalancingModelConfig.__table__, + ProviderModelSetting.__table__, + TenantDefaultModel.__table__, + ProviderModel.__table__, + ProviderModelCredential.__table__, + Tenant.__table__, + ], + checkfirst=True, + ) + + +def seed_legacy_model_type_dirty_data( + engine: Engine, + *, + primary_tenant_id: str = DEFAULT_PRIMARY_TENANT_ID, + secondary_tenant_id: str = DEFAULT_SECONDARY_TENANT_ID, +) -> DirtyDataFixture: + create_minimal_legacy_model_type_schema(engine) + primary = _seed_tenant(engine, tenant_id=primary_tenant_id, provider_name="openai") + secondary = _seed_tenant(engine, tenant_id=secondary_tenant_id, provider_name="openai") + return DirtyDataFixture(primary=primary, secondary=secondary) + + +def snapshot_legacy_model_type_state(engine: Engine) -> dict[str, list[dict[str, object]]]: + snapshots: dict[str, list[dict[str, object]]] = {} + for table_name in ALL_TABLE_NAMES: + snapshots[table_name] = fetch_table_rows(engine, table_name) + return snapshots + + +def fetch_table_rows( + engine: Engine, + table_name: str, + *, + tenant_id: str | None = None, +) -> list[dict[str, object]]: + sql = f"SELECT * FROM {table_name}" + params: dict[str, object] = {} + if tenant_id is not None: + sql += " WHERE tenant_id = :tenant_id" + params["tenant_id"] = tenant_id + sql += " ORDER BY id ASC" + + with engine.begin() as conn: + rows = conn.execute(sa.text(sql), params).mappings().all() + + result: list[dict[str, object]] = [] + for row in rows: + normalized = dict(row) + for key, value in normalized.items(): + if isinstance(value, datetime): + normalized[key] = value.isoformat() + elif isinstance(value, uuid.UUID): + normalized[key] = str(value) + result.append(normalized) + return result + + +def fetch_model_types_for_tenant(engine: Engine, table_name: str, tenant_id: str) -> list[str]: + rows = fetch_table_rows(engine, table_name, tenant_id=tenant_id) + return [str(row["model_type"]) for row in rows] + + +def assert_tenant_rows_use_only_canonical_model_types(engine: Engine, tenant_id: str) -> None: + for table_name in ALL_TABLE_NAMES: + model_types = fetch_model_types_for_tenant(engine, table_name, tenant_id) + assert set(model_types) <= set(LEGACY_TO_CANONICAL.values()) | set(UNCHANGED_MODEL_TYPES), ( + table_name, + model_types, + ) + + +def count_rows(engine: Engine, table_name: str, *, tenant_id: str) -> int: + with engine.begin() as conn: + stmt = sa.text(f"SELECT COUNT(*) FROM {table_name} WHERE tenant_id = :tenant_id") + return int(conn.execute(stmt, {"tenant_id": tenant_id}).scalar_one()) + + +def _seed_tenant(engine: Engine, *, tenant_id: str, provider_name: str) -> DirtyTenantFixture: + now = datetime(2025, 1, 1, 12, 0, 0) + winner_credential_id = str(uuid4()) + loser_credential_id = str(uuid4()) + distinct_credential_id = str(uuid4()) + provider_model_id = str(uuid4()) + load_balancing_config_id = str(uuid4()) + provider_model_setting_id = str(uuid4()) + tenant_default_model_id = str(uuid4()) + embedding_provider_model_id = str(uuid4()) + embedding_setting_id = str(uuid4()) + + loser_credential_name = f"{tenant_id}-shared" + distinct_credential_name = f"{tenant_id}-distinct" + winner_encrypted_config = json.dumps({"api_key": f"{tenant_id}-winner"}) + loser_encrypted_config = json.dumps({"api_key": f"{tenant_id}-loser"}) + distinct_encrypted_config = json.dumps({"api_key": f"{tenant_id}-distinct"}) + + with engine.begin() as conn: + conn.execute( + Tenant.__table__.insert().values( + id=tenant_id, + name=f"Tenant {tenant_id}", + plan="basic", + status="normal", + ) + ) + conn.execute( + sa.text( + """ + INSERT INTO provider_model_credentials + ( + id, tenant_id, provider_name, model_name, + model_type, credential_name, encrypted_config, + created_at, updated_at + ) + VALUES + ( + :winner_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'llm', :shared_name, :winner_config, + :created_at, :winner_updated_at + ), + ( + :loser_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :shared_name, :loser_config, + :created_at, :loser_updated_at + ), + ( + :distinct_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :distinct_name, :distinct_config, + :created_at, :distinct_updated_at + ) + """ + ), + { + "winner_id": winner_credential_id, + "loser_id": loser_credential_id, + "distinct_id": distinct_credential_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "shared_name": loser_credential_name, + "distinct_name": distinct_credential_name, + "winner_config": winner_encrypted_config, + "loser_config": loser_encrypted_config, + "distinct_config": distinct_encrypted_config, + "created_at": now - timedelta(days=2), + "winner_updated_at": now, + "loser_updated_at": now - timedelta(days=1), + "distinct_updated_at": now - timedelta(hours=12), + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO provider_models + ( + id, tenant_id, provider_name, model_name, + model_type, credential_id, is_valid, + created_at, updated_at + ) + VALUES + ( + :provider_model_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :loser_id, :is_valid, + :created_at, :updated_at + ), + ( + :embedding_provider_model_id, :tenant_id, :provider_name, 'text-embedding-3-large', + 'embeddings', NULL, :is_valid, + :created_at, :updated_at + ) + """ + ), + { + "provider_model_id": provider_model_id, + "embedding_provider_model_id": embedding_provider_model_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "loser_id": loser_credential_id, + "is_valid": True, + "created_at": now - timedelta(days=2), + "updated_at": now - timedelta(hours=6), + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO tenant_default_models + (id, tenant_id, provider_name, model_name, model_type, created_at, updated_at) + VALUES + ( + :tenant_default_model_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :created_at, :updated_at + ) + """ + ), + { + "tenant_default_model_id": tenant_default_model_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "created_at": now - timedelta(days=2), + "updated_at": now - timedelta(hours=4), + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO provider_model_settings + ( + id, tenant_id, provider_name, model_name, + model_type, enabled, load_balancing_enabled, + created_at, updated_at + ) + VALUES + ( + :provider_model_setting_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :enabled, :load_balancing_enabled, + :created_at, :updated_at + ), + ( + :embedding_setting_id, :tenant_id, :provider_name, 'text-embedding-3-large', + 'embeddings', :enabled, :embedding_load_balancing_enabled, + :created_at, :updated_at + ) + """ + ), + { + "provider_model_setting_id": provider_model_setting_id, + "embedding_setting_id": embedding_setting_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "enabled": True, + "load_balancing_enabled": True, + "embedding_load_balancing_enabled": False, + "created_at": now - timedelta(days=2), + "updated_at": now - timedelta(hours=3), + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO load_balancing_model_configs + ( + id, tenant_id, provider_name, model_name, model_type, + name, encrypted_config, credential_id, credential_source_type, + enabled, created_at, updated_at + ) + VALUES + ( + :load_balancing_config_id, :tenant_id, :provider_name, 'gpt-4o-mini', 'text-generation', + :lb_name, :loser_config, :loser_id, :credential_source_type, + :enabled, :created_at, :updated_at + ) + """ + ), + { + "load_balancing_config_id": load_balancing_config_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "lb_name": loser_credential_name, + "loser_config": loser_encrypted_config, + "loser_id": loser_credential_id, + "credential_source_type": CredentialSourceType.CUSTOM_MODEL.value, + "enabled": True, + "created_at": now - timedelta(days=2), + "updated_at": now - timedelta(hours=2), + }, + ) + + return DirtyTenantFixture( + tenant_id=tenant_id, + winner_credential_id=winner_credential_id, + loser_credential_id=loser_credential_id, + distinct_credential_id=distinct_credential_id, + provider_model_id=provider_model_id, + load_balancing_config_id=load_balancing_config_id, + provider_model_setting_id=provider_model_setting_id, + tenant_default_model_id=tenant_default_model_id, + embedding_provider_model_id=embedding_provider_model_id, + embedding_setting_id=embedding_setting_id, + loser_credential_name=loser_credential_name, + distinct_credential_name=distinct_credential_name, + loser_encrypted_config=loser_encrypted_config, + winner_encrypted_config=winner_encrypted_config, + ) diff --git a/api/tests/integration_tests/services/test_node_output_inspector_service.py b/api/tests/integration_tests/services/test_node_output_inspector_service.py new file mode 100644 index 0000000000..5a8c07e043 --- /dev/null +++ b/api/tests/integration_tests/services/test_node_output_inspector_service.py @@ -0,0 +1,475 @@ +"""End-to-end tests for ``NodeOutputInspectorService`` (Stage 4 §8 / ENG-373). + +These integration tests exercise the service against a real Postgres +(``dify-db-1``) — same pattern as :mod:`test_remove_app_and_related_data_task`: +seed rows via ``session_factory.create_session()`` with explicit commits, +exercise the service, clean up by ID at teardown. + +Coverage: +1. Snapshot for a draft run with one agent v2 node + one tool node +2. Type-check failure visible in snapshot +3. Output-check failure visible in snapshot +4. Published run returns ``published_run_inspector_not_implemented`` +5. Cross-tenant access returns ``workflow_run_not_found`` +6. File output preview endpoint returns full value with signed URL +7. ``node_detail`` path serves a single node view +""" + +from __future__ import annotations + +import json +import uuid +from collections.abc import Generator +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +import pytest +from sqlalchemy import delete + +from core.db.session_factory import session_factory +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import ( + WorkflowNodeExecutionModel, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowType, +) +from services.workflow.node_output_inspector_service import ( + NodeOutputInspectorError, + NodeOutputInspectorService, + NodeOutputStatus, + NodeStatus, +) + + +@pytest.fixture +def fake_app_model() -> SimpleNamespace: + """Lightweight stand-in for the ``App`` model that the service consumes. + + ``App`` is only read for ``id`` and ``tenant_id``; the service does not + poke at any ORM relationship so a SimpleNamespace is enough — and it + keeps us free of needing the ``apps`` row to actually exist (which would + drag in Account / Tenant setup). + """ + return SimpleNamespace( + id=str(uuid.uuid4()), + tenant_id=str(uuid.uuid4()), + ) + + +def _make_workflow_run( + *, + app_id: str, + tenant_id: str, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, + status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING, + graph: dict[str, Any] | None = None, +) -> WorkflowRun: + """Build a ``WorkflowRun`` row with all required fields populated.""" + return WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=tenant_id, + app_id=app_id, + workflow_id=str(uuid.uuid4()), + type=WorkflowType.WORKFLOW, + triggered_from=triggered_from, + version="draft", + graph=json.dumps(graph or {"nodes": []}), + status=status, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + ) + + +def _make_execution( + *, + app_id: str, + tenant_id: str, + workflow_id: str, + workflow_run_id: str, + node_id: str, + node_type: str = "agent", + title: str = "", + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED, + outputs: dict[str, Any] | None = None, + execution_metadata: dict[str, Any] | None = None, + index: int = 1, +) -> WorkflowNodeExecutionModel: + """Build a ``WorkflowNodeExecutionModel`` row with all required fields.""" + return WorkflowNodeExecutionModel( + id=str(uuid.uuid4()), + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=workflow_run_id, + index=index, + node_id=node_id, + node_type=node_type, + title=title or node_id, + status=status, + outputs=json.dumps(outputs) if outputs is not None else None, + execution_metadata=json.dumps(execution_metadata) if execution_metadata is not None else None, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=datetime.now(UTC), + finished_at=datetime.now(UTC), + ) + + +@pytest.fixture +def seeded_run( + flask_req_ctx, fake_app_model: SimpleNamespace +) -> Generator[tuple[SimpleNamespace, WorkflowRun, list[WorkflowNodeExecutionModel]], None, None]: + """Seed one debug ``WorkflowRun`` + 2 node executions in real Postgres. + + Yields ``(app_model, workflow_run, executions)``. Cleans both rows up at + teardown via direct ``DELETE`` so a failed test never leaves debris. + """ + graph = { + "nodes": [ + { + "id": "agent-node-1", + "data": {"type": "agent", "version": "2", "title": "My Agent"}, + }, + { + "id": "tool-node-1", + "data": {"type": "tool", "title": "Slack"}, + }, + ] + } + workflow_run = _make_workflow_run( + app_id=fake_app_model.id, + tenant_id=fake_app_model.tenant_id, + graph=graph, + ) + agent_execution = _make_execution( + app_id=fake_app_model.id, + tenant_id=fake_app_model.tenant_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + node_id="agent-node-1", + node_type="agent", + outputs={"text": "hello world"}, + execution_metadata={ + "output_type_check": { + "passed": True, + "results": [{"name": "text", "type": "string", "status": "ready"}], + }, + "attempt": 0, + }, + index=1, + ) + tool_execution = _make_execution( + app_id=fake_app_model.id, + tenant_id=fake_app_model.tenant_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + node_id="tool-node-1", + node_type="tool", + outputs={"message": "sent", "ok": True}, + index=2, + ) + + with session_factory.create_session() as session: + session.add(workflow_run) + session.add(agent_execution) + session.add(tool_execution) + session.commit() + run_id = workflow_run.id + execution_ids = [agent_execution.id, tool_execution.id] + + try: + yield fake_app_model, workflow_run, [agent_execution, tool_execution] + finally: + with session_factory.create_session() as session: + session.execute(delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))) + session.execute(delete(WorkflowRun).where(WorkflowRun.id == run_id)) + session.commit() + + +# ────────────────────────────────────────────────────────────────────────────── +# Stub binding resolver — declared outputs for the agent v2 node +# ────────────────────────────────────────────────────────────────────────────── + + +def _stub_resolver(declared_outputs_payload: list[dict[str, Any]]): + """Return a stand-in binding resolver whose ``.resolve`` always returns + one bundle with the supplied declared_outputs. + + The real resolver hits ``workflow_agent_node_bindings``; we skip that + table here so the Inspector can be tested without binding-row setup. + """ + binding = SimpleNamespace( + id="binding-1", + node_job_config_dict={ + "workflow_prompt": "stub", + "declared_outputs": declared_outputs_payload, + }, + ) + bundle = SimpleNamespace(binding=binding, agent=None, snapshot=None) + + class _Resolver: + def resolve(self, **_: Any): + return bundle + + return _Resolver() + + +# ────────────────────────────────────────────────────────────────────────────── +# Tests +# ────────────────────────────────────────────────────────────────────────────── + + +def test_snapshot_returns_agent_v2_declared_outputs_with_status_ready(seeded_run): + """Happy path: agent v2 node + tool node both render, statuses come from + real ``WorkflowRun`` + ``WorkflowNodeExecutionModel`` rows.""" + app_model, workflow_run, _ = seeded_run + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([{"name": "text", "type": "string"}])) + snapshot = service.snapshot_workflow_run( + app_model=app_model, + workflow_run_id=workflow_run.id, + ) + + assert snapshot.workflow_run_id == workflow_run.id + assert snapshot.workflow_run_status == WorkflowExecutionStatus.RUNNING + + by_node = {n.node_id: n for n in snapshot.node_outputs} + + agent_view = by_node["agent-node-1"] + assert agent_view.node_status == NodeStatus.READY + assert agent_view.outputs[0].name == "text" + assert agent_view.outputs[0].status == NodeOutputStatus.READY + assert agent_view.outputs[0].value_preview == "hello world" + + tool_view = by_node["tool-node-1"] + # Tool node's declared outputs are *inferred* from the produced payload. + output_names = sorted(o.name for o in tool_view.outputs) + assert output_names == ["message", "ok"] + assert all(o.type is None for o in tool_view.outputs) + + +def test_snapshot_404s_for_missing_run(fake_app_model): + """Service raises ``workflow_run_not_found`` when the row doesn't exist.""" + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([])) + with pytest.raises(NodeOutputInspectorError) as exc: + service.snapshot_workflow_run(app_model=fake_app_model, workflow_run_id=str(uuid.uuid4())) + assert exc.value.code == "workflow_run_not_found" + + +def test_snapshot_404s_for_cross_tenant_access(seeded_run): + """A wrong-tenant app_model must not see another tenant's run.""" + _, workflow_run, _ = seeded_run + intruder = SimpleNamespace(id=str(uuid.uuid4()), tenant_id=str(uuid.uuid4())) + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([])) + with pytest.raises(NodeOutputInspectorError) as exc: + service.snapshot_workflow_run(app_model=intruder, workflow_run_id=workflow_run.id) + assert exc.value.code == "workflow_run_not_found" + + +def test_snapshot_404s_for_published_run_per_decision_d1(flask_req_ctx, fake_app_model): + """Decision D-1: published / app-run Inspector deferred to stage 4.1.""" + workflow_run = _make_workflow_run( + app_id=fake_app_model.id, + tenant_id=fake_app_model.tenant_id, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + graph={"nodes": []}, + ) + with session_factory.create_session() as session: + session.add(workflow_run) + session.commit() + run_id = workflow_run.id + + try: + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([])) + with pytest.raises(NodeOutputInspectorError) as exc: + service.snapshot_workflow_run(app_model=fake_app_model, workflow_run_id=run_id) + assert exc.value.code == "published_run_inspector_not_implemented" + finally: + with session_factory.create_session() as session: + session.execute(delete(WorkflowRun).where(WorkflowRun.id == run_id)) + session.commit() + + +def test_snapshot_surfaces_type_check_failure_from_metadata(flask_req_ctx, fake_app_model): + """Per-output ``TYPE_CHECK_FAILED`` derived from the metadata blob the + Stage 4 §5 stack records on the execution row.""" + graph = {"nodes": [{"id": "agent-1", "data": {"type": "agent", "version": "2"}}]} + workflow_run = _make_workflow_run(app_id=fake_app_model.id, tenant_id=fake_app_model.tenant_id, graph=graph) + execution = _make_execution( + app_id=fake_app_model.id, + tenant_id=fake_app_model.tenant_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + node_id="agent-1", + outputs={"summary": 123}, # int despite declared string + execution_metadata={ + "output_type_check": { + "passed": False, + "results": [ + { + "name": "summary", + "type": "string", + "status": "type_check_failed", + "reason": "expected string, got int", + } + ], + } + }, + ) + with session_factory.create_session() as session: + session.add(workflow_run) + session.add(execution) + session.commit() + run_id, execution_id = workflow_run.id, execution.id + + try: + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([{"name": "summary", "type": "string"}])) + snapshot = service.snapshot_workflow_run(app_model=fake_app_model, workflow_run_id=run_id) + output = snapshot.node_outputs[0].outputs[0] + assert output.status == NodeOutputStatus.TYPE_CHECK_FAILED + assert output.type_check is not None + assert output.type_check.passed is False + assert output.type_check.reason == "expected string, got int" + finally: + with session_factory.create_session() as session: + session.execute(delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)) + session.execute(delete(WorkflowRun).where(WorkflowRun.id == run_id)) + session.commit() + + +def test_snapshot_surfaces_output_check_failure_from_metadata(flask_req_ctx, fake_app_model): + """When ``output_type_check.passed`` but ``output_check.passed=False``, the + output is flagged ``OUTPUT_CHECK_FAILED``.""" + graph = {"nodes": [{"id": "agent-1", "data": {"type": "agent", "version": "2"}}]} + workflow_run = _make_workflow_run(app_id=fake_app_model.id, tenant_id=fake_app_model.tenant_id, graph=graph) + execution = _make_execution( + app_id=fake_app_model.id, + tenant_id=fake_app_model.tenant_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + node_id="agent-1", + outputs={"report": {"file_id": "550e8400-e29b-41d4-a716-446655440000"}}, + execution_metadata={ + "output_type_check": {"passed": True, "results": [{"name": "report", "status": "ready"}]}, + "output_check": { + "passed": False, + "results": [{"name": "report", "status": "failed", "reason": "benchmark mismatch"}], + }, + }, + ) + with session_factory.create_session() as session: + session.add(workflow_run) + session.add(execution) + session.commit() + run_id, execution_id = workflow_run.id, execution.id + + try: + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([{"name": "report", "type": "file"}])) + # Stub signed-URL so we don't depend on the workflow file runtime being + # bound (it isn't, in this minimal flask_req_ctx). + with patch( + "services.workflow.node_output_inspector_service.file_helpers.get_signed_file_url", + return_value="https://signed.example/report", + ): + snapshot = service.snapshot_workflow_run(app_model=fake_app_model, workflow_run_id=run_id) + output = snapshot.node_outputs[0].outputs[0] + assert output.status == NodeOutputStatus.OUTPUT_CHECK_FAILED + assert output.output_check is not None + assert output.output_check.passed is False + assert output.output_check.reason == "benchmark mismatch" + finally: + with session_factory.create_session() as session: + session.execute(delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)) + session.execute(delete(WorkflowRun).where(WorkflowRun.id == run_id)) + session.commit() + + +def test_node_detail_serves_one_node(seeded_run): + app_model, workflow_run, _ = seeded_run + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([{"name": "text", "type": "string"}])) + view = service.node_detail( + app_model=app_model, + workflow_run_id=workflow_run.id, + node_id="agent-node-1", + ) + assert view.node_id == "agent-node-1" + assert view.outputs[0].name == "text" + + +def test_output_preview_for_file_renders_signed_url(seeded_run, fake_app_model): + """``preview`` returns the full value with signed_url for file refs.""" + # Replace the seeded agent execution's output with a file ref. + _, workflow_run, executions = seeded_run + agent_execution = executions[0] + with session_factory.create_session() as session: + # Re-bind the persisted row so we can mutate + commit. + from sqlalchemy import select + + row = session.scalar( + select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == agent_execution.id) + ) + assert row is not None + row.outputs = json.dumps({"text": {"file_id": "550e8400-e29b-41d4-a716-446655440000", "filename": "x.pdf"}}) + session.commit() + + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([{"name": "text", "type": "file"}])) + with patch( + "services.workflow.node_output_inspector_service.file_helpers.get_signed_file_url", + return_value="https://signed.example/x.pdf", + ): + preview = service.output_preview( + app_model=fake_app_model, + workflow_run_id=workflow_run.id, + node_id="agent-node-1", + output_name="text", + ) + assert preview.output_name == "text" + assert isinstance(preview.value, dict) + assert preview.value["preview_url"] == "https://signed.example/x.pdf" + assert preview.value["filename"] == "x.pdf" + + +def test_keeps_latest_execution_per_node_by_index(flask_req_ctx, fake_app_model): + """Multiple executions for the same node_id → service keeps the highest + ``index`` (matches the agent_v2 retry pattern that re-emits node + executions).""" + graph = {"nodes": [{"id": "agent-1", "data": {"type": "agent", "version": "2"}}]} + workflow_run = _make_workflow_run(app_id=fake_app_model.id, tenant_id=fake_app_model.tenant_id, graph=graph) + older = _make_execution( + app_id=fake_app_model.id, + tenant_id=fake_app_model.tenant_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + node_id="agent-1", + outputs={"text": "first attempt"}, + index=1, + ) + newer = _make_execution( + app_id=fake_app_model.id, + tenant_id=fake_app_model.tenant_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + node_id="agent-1", + outputs={"text": "second attempt"}, + index=5, + ) + with session_factory.create_session() as session: + session.add(workflow_run) + session.add(older) + session.add(newer) + session.commit() + run_id, ex_ids = workflow_run.id, [older.id, newer.id] + + try: + service = NodeOutputInspectorService(binding_resolver=_stub_resolver([{"name": "text", "type": "string"}])) + snapshot = service.snapshot_workflow_run(app_model=fake_app_model, workflow_run_id=run_id) + assert snapshot.node_outputs[0].outputs[0].value_preview == "second attempt" + finally: + with session_factory.create_session() as session: + session.execute(delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(ex_ids))) + session.execute(delete(WorkflowRun).where(WorkflowRun.id == run_id)) + session.commit() diff --git a/api/tests/seed_legacy_model_type_dirty_data.py b/api/tests/seed_legacy_model_type_dirty_data.py new file mode 100644 index 0000000000..c860cea956 --- /dev/null +++ b/api/tests/seed_legacy_model_type_dirty_data.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +API_PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(API_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(API_PROJECT_ROOT)) + +import sqlalchemy as sa + +from tests.helpers.legacy_model_type_migration import ( + DEFAULT_PRIMARY_TENANT_ID, + DEFAULT_SECONDARY_TENANT_ID, + create_minimal_legacy_model_type_schema, + seed_legacy_model_type_dirty_data, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Seed dirty legacy model_type rows for manual migration experiments. " + "Example: uv run --project api python api/tests/seed_legacy_model_type_dirty_data.py " + "--db-url postgresql://postgres:postgres@127.0.0.1:5432/dify" + ) + ) + parser.add_argument("--db-url", required=True, help="SQLAlchemy database URL for the target database.") + parser.add_argument( + "--primary-tenant-id", + default=DEFAULT_PRIMARY_TENANT_ID, + help="Tenant that will contain the main conflict scenario.", + ) + parser.add_argument( + "--secondary-tenant-id", + default=DEFAULT_SECONDARY_TENANT_ID, + help="Tenant used to verify tenant filtering behavior.", + ) + parser.add_argument( + "--create-minimal-schema", + action="store_true", + help="Create the minimal tables needed for the seed when running against an empty scratch database.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + engine = sa.create_engine(args.db_url) + try: + if args.create_minimal_schema: + create_minimal_legacy_model_type_schema(engine) + + fixture = seed_legacy_model_type_dirty_data( + engine, + primary_tenant_id=args.primary_tenant_id, + secondary_tenant_id=args.secondary_tenant_id, + ) + finally: + engine.dispose() + + print( + json.dumps( + { + "primary_tenant_id": fixture.primary.tenant_id, + "secondary_tenant_id": fixture.secondary.tenant_id, + "winner_credential_id": fixture.primary.winner_credential_id, + "loser_credential_id": fixture.primary.loser_credential_id, + "provider_model_id": fixture.primary.provider_model_id, + "load_balancing_config_id": fixture.primary.load_balancing_config_id, + }, + indent=2, + sort_keys=True, + ) + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/api/tests/test_containers_integration_tests/commands/test_legacy_model_type_migration.py b/api/tests/test_containers_integration_tests/commands/test_legacy_model_type_migration.py new file mode 100644 index 0000000000..401696d5ca --- /dev/null +++ b/api/tests/test_containers_integration_tests/commands/test_legacy_model_type_migration.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +import importlib +import io +import json +from collections.abc import Generator +from datetime import datetime, timedelta + +import pytest +import sqlalchemy as sa + +from tests.helpers.legacy_model_type_migration import ( + assert_tenant_rows_use_only_canonical_model_types, + count_rows, + fetch_table_rows, + seed_legacy_model_type_dirty_data, +) + + +def _parse_json_lines(output: io.StringIO) -> list[dict[str, object]]: + return [json.loads(line) for line in output.getvalue().splitlines() if line.strip()] + + +def _json_key(value: object) -> str: + return json.dumps(value, sort_keys=True) + + +def _lb_processing_signatures(lines: list[dict[str, object]]) -> set[tuple[object, ...]]: + signatures: set[tuple[object, ...]] = set() + for line in lines: + attrs = line.get("attrs") + if not isinstance(attrs, dict): + continue + if attrs.get("table_name") != "load_balancing_model_configs": + continue + event = line.get("event") + if event == "row_updated": + signatures.add( + ( + event, + attrs.get("id"), + _json_key(attrs.get("old_values")), + _json_key(attrs.get("new_values")), + ) + ) + elif event == "row_deleted": + signatures.add( + ( + event, + attrs.get("id"), + attrs.get("merge_winner_id"), + ) + ) + elif event == "group_processed": + signatures.add( + ( + event, + attrs.get("table_name"), + _json_key(attrs.get("business_key")), + tuple(attrs.get("group_row_ids", [])), + ) + ) + return signatures + + +def _insert_load_balancing_model_config( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + name: str, + encrypted_config: str, + credential_id: str, + enabled: bool, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO load_balancing_model_configs + ( + id, tenant_id, provider_name, model_name, model_type, name, + encrypted_config, credential_id, credential_source_type, enabled, created_at, updated_at + ) + VALUES + ( + :id, :tenant_id, :provider_name, :model_name, :model_type, :name, + :encrypted_config, :credential_id, :credential_source_type, :enabled, :created_at, :updated_at + ) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "name": name, + "encrypted_config": encrypted_config, + "credential_id": credential_id, + "credential_source_type": "custom_model", + "enabled": enabled, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +@pytest.fixture(scope="session") +def migration_module(): + try: + return importlib.import_module("services.legacy_model_type_migration") + except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path + pytest.fail( + "services.legacy_model_type_migration is missing. " + "Implement LegacyModelTypeMigrationService before running these tests." + ) + + +@pytest.fixture(params=("postgresql", "mysql"), scope="session") +def container_engine(request: pytest.FixtureRequest) -> Generator[tuple[str, sa.Engine], None, None]: + backend_name = request.param + if backend_name == "postgresql": + testcontainers_postgres = pytest.importorskip("testcontainers.postgres") + container = testcontainers_postgres.PostgresContainer("postgres:15-alpine") + else: + testcontainers_mysql = pytest.importorskip("testcontainers.mysql") + container = testcontainers_mysql.MySqlContainer("mysql:8.0") + + container.start() + raw_url = container.get_connection_url() + engine_url = raw_url.replace("mysql://", "mysql+pymysql://", 1) + engine = sa.create_engine(engine_url) + + try: + yield backend_name, engine + finally: + engine.dispose() + container.stop() + + +def test_legacy_model_type_migration_end_to_end_across_supported_backends( + migration_module, + container_engine: tuple[str, sa.Engine], + monkeypatch: pytest.MonkeyPatch, +) -> None: + backend_name, engine = container_engine + helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration") + helper_module.drop_minimal_legacy_model_type_schema(engine) + fixture = seed_legacy_model_type_dirty_data(engine) + + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + + dry_run_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=False, + output=dry_run_output, + tenant_ids=(fixture.primary.tenant_id,), + ).migrate() + + assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 3 + assert deleted_cache_keys == [] + + apply_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=True, + output=apply_output, + tenant_ids=(fixture.primary.tenant_id,), + ).migrate() + first_apply_state = { + table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id) + for table_name in ( + "provider_models", + "tenant_default_models", + "provider_model_settings", + "load_balancing_model_configs", + "provider_model_credentials", + ) + } + + assert_tenant_rows_use_only_canonical_model_types(engine, fixture.primary.tenant_id) + assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 2 + provider_model_row = next( + row for row in first_apply_state["provider_models"] if row["id"] == fixture.primary.provider_model_id + ) + assert provider_model_row["credential_id"] == fixture.primary.winner_credential_id + credential_ids = {str(row["id"]) for row in first_apply_state["provider_model_credentials"]} + assert credential_ids == { + fixture.primary.winner_credential_id, + fixture.primary.distinct_credential_id, + } + lb_row = next( + row + for row in first_apply_state["load_balancing_model_configs"] + if row["id"] == fixture.primary.load_balancing_config_id + ) + assert lb_row["credential_id"] == fixture.primary.winner_credential_id + assert lb_row["encrypted_config"] == fixture.primary.winner_encrypted_config + assert deleted_cache_keys, f"{backend_name} apply run should clear cache keys" + + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=True, + output=io.StringIO(), + tenant_ids=(fixture.primary.tenant_id,), + ).migrate() + second_apply_state = { + table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id) + for table_name in first_apply_state + } + assert second_apply_state == first_apply_state + + +def test_load_balancing_inherit_deduplication_is_applied_consistently_across_supported_backends( + migration_module, + container_engine: tuple[str, sa.Engine], + monkeypatch: pytest.MonkeyPatch, +) -> None: + _, engine = container_engine + helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration") + helper_module.drop_minimal_legacy_model_type_schema(engine) + fixture = seed_legacy_model_type_dirty_data(engine) + + tenant_id = fixture.primary.tenant_id + older_inherit_row_id = "00000000-0000-0000-0000-00000000ee01" + newer_inherit_row_id = "00000000-0000-0000-0000-00000000ee02" + canonical_non_inherit_row_id = "00000000-0000-0000-0000-00000000ee03" + created_at = datetime(2025, 1, 1, 8, 0, 0) + + _insert_load_balancing_model_config( + engine, + row_id=older_inherit_row_id, + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="llm", + name="__inherit__", + encrypted_config='{"api_key":"older-inherit"}', + credential_id=fixture.primary.winner_credential_id, + enabled=True, + created_at=created_at, + updated_at=created_at + timedelta(minutes=15), + ) + _insert_load_balancing_model_config( + engine, + row_id=newer_inherit_row_id, + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + name="__inherit__", + encrypted_config='{"api_key":"newer-inherit"}', + credential_id=fixture.primary.distinct_credential_id, + enabled=True, + created_at=created_at, + updated_at=created_at + timedelta(minutes=30), + ) + _insert_load_balancing_model_config( + engine, + row_id=canonical_non_inherit_row_id, + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="llm", + name=f"{tenant_id}-second-shared", + encrypted_config='{"api_key":"non-inherit-canonical"}', + credential_id=fixture.primary.distinct_credential_id, + enabled=True, + created_at=created_at, + updated_at=created_at + timedelta(minutes=45), + ) + + before_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + + dry_run_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=False, + output=dry_run_output, + tables=("load_balancing_model_configs",), + model_types=(migration_module.ModelType.LLM,), + tenant_ids=(tenant_id,), + ).migrate() + + after_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) + dry_run_lines = _parse_json_lines(dry_run_output) + dry_run_cache_events = [line["event"] for line in dry_run_lines if str(line.get("event")).startswith("cache_")] + dry_run_row_updates = { + str(attrs["id"]) + for line in dry_run_lines + if line.get("event") == "row_updated" + and isinstance((attrs := line.get("attrs")), dict) + and attrs.get("table_name") == "load_balancing_model_configs" + } + dry_run_row_deletes = { + str(attrs["id"]) + for line in dry_run_lines + if line.get("event") == "row_deleted" + and isinstance((attrs := line.get("attrs")), dict) + and attrs.get("table_name") == "load_balancing_model_configs" + } + dry_run_group_processed = [ + attrs + for line in dry_run_lines + if line.get("event") == "group_processed" + and isinstance((attrs := line.get("attrs")), dict) + and attrs.get("table_name") == "load_balancing_model_configs" + ] + + assert after_dry_run == before_dry_run + assert deleted_cache_keys == [] + assert dry_run_row_deletes == {older_inherit_row_id} + assert dry_run_row_updates == { + fixture.primary.load_balancing_config_id, + newer_inherit_row_id, + } + assert canonical_non_inherit_row_id not in dry_run_row_updates + assert "cache_delete_planned" in dry_run_cache_events + assert "cache_deleted" not in dry_run_cache_events + assert len(dry_run_group_processed) == 1 + assert dry_run_group_processed[0]["table_name"] == "load_balancing_model_configs" + assert dry_run_group_processed[0]["business_key"] == { + "tenant_id": tenant_id, + "provider_name": "openai", + "model_name": "gpt-4o-mini", + "model_type": "llm", + } + assert set(dry_run_group_processed[0]["group_row_ids"]) == { + older_inherit_row_id, + newer_inherit_row_id, + } + + apply_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=True, + output=apply_output, + tables=("load_balancing_model_configs",), + model_types=(migration_module.ModelType.LLM,), + tenant_ids=(tenant_id,), + ).migrate() + + apply_lines = _parse_json_lines(apply_output) + apply_cache_events = [line["event"] for line in apply_lines if str(line.get("event")).startswith("cache_")] + apply_group_processed = [ + attrs + for line in apply_lines + if line.get("event") == "group_processed" + and isinstance((attrs := line.get("attrs")), dict) + and attrs.get("table_name") == "load_balancing_model_configs" + ] + assert _lb_processing_signatures(apply_lines) == _lb_processing_signatures(dry_run_lines) + assert "cache_deleted" in apply_cache_events + assert deleted_cache_keys + assert len(apply_group_processed) == len(dry_run_group_processed) + assert [ + ( + attrs["table_name"], + _json_key(attrs["business_key"]), + tuple(attrs["group_row_ids"]), + ) + for attrs in apply_group_processed + ] == [ + ( + attrs["table_name"], + _json_key(attrs["business_key"]), + tuple(attrs["group_row_ids"]), + ) + for attrs in dry_run_group_processed + ] + + lb_rows = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) + surviving_inherit_rows = [row for row in lb_rows if row["name"] == "__inherit__"] + surviving_non_inherit_rows = [row for row in lb_rows if row["name"] != "__inherit__"] + + assert {str(row["id"]) for row in surviving_inherit_rows} == {newer_inherit_row_id} + assert surviving_inherit_rows[0]["model_type"] == "llm" + assert surviving_inherit_rows[0]["credential_id"] == fixture.primary.distinct_credential_id + + assert { + str(row["id"]) + for row in surviving_non_inherit_rows + if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id} + } == {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id} + assert all( + row["model_type"] == "llm" + for row in surviving_non_inherit_rows + if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id} + ) + assert count_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) == len(before_dry_run) - 1 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py index 00309c25d6..5eb9f71e69 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -19,6 +19,7 @@ def test_get_api_key_auth_data_source( test_client_with_containers: FlaskClient, ) -> None: account, tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) binding = DataSourceApiKeyAuthBinding( tenant_id=tenant.id, category="api_key", @@ -26,8 +27,16 @@ def test_get_api_key_auth_data_source( credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), disabled=False, ) - db_session_with_containers.add(binding) + foreign_binding = DataSourceApiKeyAuthBinding( + tenant_id=foreign_tenant.id, + category="api_key", + provider="foreign_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add_all([binding, foreign_binding]) db_session_with_containers.commit() + authenticate_console_client(test_client_with_containers, foreign_account) response = test_client_with_containers.get( "/console/api/api-key-auth/data-source", @@ -60,20 +69,23 @@ def test_create_binding_successful( db_session_with_containers: Session, test_client_with_containers: FlaskClient, ) -> None: - account, _tenant = create_console_account_and_tenant(db_session_with_containers) + account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id + payload = {"category": "api_key", "provider": "custom", "credentials": {"key": "value"}} with ( patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), - patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth, ): response = test_client_with_containers.post( "/console/api/api-key-auth/data-source/binding", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + json=payload, headers=authenticate_console_client(test_client_with_containers, account), ) assert response.status_code == 200 assert response.get_json() == {"result": "success"} + create_auth.assert_called_once_with(tenant_id, payload) def test_create_binding_failure( @@ -129,3 +141,35 @@ def test_delete_binding_successful( ) is None ) + + +def test_delete_binding_scopes_to_authenticated_tenant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_binding = DataSourceApiKeyAuthBinding( + tenant_id=foreign_tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(foreign_binding) + db_session_with_containers.commit() + foreign_binding_id = foreign_binding.id + authenticate_console_client(test_client_with_containers, foreign_account) + + response = test_client_with_containers.delete( + f"/console/api/api-key-auth/data-source/{foreign_binding_id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == foreign_binding_id) + ) + is not None + ) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_external.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_external.py new file mode 100644 index 0000000000..d6b7e9e636 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_external.py @@ -0,0 +1,116 @@ +"""Integration tests for console external knowledge API endpoints.""" + +from __future__ import annotations + +import json + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.dataset import ExternalKnowledgeApis +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def _create_external_api( + db_session: Session, + *, + tenant_id: str, + account_id: str, + name: str, +) -> ExternalKnowledgeApis: + external_api = ExternalKnowledgeApis( + tenant_id=tenant_id, + created_by=account_id, + updated_by=account_id, + name=name, + description=f"{name} description", + settings=json.dumps( + { + "endpoint": "https://example.com", + "api_key": "test-api-key", + } + ), + ) + db_session.add(external_api) + db_session.commit() + return external_api + + +def test_external_api_template_list_filters_paginates_and_scopes_to_authenticated_tenant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise the real list route, including query parsing, DB lookup, and tenant isolation.""" + account, tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + account_id = account.id + tenant_id = tenant.id + foreign_account_id = foreign_account.id + foreign_tenant_id = foreign_tenant.id + headers = authenticate_console_client(test_client_with_containers, account) + + _create_external_api( + db_session_with_containers, + tenant_id=tenant_id, + account_id=account_id, + name="Alpha Primary", + ) + _create_external_api( + db_session_with_containers, + tenant_id=tenant_id, + account_id=account_id, + name="Alpha Secondary", + ) + _create_external_api( + db_session_with_containers, + tenant_id=tenant_id, + account_id=account_id, + name="Beta Unmatched", + ) + _create_external_api( + db_session_with_containers, + tenant_id=foreign_tenant_id, + account_id=foreign_account_id, + name="Alpha Foreign", + ) + + response = test_client_with_containers.get( + "/console/api/datasets/external-knowledge-api?page=1&limit=1&keyword=Alpha", + headers=headers, + ) + + assert response.status_code == 200 + assert response.json is not None + assert response.json["page"] == 1 + assert response.json["limit"] == 1 + assert response.json["total"] == 2 + assert response.json["has_more"] is True + assert len(response.json["data"]) == 1 + + first_page_item = response.json["data"][0] + assert first_page_item["tenant_id"] == tenant_id + assert first_page_item["name"] in {"Alpha Primary", "Alpha Secondary"} + assert first_page_item["settings"] == { + "endpoint": "https://example.com", + "api_key": "test-api-key", + } + assert first_page_item["dataset_bindings"] == [] + + second_response = test_client_with_containers.get( + "/console/api/datasets/external-knowledge-api?page=2&limit=1&keyword=Alpha", + headers=headers, + ) + + assert second_response.status_code == 200 + assert second_response.json is not None + assert second_response.json["page"] == 2 + assert second_response.json["limit"] == 1 + assert second_response.json["total"] == 2 + assert len(second_response.json["data"]) == 1 + + second_page_item = second_response.json["data"][0] + assert second_page_item["name"] in {"Alpha Primary", "Alpha Secondary"} + assert second_response.json["data"][0]["tenant_id"] == tenant_id diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py b/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py index e7852b8fe1..058f4e5fa3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py +++ b/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py @@ -6,11 +6,13 @@ from unittest.mock import patch import pytest from flask.testing import FlaskClient +from sqlalchemy import select from sqlalchemy.orm import Session from constants import HIDDEN_VALUE from libs.rsa import generate_key_pair -from models import Tenant +from models.api_based_extension import APIBasedExtension +from services.api_based_extension_service import APIBasedExtensionService from tests.test_containers_integration_tests.controllers.console.helpers import ( authenticate_console_client, create_console_account_and_tenant, @@ -27,13 +29,14 @@ def _masked_api_key(api_key: str) -> str: def api_extension_client( db_session_with_containers: Session, test_client_with_containers: FlaskClient, -) -> tuple[FlaskClient, dict[str, str], Tenant]: +) -> tuple[FlaskClient, dict[str, str], str]: account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id tenant.encrypt_public_key = generate_key_pair(tenant.id) db_session_with_containers.commit() headers = authenticate_console_client(test_client_with_containers, account) - return test_client_with_containers, headers, tenant + return test_client_with_containers, headers, tenant_id @pytest.fixture(autouse=True) @@ -44,9 +47,10 @@ def mock_api_based_extension_ping(): def test_create_response_masks_plaintext_api_key( - api_extension_client: tuple[FlaskClient, dict[str, str], Tenant], + api_extension_client: tuple[FlaskClient, dict[str, str], str], + db_session_with_containers: Session, ) -> None: - client, headers, _ = api_extension_client + client, headers, tenant_id = api_extension_client api_key = "plain-secret-12345" response = client.post( @@ -62,10 +66,57 @@ def test_create_response_masks_plaintext_api_key( assert response.status_code == 201 assert response.json is not None assert response.json["api_key"] == _masked_api_key(api_key) + extension = db_session_with_containers.scalar( + select(APIBasedExtension).where(APIBasedExtension.id == response.json["id"]).limit(1) + ) + assert extension is not None + assert extension.tenant_id == tenant_id + + +def test_list_scopes_api_based_extensions_to_authenticated_tenant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + account_headers = authenticate_console_client(test_client_with_containers, account) + tenant.encrypt_public_key = generate_key_pair(tenant.id) + _foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_tenant_id = foreign_tenant.id + foreign_tenant.encrypt_public_key = generate_key_pair(foreign_tenant.id) + db_session_with_containers.commit() + + account_create_response = test_client_with_containers.post( + "/console/api/api-based-extension", + headers=account_headers, + json={ + "name": "Tenant API", + "api_endpoint": "https://tenant.example.com/hook", + "api_key": "tenant-secret-12345", + }, + ) + assert account_create_response.status_code == 201 + + APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=foreign_tenant_id, + name="Foreign API", + api_endpoint="https://foreign.example.com/hook", + api_key="foreign-secret-12345", + ) + ) + + response = test_client_with_containers.get( + "/console/api/api-based-extension", + headers=account_headers, + ) + + assert response.status_code == 200 + assert response.json is not None + assert [item["name"] for item in response.json] == ["Tenant API"] def test_update_response_masks_new_plaintext_api_key( - api_extension_client: tuple[FlaskClient, dict[str, str], Tenant], + api_extension_client: tuple[FlaskClient, dict[str, str], str], ) -> None: client, headers, _ = api_extension_client new_api_key = "new-secret-67890" @@ -96,7 +147,7 @@ def test_update_response_masks_new_plaintext_api_key( def test_update_response_masks_existing_plaintext_api_key_when_hidden_value_is_submitted( - api_extension_client: tuple[FlaskClient, dict[str, str], Tenant], + api_extension_client: tuple[FlaskClient, dict[str, str], str], ) -> None: client, headers, _ = api_extension_client existing_api_key = "old-secret-12345" diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_apikey.py b/api/tests/test_containers_integration_tests/controllers/console/test_apikey.py index b55eaa1d58..8559501c89 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/test_apikey.py +++ b/api/tests/test_containers_integration_tests/controllers/console/test_apikey.py @@ -7,9 +7,11 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask.testing import FlaskClient -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.orm import Session +from models import Account +from models.account import AccountStatus, TenantAccountRole from models.enums import ApiTokenType from models.model import ApiToken, App, AppMode from tests.test_containers_integration_tests.controllers.console.helpers import ( @@ -58,6 +60,24 @@ class TestAppApiKeyListResource: assert data["token"].startswith("app-") assert data["id"] is not None + def test_create_api_key_persists_authenticated_tenant( + self, + setup_app: tuple[FlaskClient, dict[str, str], App], + db_session_with_containers: Session, + ) -> None: + client, headers, app = setup_app + tenant_id = app.tenant_id + + resp = client.post(f"/console/api/apps/{app.id}/api-keys", headers=headers) + + assert resp.status_code == 201 + assert resp.json is not None + api_token = db_session_with_containers.scalar(select(ApiToken).where(ApiToken.id == resp.json["id"])) + assert api_token is not None + assert api_token.tenant_id == tenant_id + assert api_token.app_id == app.id + assert api_token.type == ApiTokenType.APP + def test_get_keys_after_create(self, setup_app: tuple[FlaskClient, dict[str, str], App]) -> None: client, headers, app = setup_app client.post(f"/console/api/apps/{app.id}/api-keys", headers=headers) @@ -93,6 +113,21 @@ class TestAppApiKeyListResource: ) assert resp.status_code == 404 + def test_get_foreign_app_keys_not_found( + self, + setup_app: tuple[FlaskClient, dict[str, str], App], + db_session_with_containers: Session, + ) -> None: + client, headers, _ = setup_app + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_app = create_console_app( + db_session_with_containers, foreign_tenant.id, foreign_account.id, AppMode.CHAT + ) + + resp = client.get(f"/console/api/apps/{foreign_app.id}/api-keys", headers=headers) + + assert resp.status_code == 404 + class TestAppApiKeyResource: """Tests for DELETE /apps//api-keys/.""" @@ -139,16 +174,13 @@ class TestAppApiKeyResource: resource.resource_model = MagicMock() resource.resource_id_field = "app_id" - non_admin = MagicMock() - non_admin.is_admin_or_owner = False + non_admin = Account(name="Normal User", email="normal@example.com", status=AccountStatus.ACTIVE) + non_admin.id = "normal-user" + non_admin.role = TenantAccountRole.NORMAL with ( flask_app_with_containers.test_request_context("/"), - patch( - "controllers.console.apikey.current_account_with_tenant", - return_value=(non_admin, "tenant-id"), - ), patch("controllers.console.apikey._get_resource"), ): with pytest.raises(Forbidden): - BaseApiKeyResource.delete(resource, "rid", "kid") + BaseApiKeyResource.delete(resource, "rid", "kid", "tenant-id", non_admin) diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_feature.py b/api/tests/test_containers_integration_tests/controllers/console/test_feature.py new file mode 100644 index 0000000000..9eb76c8152 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/test_feature.py @@ -0,0 +1,65 @@ +"""Integration tests for console feature endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from services.feature_service import FeatureModel, FeatureService, LimitationModel +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_feature_list_returns_current_tenant_configuration_without_vector_space( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise auth, tenant injection, and the feature response shaping contract.""" + account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id + headers = authenticate_console_client(test_client_with_containers, account) + feature_model = FeatureModel( + members=LimitationModel(size=1, limit=2), + apps=LimitationModel(size=3, limit=4), + vector_space=LimitationModel(size=5, limit=6), + ) + + with patch.object(FeatureService, "get_features", return_value=feature_model) as get_features: + response = test_client_with_containers.get( + "/console/api/features", + headers=headers, + ) + + assert response.status_code == 200 + assert response.json is not None + assert response.json["members"] == {"size": 1, "limit": 2} + assert response.json["apps"] == {"size": 3, "limit": 4} + assert "vector_space" not in response.json + get_features.assert_called_once_with(tenant_id, exclude_vector_space=True) + + +def test_feature_vector_space_returns_current_tenant_usage( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise tenant injection and vector-space response serialization through the registered route.""" + account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id + headers = authenticate_console_client(test_client_with_containers, account) + + vector_space = SimpleNamespace(model_dump=lambda: {"size": 0, "limit": 100}) + + with patch.object(FeatureService, "get_vector_space", return_value=vector_space) as get_vector_space: + response = test_client_with_containers.get( + "/console/api/features/vector-space", + headers=headers, + ) + + assert response.status_code == 200 + assert response.json == {"size": 0, "limit": 100} + get_vector_space.assert_called_once_with(tenant_id) diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_files.py b/api/tests/test_containers_integration_tests/controllers/console/test_files.py new file mode 100644 index 0000000000..8985c1ba66 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/test_files.py @@ -0,0 +1,101 @@ +"""Integration tests for console file endpoints.""" + +from __future__ import annotations + +from io import BytesIO + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from models.model import UploadFile +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_file_upload_config_returns_console_limits( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise the authenticated upload-config route and response contract.""" + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + headers = authenticate_console_client(test_client_with_containers, account) + + response = test_client_with_containers.get( + "/console/api/files/upload", + headers=headers, + ) + + assert response.status_code == 200 + assert response.json == { + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT, + "file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT, + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, + "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, + } + + +def test_file_upload_persists_file_for_authenticated_current_user( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise real upload behavior plus current-user and tenant propagation.""" + account, tenant = create_console_account_and_tenant(db_session_with_containers) + account_id = account.id + tenant_id = tenant.id + headers = authenticate_console_client(test_client_with_containers, account) + content = b"hello from console integration" + + response = test_client_with_containers.post( + "/console/api/files/upload", + headers=headers, + data={"file": (BytesIO(content), "tenant-owned.txt")}, + content_type="multipart/form-data", + ) + + assert response.status_code == 201 + assert response.json is not None + assert response.json["name"] == "tenant-owned.txt" + assert response.json["size"] == len(content) + assert response.json["extension"] == "txt" + assert response.json["mime_type"] == "text/plain" + assert response.json["created_by"] == account_id + + upload_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == response.json["id"]).limit(1) + ) + assert upload_file is not None + assert upload_file.tenant_id == tenant_id + assert upload_file.created_by == account_id + assert upload_file.name == "tenant-owned.txt" + assert upload_file.size == len(content) + assert f"/{tenant_id}/" in upload_file.key + + +def test_file_upload_rejects_missing_file_after_authentication( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise the route's validation path with a real authenticated account.""" + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + headers = authenticate_console_client(test_client_with_containers, account) + + response = test_client_with_containers.post( + "/console/api/files/upload", + headers=headers, + data={}, + content_type="multipart/form-data", + ) + + assert response.status_code == 400 + assert response.json is not None + assert response.json["code"] == "no_file_uploaded" diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 454b8096d1..718ff05d22 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -127,7 +127,7 @@ class TestEmailDeliveryTestHandler: monkeypatch.setattr( service_module.FeatureService, "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), + lambda _tenant_id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=False), ) handler = EmailDeliveryTestHandler(session_factory=MagicMock()) context = DeliveryTestContext( @@ -142,7 +142,7 @@ class TestEmailDeliveryTestHandler: monkeypatch.setattr( service_module.FeatureService, "get_features", - lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + lambda _id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=True), ) monkeypatch.setattr(service_module.mail, "is_inited", lambda: False) @@ -159,7 +159,7 @@ class TestEmailDeliveryTestHandler: monkeypatch.setattr( service_module.FeatureService, "get_features", - lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + lambda _id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=True), ) monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) @@ -178,7 +178,7 @@ class TestEmailDeliveryTestHandler: monkeypatch.setattr( service_module.FeatureService, "get_features", - lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + lambda _id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=True), ) monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) mock_mail_send = MagicMock() @@ -214,7 +214,7 @@ class TestEmailDeliveryTestHandler: monkeypatch.setattr( service_module.FeatureService, "get_features", - lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + lambda _id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=True), ) monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) mock_mail_send = MagicMock() diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_account_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_account_task.py new file mode 100644 index 0000000000..68737a4ef6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_account_task.py @@ -0,0 +1,84 @@ +""" +Integration tests for delete_account_task. + +These tests keep billing and email dispatch mocked, but exercise the account +lookup through the real Testcontainers PostgreSQL session factory instead of a +patched session_factory mock. +""" + +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account +from tasks.delete_account_task import delete_account_task + + +def _create_account(db_session: Session, *, email: str = "user@example.com") -> Account: + account = Account( + name=f"account-{uuid4()}", + email=email, + ) + db_session.add(account) + db_session.commit() + return account + + +@pytest.fixture +def mock_external_dependencies(mocker): + billing_service = mocker.patch("tasks.delete_account_task.BillingService") + mail_task = mocker.patch("tasks.delete_account_task.send_deletion_success_task") + return billing_service, mail_task + + +def test_billing_enabled_account_exists_calls_billing_and_sends_email( + db_session_with_containers: Session, mock_external_dependencies, mocker +) -> None: + billing_service, mail_task = mock_external_dependencies + account = _create_account(db_session_with_containers, email="a@b.com") + mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True) + + delete_account_task(account.id) + + billing_service.delete_account.assert_called_once_with(account.id) + mail_task.delay.assert_called_once_with(account.email) + + +def test_billing_disabled_account_exists_sends_email_only( + db_session_with_containers: Session, mock_external_dependencies, mocker +) -> None: + billing_service, mail_task = mock_external_dependencies + account = _create_account(db_session_with_containers, email="x@y.com") + mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False) + + delete_account_task(account.id) + + billing_service.delete_account.assert_not_called() + mail_task.delay.assert_called_once_with(account.email) + + +def test_billing_enabled_account_not_found_calls_billing_no_email(mock_external_dependencies, mocker, caplog) -> None: + billing_service, mail_task = mock_external_dependencies + account_id = str(uuid4()) + mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True) + + delete_account_task(account_id) + + billing_service.delete_account.assert_called_once_with(account_id) + mail_task.delay.assert_not_called() + assert any("not found" in record.getMessage().lower() for record in caplog.records) + + +def test_billing_delete_raises_propagates_and_no_email( + db_session_with_containers: Session, mock_external_dependencies, mocker +) -> None: + billing_service, mail_task = mock_external_dependencies + account = _create_account(db_session_with_containers, email="err@example.com") + billing_service.delete_account.side_effect = RuntimeError("billing down") + mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True) + + with pytest.raises(RuntimeError, match="billing down"): + delete_account_task(account.id) + + mail_task.delay.assert_not_called() diff --git a/api/tests/unit_tests/clients/agent_backend/test_client.py b/api/tests/unit_tests/clients/agent_backend/test_client.py index 7e3be42551..407372d29d 100644 --- a/api/tests/unit_tests/clients/agent_backend/test_client.py +++ b/api/tests/unit_tests/clients/agent_backend/test_client.py @@ -2,12 +2,12 @@ from collections.abc import Iterator import pytest from dify_agent.client import DifyAgentHTTPError, DifyAgentStreamError, DifyAgentTimeoutError, DifyAgentValidationError +from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig from dify_agent.protocol import ( CancelRunRequest, CancelRunResponse, CreateRunRequest, CreateRunResponse, - ExecutionContext, RunEvent, RunStartedEvent, RunStatusResponse, @@ -29,12 +29,11 @@ def _request(): return AgentBackendRunRequestBuilder().build_for_workflow_node( AgentBackendWorkflowNodeRunInput( model=AgentBackendModelConfig( - tenant_id="tenant-1", plugin_id="langgenius/openai", model_provider="openai", model="gpt-test", ), - execution_context=ExecutionContext(tenant_id="tenant-1", invoke_from="workflow_run"), + execution_context=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), workflow_node_job_prompt="Do the task.", user_prompt="hello", ) diff --git a/api/tests/unit_tests/clients/agent_backend/test_fake_client.py b/api/tests/unit_tests/clients/agent_backend/test_fake_client.py index 80b398988a..087cffef81 100644 --- a/api/tests/unit_tests/clients/agent_backend/test_fake_client.py +++ b/api/tests/unit_tests/clients/agent_backend/test_fake_client.py @@ -1,4 +1,4 @@ -from dify_agent.protocol import ExecutionContext +from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig from clients.agent_backend import ( AgentBackendModelConfig, @@ -13,12 +13,11 @@ def _request(): return AgentBackendRunRequestBuilder().build_for_workflow_node( AgentBackendWorkflowNodeRunInput( model=AgentBackendModelConfig( - tenant_id="tenant-1", plugin_id="langgenius/openai", model_provider="openai", model="gpt-test", ), - execution_context=ExecutionContext(tenant_id="tenant-1", invoke_from="workflow_run"), + execution_context=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), workflow_node_job_prompt="Do the task.", user_prompt="hello", ) diff --git a/api/tests/unit_tests/clients/agent_backend/test_request_builder.py b/api/tests/unit_tests/clients/agent_backend/test_request_builder.py index 44c795d70d..0df3940af8 100644 --- a/api/tests/unit_tests/clients/agent_backend/test_request_builder.py +++ b/api/tests/unit_tests/clients/agent_backend/test_request_builder.py @@ -1,18 +1,25 @@ import pytest from agenton.layers import ExitIntent from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID -from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID +from dify_agent.layers.dify_plugin import ( + DIFY_PLUGIN_LLM_LAYER_TYPE_ID, + DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + DifyPluginToolConfig, + DifyPluginToolsLayerConfig, +) +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID from dify_agent.protocol import ( DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID, CreateRunRequest, - ExecutionContext, ) from pydantic import ValidationError from clients.agent_backend import ( AGENT_SOUL_PROMPT_LAYER_ID, + DIFY_EXECUTION_CONTEXT_LAYER_ID, + DIFY_PLUGIN_TOOLS_LAYER_ID, WORKFLOW_NODE_JOB_PROMPT_LAYER_ID, WORKFLOW_USER_PROMPT_LAYER_ID, AgentBackendModelConfig, @@ -26,15 +33,14 @@ from clients.agent_backend import ( def _run_input() -> AgentBackendWorkflowNodeRunInput: return AgentBackendWorkflowNodeRunInput( model=AgentBackendModelConfig( - tenant_id="tenant-1", plugin_id="langgenius/openai", - user_id="user-1", model_provider="openai", model="gpt-test", credentials={"api_key": "secret-key"}, ), - execution_context=ExecutionContext( + execution_context=DifyExecutionContextLayerConfig( tenant_id="tenant-1", + user_id="user-1", workflow_id="workflow-1", workflow_run_id="workflow-run-1", node_id="node-1", @@ -64,13 +70,11 @@ def test_request_builder_outputs_dify_agent_create_run_request(): AGENT_SOUL_PROMPT_LAYER_ID, WORKFLOW_NODE_JOB_PROMPT_LAYER_ID, WORKFLOW_USER_PROMPT_LAYER_ID, - "plugin", + DIFY_EXECUTION_CONTEXT_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID, ] assert request.on_exit.default is ExitIntent.DELETE - assert request.execution_context is not None - assert request.execution_context.node_execution_id == "node-execution-1" assert request.idempotency_key == "workflow-run-1:node-execution-1" assert request.metadata == {"workflow_id": "workflow-1", "node_id": "node-1"} @@ -94,12 +98,41 @@ def test_request_builder_sets_model_and_output_layer_contract_ids(): request = AgentBackendRunRequestBuilder().build_for_workflow_node(_run_input()) layers = {layer.name: layer for layer in request.composition.layers} - assert layers["plugin"].type == DIFY_PLUGIN_LAYER_TYPE_ID + assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].type == DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID + assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].config.user_id == "user-1" assert layers[DIFY_AGENT_MODEL_LAYER_ID].type == DIFY_PLUGIN_LLM_LAYER_TYPE_ID - assert layers[DIFY_AGENT_MODEL_LAYER_ID].deps == {"plugin": "plugin"} + assert layers[DIFY_AGENT_MODEL_LAYER_ID].config.plugin_id == "langgenius/openai" + assert layers[DIFY_AGENT_MODEL_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID} assert layers[DIFY_AGENT_OUTPUT_LAYER_ID].type == DIFY_OUTPUT_LAYER_TYPE_ID +def test_request_builder_adds_dify_plugin_tools_layer_when_configured(): + run_input = _run_input() + run_input.tools = DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/time", + provider="time", + tool_name="current_time", + credential_type="unauthorized", + name="current_time", + description="Get current time.", + credentials={}, + runtime_parameters={}, + parameters=[], + parameters_json_schema={"type": "object", "properties": {}, "required": []}, + ) + ] + ) + + request = AgentBackendRunRequestBuilder().build_for_workflow_node(run_input) + layers = {layer.name: layer for layer in request.composition.layers} + + assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].type == DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID + assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID} + assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID].config.tools[0].tool_name == "current_time" + + def test_request_builder_can_suspend_on_exit_for_resume_or_babysit_paths(): run_input = _run_input() run_input.suspend_on_exit = True @@ -113,12 +146,11 @@ def test_request_builder_rejects_blank_prompts(): with pytest.raises(ValidationError): AgentBackendWorkflowNodeRunInput( model=AgentBackendModelConfig( - tenant_id="tenant-1", plugin_id="langgenius/openai", model_provider="openai", model="gpt-test", ), - execution_context=ExecutionContext(tenant_id="tenant-1", invoke_from="workflow_run"), + execution_context=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), workflow_node_job_prompt=" ", user_prompt="hello", ) diff --git a/api/tests/unit_tests/commands/test_legacy_model_type_migration.py b/api/tests/unit_tests/commands/test_legacy_model_type_migration.py new file mode 100644 index 0000000000..7eead948c1 --- /dev/null +++ b/api/tests/unit_tests/commands/test_legacy_model_type_migration.py @@ -0,0 +1,2025 @@ +from __future__ import annotations + +import importlib +import io +import json +import os +import threading +import time +from datetime import datetime, timedelta +from pathlib import Path +from types import SimpleNamespace +from typing import cast + +import pytest +import sqlalchemy as sa +from click.testing import CliRunner +from sqlalchemy.exc import OperationalError + +from graphon.model_runtime.entities.model_entities import ModelType +from models.account import Tenant +from models.enums import CredentialSourceType +from models.provider import ProviderModel +from tests.helpers.legacy_model_type_migration import ( + ALL_TABLE_NAMES, + LEGACY_TO_CANONICAL, + assert_tenant_rows_use_only_canonical_model_types, + count_rows, + create_minimal_legacy_model_type_schema, + fetch_table_rows, + seed_legacy_model_type_dirty_data, + snapshot_legacy_model_type_state, +) + + +@pytest.fixture +def sqlite_engine(tmp_path: Path) -> sa.Engine: + engine = sa.create_engine(f"sqlite:///{tmp_path / 'legacy_model_type_migration.sqlite'}") + try: + yield engine + finally: + engine.dispose() + + +@pytest.fixture +def dirty_fixture(sqlite_engine: sa.Engine): + return seed_legacy_model_type_dirty_data(sqlite_engine) + + +@pytest.fixture +def migration_module(): + try: + return importlib.import_module("services.legacy_model_type_migration") + except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path + pytest.fail( + "services.legacy_model_type_migration is missing. " + "Implement LegacyModelTypeMigrationService before running these tests." + ) + + +@pytest.fixture +def command_module(): + try: + return importlib.import_module("commands.data_migrate") + except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path + pytest.fail( + "commands.data_migrate is missing. " + "Implement the `flask data-migrate legacy-model-types` command group before running these tests." + ) + + +def _parse_json_lines(output: io.StringIO) -> list[dict[str, object]]: + return [json.loads(line) for line in output.getvalue().splitlines() if line.strip()] + + +def _json_key(value: object) -> str: + return json.dumps(value, sort_keys=True) + + +def _event_signature(line: dict[str, object]) -> tuple[object, ...] | None: + event = line.get("event") + attrs = line.get("attrs") + if not isinstance(attrs, dict): + return None + + if event == "row_updated": + return ( + event, + attrs.get("table_name"), + attrs.get("id"), + _json_key(attrs.get("business_key")), + _json_key(attrs.get("old_values")), + _json_key(attrs.get("new_values")), + _json_key(attrs.get("rewrite_source")), + ) + if event == "row_deleted": + return ( + event, + attrs.get("table_name"), + attrs.get("id"), + _json_key(attrs.get("business_key")), + attrs.get("merge_winner_id"), + ) + if event == "group_processed": + return ( + event, + attrs.get("table_name"), + _json_key(attrs.get("business_key")), + tuple(cast(list[str], attrs.get("group_row_ids", []))), + ) + return None + + +def _collect_processing_signatures(lines: list[dict[str, object]]) -> set[tuple[object, ...]]: + signatures: set[tuple[object, ...]] = set() + for line in lines: + signature = _event_signature(line) + if signature is not None: + signatures.add(signature) + return signatures + + +def _cache_event_row_ids( + lines: list[dict[str, object]], + *, + table_name: str, + row_ids: set[str], + event_name: str, +) -> set[str]: + matching_row_ids: set[str] = set() + for line in lines: + if line.get("event") != event_name: + continue + attrs = line.get("attrs") + if not isinstance(attrs, dict): + continue + if attrs.get("table_name") != table_name: + continue + row_id = str(attrs.get("id")) + if row_id in row_ids: + matching_row_ids.add(row_id) + return matching_row_ids + + +def _patch_batch_size( + monkeypatch: pytest.MonkeyPatch, + migration_module, + *, + batch_size: int, +) -> None: + original_init = migration_module.Migration.__init__ + + def _patched_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + self._batch_size = batch_size + + monkeypatch.setattr(migration_module.Migration, "__init__", _patched_init) + + +def _insert_provider_model( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + credential_id: str | None, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO provider_models + ( + id, tenant_id, provider_name, model_name, model_type, + credential_id, is_valid, created_at, updated_at + ) + VALUES + ( + :id, :tenant_id, :provider_name, :model_name, :model_type, + :credential_id, :is_valid, :created_at, :updated_at + ) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "credential_id": credential_id, + "is_valid": True, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +def _insert_tenant(engine: sa.Engine, *, tenant_id: str) -> None: + with engine.begin() as conn: + conn.execute( + Tenant.__table__.insert().values( + id=tenant_id, + name=f"Tenant {tenant_id}", + plan="basic", + status="normal", + ) + ) + + +def _insert_tenant_default_model( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO tenant_default_models + (id, tenant_id, provider_name, model_name, model_type, created_at, updated_at) + VALUES + (:id, :tenant_id, :provider_name, :model_name, :model_type, :created_at, :updated_at) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +def _insert_provider_model_setting( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + enabled: bool, + load_balancing_enabled: bool, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO provider_model_settings + ( + id, tenant_id, provider_name, model_name, model_type, + enabled, load_balancing_enabled, + created_at, updated_at + ) + VALUES + ( + :id, :tenant_id, :provider_name, :model_name, :model_type, + :enabled, :load_balancing_enabled, + :created_at, :updated_at + ) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "enabled": enabled, + "load_balancing_enabled": load_balancing_enabled, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +def _insert_load_balancing_model_config( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + name: str, + encrypted_config: str, + credential_id: str, + enabled: bool, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO load_balancing_model_configs + ( + id, tenant_id, provider_name, model_name, model_type, name, + encrypted_config, credential_id, credential_source_type, enabled, created_at, updated_at + ) + VALUES + ( + :id, :tenant_id, :provider_name, :model_name, :model_type, :name, + :encrypted_config, :credential_id, :credential_source_type, :enabled, :created_at, :updated_at + ) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "name": name, + "encrypted_config": encrypted_config, + "credential_id": credential_id, + "credential_source_type": CredentialSourceType.CUSTOM_MODEL.value, + "enabled": enabled, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +def test_data_migrate_command_defaults_output_to_stdout_stream( + command_module, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + service_calls: list[dict[str, object]] = [] + fake_stdout = io.StringIO() + + class FakeService: + def __init__( + self, + *, + engine: sa.Engine, + apply: bool, + concurrency: int, + output: io.TextIOBase | None = None, + tables: tuple[str, ...] | None, + model_types: tuple[ModelType, ...], + tenant_ids: tuple[str, ...] | None, + ) -> None: + service_calls.append( + { + "engine": engine, + "apply": apply, + "concurrency": concurrency, + "output": output, + "tables": tables, + "model_types": model_types, + "tenant_ids": tenant_ids, + } + ) + + def migrate(self) -> None: + service_calls.append({"migrated": True}) + + monkeypatch.setattr(command_module, "LegacyModelTypeMigrationService", FakeService) + monkeypatch.setattr(command_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(command_module.sys, "stdout", fake_stdout) + tenant_id_file = tmp_path / "tenant_ids.txt" + tenant_id_file.write_text("tenant-alpha\n", encoding="utf-8") + + data_migrate = command_module.data_migrate + legacy_model_types = cast(object, data_migrate.commands["legacy-model-types"]) + + legacy_model_types.callback( + apply=True, + tables=("provider_models",), + model_types=("llm", "text-embedding"), + tenant_id_file=str(tenant_id_file), + output=None, + concurrency=7, + ) + + assert service_calls[0]["apply"] is True + assert service_calls[0]["concurrency"] == 7 + assert service_calls[0]["output"] is fake_stdout + assert service_calls[0]["tables"] == ("provider_models",) + assert tuple(cast(list[str], service_calls[0]["tenant_ids"])) == ("tenant-alpha",) + assert service_calls[0]["model_types"] == (ModelType.LLM, ModelType.TEXT_EMBEDDING) + assert service_calls[1] == {"migrated": True} + + +def test_data_migrate_command_opens_output_file_and_closes_stream( + command_module, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + service_calls: list[dict[str, object]] = [] + + class FakeService: + def __init__( + self, + *, + engine: sa.Engine, + apply: bool, + concurrency: int, + output: io.TextIOBase | None = None, + tables: tuple[str, ...] | None, + model_types: tuple[ModelType, ...], + tenant_ids: tuple[str, ...] | None, + ) -> None: + service_calls.append( + { + "engine": engine, + "apply": apply, + "concurrency": concurrency, + "output": output, + "tables": tables, + "model_types": model_types, + "tenant_ids": tenant_ids, + } + ) + + def migrate(self) -> None: + output = cast(io.TextIOBase, service_calls[0]["output"]) + output.write('{"event":"test"}\n') + service_calls.append({"migrated": True}) + + monkeypatch.setattr(command_module, "LegacyModelTypeMigrationService", FakeService) + monkeypatch.setattr(command_module, "db", SimpleNamespace(engine=object())) + output_path = tmp_path / "migration.jsonl" + + data_migrate = command_module.data_migrate + legacy_model_types = cast(object, data_migrate.commands["legacy-model-types"]) + + legacy_model_types.callback( + apply=False, + tables=(), + model_types=(), + tenant_id_file=None, + output=output_path, + concurrency=3, + ) + + output_stream = cast(io.TextIOBase, service_calls[0]["output"]) + assert service_calls[0]["concurrency"] == 3 + assert output_stream is not output_path + assert isinstance(output_stream, io.TextIOBase) + assert Path(output_stream.name) == output_path + assert output_stream.closed is True + assert output_path.read_text(encoding="utf-8") == '{"event":"test"}\n' + assert service_calls[1] == {"migrated": True} + + +@pytest.mark.parametrize( + ("cpu_count", "expected_concurrency"), + [ + (8, 8), + (None, 1), + ], +) +def test_data_migrate_command_defaults_concurrency_from_cpu_count_or_falls_back_to_one( + monkeypatch: pytest.MonkeyPatch, + cpu_count: int | None, + expected_concurrency: int, +) -> None: + service_calls: list[dict[str, object]] = [] + command_module = importlib.import_module("commands.data_migrate") + + class FakeService: + def __init__( + self, + *, + engine: sa.Engine, + apply: bool, + concurrency: int, + output: io.TextIOBase | None = None, + tables: tuple[str, ...] | None, + model_types: tuple[ModelType, ...], + tenant_ids: tuple[str, ...] | None, + ) -> None: + service_calls.append( + { + "engine": engine, + "apply": apply, + "concurrency": concurrency, + "output": output, + "tables": tables, + "model_types": model_types, + "tenant_ids": tenant_ids, + } + ) + + def migrate(self) -> None: + service_calls.append({"migrated": True}) + + monkeypatch.setattr(os, "cpu_count", lambda: cpu_count) + importlib.reload(command_module) + try: + monkeypatch.setattr(command_module, "LegacyModelTypeMigrationService", FakeService) + monkeypatch.setattr(command_module, "db", SimpleNamespace(engine=object())) + + result = CliRunner().invoke(command_module.data_migrate, ["legacy-model-types"]) + + assert result.exit_code == 0, result.output + assert expected_concurrency == command_module._DEFAULT_CONCURRENCY + assert service_calls[0]["concurrency"] == expected_concurrency + assert service_calls[1] == {"migrated": True} + finally: + monkeypatch.undo() + importlib.reload(command_module) + + +def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reverse_dependency_expansion( + migration_module, + sqlite_engine: sa.Engine, +) -> None: + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + monkeypatch = pytest.MonkeyPatch() + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + concurrency=1, + tables=("provider_models", "tenant_default_models"), + model_types=(ModelType.LLM,), + tenant_ids=("tenant-alpha", "tenant-beta"), + ) + + service.migrate() + finally: + monkeypatch.undo() + + assert seen_runs == [ + ("tenant-alpha", ("provider_models", "tenant_default_models"), (ModelType.LLM,)), + ("tenant-beta", ("provider_models", "tenant_default_models"), (ModelType.LLM,)), + ] + + +def test_service_migrate_without_tenant_ids_discovers_tenants_per_selected_table_without_querying_tenants( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + provider_tenant_id = "00000000-0000-0000-0000-000000000111" + default_tenant_id = "00000000-0000-0000-0000-000000000222" + empty_tenant_id = "00000000-0000-0000-0000-000000000333" + for tenant_id in (provider_tenant_id, default_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 1, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_provider_model( + sqlite_engine, + row_id="10000000-0000-0000-0000-000000000111", + tenant_id=provider_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + _insert_tenant_default_model( + sqlite_engine, + row_id="20000000-0000-0000-0000-000000000222", + tenant_id=default_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + executed_sql: list[str] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + def _record_sql( + conn: sa.engine.Connection, + cursor: object, + statement: str, + parameters: object, + context: object, + executemany: bool, + ) -> None: + del conn, cursor, parameters, context, executemany + executed_sql.append(statement) + + sa.event.listen(sqlite_engine, "before_cursor_execute", _record_sql) + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("provider_models", "tenant_default_models"), + model_types=(ModelType.LLM,), + ) + + service.migrate() + finally: + sa.event.remove(sqlite_engine, "before_cursor_execute", _record_sql) + + assert seen_runs == [ + (provider_tenant_id, ("provider_models",), (ModelType.LLM,)), + (default_tenant_id, ("tenant_default_models",), (ModelType.LLM,)), + ] + normalized_statements = [" ".join(statement.lower().split()) for statement in executed_sql] + discovery_statements = [statement for statement in normalized_statements if statement.startswith("select")] + table_names = ("provider_models", "tenant_default_models") + table_discovery_statements = [ + statement + for statement in discovery_statements + if any(f" from {table_name} " in f" {statement} " for table_name in table_names) + ] + + assert [statement for statement in discovery_statements if " from tenants " in f" {statement} "] == [] + assert [statement for statement in discovery_statements if " union " in f" {statement} "] == [] + assert [ + next(table_name for table_name in table_names if f" from {table_name} " in f" {statement} ") + for statement in table_discovery_statements + ] == list(table_names) + + +def test_service_migrate_without_tenant_ids_filters_provider_model_tenants_by_selected_model_types( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + llm_tenant_id = "00000000-0000-0000-0000-000000000411" + embedding_tenant_id = "00000000-0000-0000-0000-000000000422" + empty_tenant_id = "00000000-0000-0000-0000-000000000433" + for tenant_id in (llm_tenant_id, embedding_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 2, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_provider_model( + sqlite_engine, + row_id="30000000-0000-0000-0000-000000000411", + tenant_id=llm_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + _insert_provider_model( + sqlite_engine, + row_id="30000000-0000-0000-0000-000000000422", + tenant_id=embedding_tenant_id, + provider_name="openai", + model_name="text-embedding-3-large", + model_type="embeddings", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("provider_models",), + model_types=(ModelType.LLM,), + ) + + service.migrate() + + assert seen_runs == [ + (llm_tenant_id, ("provider_models",), (ModelType.LLM,)), + ] + + +def test_service_migrate_without_tenant_ids_discovers_all_load_balancing_tenants_for_simpler_table_scoped_query( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + inherit_llm_tenant_id = "00000000-0000-0000-0000-000000000511" + inherit_embedding_tenant_id = "00000000-0000-0000-0000-000000000522" + empty_tenant_id = "00000000-0000-0000-0000-000000000533" + for tenant_id in (inherit_llm_tenant_id, inherit_embedding_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 3, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_load_balancing_model_config( + sqlite_engine, + row_id="40000000-0000-0000-0000-000000000511", + tenant_id=inherit_llm_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + name="__inherit__", + encrypted_config=json.dumps({"api_key": "inherit-llm"}), + credential_id="50000000-0000-0000-0000-000000000511", + enabled=True, + created_at=created_at, + updated_at=updated_at, + ) + _insert_load_balancing_model_config( + sqlite_engine, + row_id="40000000-0000-0000-0000-000000000522", + tenant_id=inherit_embedding_tenant_id, + provider_name="openai", + model_name="text-embedding-3-large", + model_type=ModelType.TEXT_EMBEDDING.value, + name="__inherit__", + encrypted_config=json.dumps({"api_key": "inherit-embedding"}), + credential_id="50000000-0000-0000-0000-000000000522", + enabled=True, + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + # Load-balancing tenant discovery is a deliberate exception: it scans the + # whole table so the discovery query stays easy to understand, even when + # the scheduled tenant set is wider than the selected model types. + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("load_balancing_model_configs",), + model_types=(ModelType.LLM,), + ) + + service.migrate() + + assert seen_runs == [ + (inherit_llm_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)), + (inherit_embedding_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)), + ] + + +def test_service_migrate_with_concurrency_greater_than_one_runs_tenants_in_parallel_without_changing_migration_scope( + migration_module, + sqlite_engine: sa.Engine, +) -> None: + init_calls: list[dict[str, object]] = [] + started_tenants: list[str] = [] + worker_errors: list[BaseException] = [] + release_runs = threading.Event() + all_started = threading.Event() + active_runs = 0 + max_active_runs = 0 + state_lock = threading.Lock() + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + self._tenant_id = tenant_id + init_calls.append( + { + "tenant_id": tenant_id, + "engine": engine, + "apply": apply, + "model_types": model_types, + "table_names": tuple(model.__table__.name for model in orm_models), + } + ) + + def run(self) -> None: + nonlocal active_runs, max_active_runs + with state_lock: + active_runs += 1 + max_active_runs = max(max_active_runs, active_runs) + started_tenants.append(self._tenant_id) + if len(started_tenants) == 2: + all_started.set() + + release_runs.wait(timeout=1) + + with state_lock: + active_runs -= 1 + + monkeypatch = pytest.MonkeyPatch() + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + concurrency=2, + tables=("provider_models",), + model_types=(ModelType.LLM,), + tenant_ids=("tenant-alpha", "tenant-beta"), + ) + + def _run_service() -> None: + try: + service.migrate() + except BaseException as exc: # pragma: no cover - test harness + worker_errors.append(exc) + + worker = threading.Thread(target=_run_service) + worker.start() + started_in_parallel = all_started.wait(timeout=0.5) + release_runs.set() + worker.join(timeout=1) + finally: + monkeypatch.undo() + + assert worker_errors == [] + assert started_in_parallel is True + assert worker.is_alive() is False + assert max_active_runs == 2 + assert {call["tenant_id"] for call in init_calls} == {"tenant-alpha", "tenant-beta"} + for call in init_calls: + assert tuple(cast(tuple[str, ...], call["table_names"])) == ("provider_models",) + assert call["model_types"] == (ModelType.LLM,) + + +def test_service_parallel_migrate_serializes_shared_output_by_line( + migration_module, + sqlite_engine: sa.Engine, +) -> None: + worker_errors: list[BaseException] = [] + start_barrier = threading.Barrier(2) + + class SlowLineOutput(io.StringIO): + def __init__(self) -> None: + super().__init__() + self.overlap_count = 0 + self._in_write = False + self._state_lock = threading.Lock() + + def write(self, s: str) -> int: + with self._state_lock: + if self._in_write: + self.overlap_count += 1 + self._in_write = True + try: + time.sleep(0.01) + return super().write(s) + finally: + with self._state_lock: + self._in_write = False + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + self._tenant_id = tenant_id + self._output = output + + def run(self) -> None: + try: + start_barrier.wait(timeout=1) + except threading.BrokenBarrierError as exc: + raise AssertionError("parallel migrate should schedule both tenant runs together") from exc + + for index in range(3): + self._output.write(f"{self._tenant_id}:line-{index}\n") + + monkeypatch = pytest.MonkeyPatch() + output = SlowLineOutput() + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + concurrency=2, + output=output, + tables=("provider_models",), + model_types=(ModelType.LLM,), + tenant_ids=("tenant-alpha", "tenant-beta"), + ) + + def _run_service() -> None: + try: + service.migrate() + except BaseException as exc: # pragma: no cover - test harness + worker_errors.append(exc) + + worker = threading.Thread(target=_run_service) + worker.start() + worker.join(timeout=2) + finally: + monkeypatch.undo() + + assert worker.is_alive() is False + assert worker_errors == [] + assert output.overlap_count == 0 + assert sorted(output.getvalue().splitlines()) == sorted( + [ + "tenant-alpha:line-0", + "tenant-alpha:line-1", + "tenant-alpha:line-2", + "tenant-beta:line-0", + "tenant-beta:line-1", + "tenant-beta:line-2", + ] + ) + + +def test_migration_dry_run_emits_json_lines_without_db_or_cache_mutation( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + before = snapshot_legacy_model_type_state(sqlite_engine) + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + output = io.StringIO() + + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + after = snapshot_legacy_model_type_state(sqlite_engine) + assert after == before + assert deleted_cache_keys == [] + + lines = [json.loads(line) for line in output.getvalue().splitlines() if line.strip()] + assert lines, "dry-run should emit JSON lines" + assert all({"event", "message", "attrs", "ts"} <= set(line) for line in lines) + rendered_output = output.getvalue() + assert dirty_fixture.primary.loser_credential_id in rendered_output + assert dirty_fixture.primary.loser_credential_name in rendered_output + assert dirty_fixture.primary.loser_encrypted_config in rendered_output + + +def test_dry_run_and_apply_share_processing_scope_and_differ_only_on_side_effects( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + before = snapshot_legacy_model_type_state(sqlite_engine) + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + + dry_run_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=dry_run_output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + after_dry_run = snapshot_legacy_model_type_state(sqlite_engine) + dry_run_lines = _parse_json_lines(dry_run_output) + + apply_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=apply_output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + after_apply = snapshot_legacy_model_type_state(sqlite_engine) + apply_lines = _parse_json_lines(apply_output) + + assert after_dry_run == before + assert after_apply != before + + dry_run_signatures = _collect_processing_signatures(dry_run_lines) + apply_signatures = _collect_processing_signatures(apply_lines) + assert apply_signatures == dry_run_signatures + + dry_run_cache_events = [line["event"] for line in dry_run_lines if str(line.get("event")).startswith("cache_")] + apply_cache_events = [line["event"] for line in apply_lines if str(line.get("event")).startswith("cache_")] + assert "cache_deleted" not in dry_run_cache_events + assert "cache_delete_planned" in dry_run_cache_events + assert "cache_deleted" in apply_cache_events + assert deleted_cache_keys + + dry_run_rewrite_signatures = { + signature + for signature in dry_run_signatures + if signature[0] == "row_updated" + and signature[1] in {"provider_models", "load_balancing_model_configs"} + and signature[-1] != _json_key(None) + } + apply_rewrite_signatures = { + signature + for signature in apply_signatures + if signature[0] == "row_updated" + and signature[1] in {"provider_models", "load_balancing_model_configs"} + and signature[-1] != _json_key(None) + } + assert apply_rewrite_signatures == dry_run_rewrite_signatures + + dry_run_lb_signatures = { + signature + for signature in dry_run_signatures + if signature[0] == "row_updated" and signature[1] == "load_balancing_model_configs" + } + apply_lb_signatures = { + signature + for signature in apply_signatures + if signature[0] == "row_updated" and signature[1] == "load_balancing_model_configs" + } + assert apply_lb_signatures == dry_run_lb_signatures + + +def test_provider_models_processing_uses_same_plan_locking_and_transaction_entry_for_dry_run_and_apply( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + dry_migration = migration_module.Migration( + tenant_id=dirty_fixture.primary.tenant_id, + engine=sqlite_engine, + apply=False, + output=io.StringIO(), + model_types=(ModelType.LLM,), + orm_models=(ProviderModel,), + ) + candidate = dry_migration._load_provider_model_candidates(None)[0] + business_key = migration_module._ProviderModelBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + model_type=candidate.canonical_model_type, + ) + + apply_migration = migration_module.Migration( + tenant_id=dirty_fixture.primary.tenant_id, + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + model_types=(ModelType.LLM,), + orm_models=(ProviderModel,), + ) + + current_phase = {"name": "dry"} + lock_rows_seen: list[tuple[str, bool]] = [] + begin_calls: list[str] = [] + configure_calls: list[str] = [] + + class _FakeBeginContext: + def __init__(self, phase: str) -> None: + self._phase = phase + + def __enter__(self) -> None: + begin_calls.append(self._phase) + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + class _FakeSession: + def __init__(self, phase: str) -> None: + self._phase = phase + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def begin(self) -> _FakeBeginContext: + return _FakeBeginContext(self._phase) + + def _fake_session_factory(engine: sa.Engine) -> _FakeSession: + return _FakeSession(current_phase["name"]) + + def _fake_build_plan(self, session, candidate, *, lock_rows: bool): + lock_rows_seen.append((current_phase["name"], lock_rows)) + return SimpleNamespace(group_row_ids=[str(candidate.row.id)], winner=None, loser_rows=[]) + + def _fake_emit_plan(self, plan, *, session, tx_id: str, business_key: dict[str, object]) -> None: + return None + + def _fake_configure(self, session) -> None: + configure_calls.append(current_phase["name"]) + + monkeypatch.setattr(migration_module, "_session_factory", _fake_session_factory) + monkeypatch.setattr(migration_module.Migration, "_build_provider_model_group_plan", _fake_build_plan) + monkeypatch.setattr(migration_module.Migration, "_emit_provider_model_group_plan", _fake_emit_plan) + monkeypatch.setattr(migration_module.Migration, "_configure_lock_timeout", _fake_configure) + + dry_migration._process_provider_model_group(candidate, business_key) + current_phase["name"] = "apply" + apply_migration._process_provider_model_group(candidate, business_key) + + assert [phase for phase, _ in lock_rows_seen] == ["dry", "apply"] + assert lock_rows_seen[0][1] == lock_rows_seen[1][1] + assert begin_calls == ["dry", "apply"] + assert configure_calls == ["dry", "apply"] + + +@pytest.mark.parametrize( + ("orig", "expected"), + [ + (SimpleNamespace(pgcode="55P03"), True), + (SimpleNamespace(sqlstate="55P03"), True), + (SimpleNamespace(errno=1205), True), + (RuntimeError("canceling statement due to lock timeout"), True), + (SimpleNamespace(pgcode="23505"), False), + (SimpleNamespace(errno=1213), False), + ], +) +def test_is_lock_timeout_error_prefers_structured_backend_codes( + migration_module, + sqlite_engine: sa.Engine, + orig: object, + expected: bool, +) -> None: + migration = migration_module.Migration( + tenant_id="tenant-1", + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + model_types=(ModelType.LLM,), + orm_models=(), + ) + exc = OperationalError("SELECT 1", {}, orig) + + assert migration._is_lock_timeout_error(exc) is expected + + +def test_process_load_balancing_model_config_row_logs_stacktrace_for_lock_timeout( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + output = io.StringIO() + migration = migration_module.Migration( + tenant_id="tenant-1", + engine=sqlite_engine, + apply=True, + output=output, + model_types=(ModelType.LLM,), + orm_models=(migration_module.LoadBalancingModelConfig,), + ) + candidate = migration_module._RowWithRawModelType( + row=SimpleNamespace(id="lb-row-1"), + raw_model_type="text-generation", + canonical_model_type=ModelType.LLM, + ) + lock_timeout_exc = OperationalError("SELECT 1", {}, SimpleNamespace(pgcode="55P03")) + + class _FakeBeginContext: + def __enter__(self) -> None: + return None + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + class _FakeSession: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def begin(self) -> _FakeBeginContext: + return _FakeBeginContext() + + def _fake_session_factory(engine: sa.Engine) -> _FakeSession: + return _FakeSession() + + def _fake_reload(self, session, original_candidate, *, lock_rows: bool): + raise lock_timeout_exc + + monkeypatch.setattr(migration_module, "_session_factory", _fake_session_factory) + monkeypatch.setattr(migration_module.Migration, "_configure_lock_timeout", lambda self, session: None) + monkeypatch.setattr( + migration_module.Migration, + "_reload_load_balancing_model_config_candidate", + _fake_reload, + ) + + migration._process_load_balancing_model_config_row(candidate) + + lines = _parse_json_lines(output) + assert len(lines) == 1 + assert lines[0]["event"] == "lock_timeout_skipped" + attrs = cast(dict[str, object], lines[0]["attrs"]) + assert attrs["table_name"] == "load_balancing_model_configs" + assert attrs["id"] == "lb-row-1" + assert attrs["error"] == str(lock_timeout_exc) + assert isinstance(attrs["stacktrace"], str) + assert "OperationalError" in attrs["stacktrace"] + + +def test_process_load_balancing_model_config_row_logs_update_after_sql_execution( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + migration = migration_module.Migration( + tenant_id="tenant-1", + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + model_types=(ModelType.LLM,), + orm_models=(migration_module.LoadBalancingModelConfig,), + ) + candidate = migration_module._RowWithRawModelType( + row=SimpleNamespace(id="lb-row-1"), + raw_model_type="text-generation", + canonical_model_type=ModelType.LLM, + ) + action_log: list[str] = [] + + class _FakeBeginContext: + def __enter__(self) -> None: + action_log.append("begin") + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + class _FakeSession: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def begin(self) -> _FakeBeginContext: + return _FakeBeginContext() + + def execute(self, stmt) -> None: + action_log.append("sql_execute") + + def _fake_session_factory(engine: sa.Engine) -> _FakeSession: + return _FakeSession() + + def _fake_configure(self, session) -> None: + action_log.append("configure_lock_timeout") + + def _fake_reload(self, session, original_candidate, *, lock_rows: bool): + action_log.append(f"reload_candidate:{lock_rows}") + return candidate + + def _fake_log_row_updated(self, *args, **kwargs) -> None: + action_log.append("log_row_updated") + + def _fake_cache_cleanup(self, *, row_id: str, tx_id: str) -> None: + action_log.append("cache_cleanup") + + monkeypatch.setattr(migration_module, "_session_factory", _fake_session_factory) + monkeypatch.setattr(migration_module.Migration, "_configure_lock_timeout", _fake_configure) + monkeypatch.setattr( + migration_module.Migration, + "_reload_load_balancing_model_config_candidate", + _fake_reload, + ) + monkeypatch.setattr(migration_module.Migration, "_log_row_updated", _fake_log_row_updated) + monkeypatch.setattr( + migration_module.Migration, + "_log_load_balancing_model_config_cache_cleanup", + _fake_cache_cleanup, + ) + + migration._process_load_balancing_model_config_row(candidate) + + assert action_log == [ + "begin", + "configure_lock_timeout", + "reload_candidate:True", + "sql_execute", + "log_row_updated", + "cache_cleanup", + ] + + +def test_load_balancing_model_config_cache_delete_failure_logs_stacktrace( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + def _raise_delete_failure(self) -> None: + raise RuntimeError("cache delete boom") + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _raise_delete_failure) + + output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=output, + tables=("load_balancing_model_configs",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + + failed_events = [ + cast(dict[str, object], line["attrs"]) + for line in _parse_json_lines(output) + if line.get("event") == "cache_delete_failed" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == "load_balancing_model_configs" + ] + + assert len(failed_events) == 1 + assert failed_events[0]["error"] == "cache delete boom" + assert isinstance(failed_events[0]["stacktrace"], str) + assert "RuntimeError: cache delete boom" in cast(str, failed_events[0]["stacktrace"]) + + +def test_group_completed_logs_exist_for_all_grouped_tables_and_use_canonical_model_type( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + output = io.StringIO() + + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + lines = _parse_json_lines(output) + group_completed_records = [ + line + for line in lines + if isinstance(line.get("attrs"), dict) and "group_row_ids" in cast(dict[str, object], line["attrs"]) + ] + grouped_table_names = { + cast(dict[str, object], record["attrs"]).get("table_name") for record in group_completed_records + } + + assert grouped_table_names >= { + "provider_models", + "tenant_default_models", + "provider_model_settings", + "provider_model_credentials", + } + + for record in group_completed_records: + attrs = cast(dict[str, object], record["attrs"]) + business_key = cast(dict[str, object], attrs["business_key"]) + assert isinstance(attrs["group_row_ids"], list) + assert attrs["group_row_ids"] + if "model_type" in business_key: + assert business_key["model_type"] in { + ModelType.LLM.value, + ModelType.TEXT_EMBEDDING.value, + ModelType.RERANK.value, + } + assert business_key["model_type"] not in LEGACY_TO_CANONICAL + + +def test_provider_models_group_completed_log_includes_related_canonical_row_ids( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_batch_size(monkeypatch, migration_module, batch_size=1) + inserted_row_id = "00000000-0000-0000-0000-00000000aa01" + created_at = datetime(2025, 1, 1, 10, 0, 0) + updated_at = created_at + timedelta(minutes=5) + _insert_provider_model( + sqlite_engine, + row_id=inserted_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + credential_id=dirty_fixture.primary.distinct_credential_id, + created_at=created_at, + updated_at=updated_at, + ) + + output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=output, + tables=("provider_models",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + + lines = _parse_json_lines(output) + matching_records = [] + for line in lines: + attrs = line.get("attrs") + if not isinstance(attrs, dict): + continue + business_key = attrs.get("business_key") + if not isinstance(business_key, dict): + continue + if ( + attrs.get("table_name") == "provider_models" + and business_key.get("tenant_id") == dirty_fixture.primary.tenant_id + and business_key.get("provider_name") == "openai" + and business_key.get("model_name") == "gpt-4o-mini" + and business_key.get("model_type") == ModelType.LLM.value + and "group_row_ids" in attrs + ): + matching_records.append(attrs) + + assert len(matching_records) == 1 + assert set(cast(list[str], matching_records[0]["group_row_ids"])) == { + dirty_fixture.primary.provider_model_id, + inserted_row_id, + } + + +def test_provider_model_settings_group_crossing_batches_is_completed_once_with_all_group_row_ids( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_batch_size(monkeypatch, migration_module, batch_size=1) + inserted_row_id = "00000000-0000-0000-0000-00000000cc01" + created_at = datetime(2025, 1, 1, 9, 0, 0) + updated_at = created_at + timedelta(minutes=10) + _insert_provider_model_setting( + sqlite_engine, + row_id=inserted_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, + created_at=created_at, + updated_at=updated_at, + ) + + output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=output, + tables=("provider_model_settings",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + + lines = _parse_json_lines(output) + matching_records = [] + for line in lines: + attrs = line.get("attrs") + if not isinstance(attrs, dict): + continue + business_key = attrs.get("business_key") + if not isinstance(business_key, dict): + continue + if ( + attrs.get("table_name") == "provider_model_settings" + and business_key.get("tenant_id") == dirty_fixture.primary.tenant_id + and business_key.get("provider_name") == "openai" + and business_key.get("model_name") == "gpt-4o-mini" + and business_key.get("model_type") == ModelType.LLM.value + and "group_row_ids" in attrs + ): + matching_records.append(attrs) + + assert len(matching_records) == 1 + assert set(cast(list[str], matching_records[0]["group_row_ids"])) == { + dirty_fixture.primary.provider_model_setting_id, + inserted_row_id, + } + + +def test_load_balancing_inherit_rows_are_deduplicated_by_normalized_model_type_before_canonicalization( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + older_canonical_row_id = "00000000-0000-0000-0000-00000000dd01" + newer_legacy_row_id = "00000000-0000-0000-0000-00000000dd02" + created_at = datetime(2025, 1, 1, 8, 0, 0) + older_updated_at = created_at + timedelta(minutes=15) + newer_updated_at = created_at + timedelta(minutes=30) + _insert_load_balancing_model_config( + sqlite_engine, + row_id=older_canonical_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + name="__inherit__", + encrypted_config='{"api_key":"older-inherit"}', + credential_id=dirty_fixture.primary.winner_credential_id, + enabled=True, + created_at=created_at, + updated_at=older_updated_at, + ) + _insert_load_balancing_model_config( + sqlite_engine, + row_id=newer_legacy_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + name="__inherit__", + encrypted_config='{"api_key":"newer-inherit"}', + credential_id=dirty_fixture.primary.distinct_credential_id, + enabled=True, + created_at=created_at, + updated_at=newer_updated_at, + ) + + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + + tenant_id = dirty_fixture.primary.tenant_id + table_name = "load_balancing_model_configs" + expected_row_ids = {older_canonical_row_id, newer_legacy_row_id} + + dry_run_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=dry_run_output, + tables=(table_name,), + model_types=(ModelType.LLM,), + tenant_ids=(tenant_id,), + ).migrate() + + dry_run_lines = _parse_json_lines(dry_run_output) + dry_run_signatures = { + signature + for signature in _collect_processing_signatures(dry_run_lines) + if signature[1] == table_name and signature[2] in expected_row_ids + } + dry_run_row_updates = [ + cast(dict[str, object], line["attrs"]) + for line in dry_run_lines + if line.get("event") == "row_updated" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == table_name + and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids + ] + assert len(dry_run_row_updates) == 1 + assert str(dry_run_row_updates[0]["id"]) == newer_legacy_row_id + assert dry_run_row_updates[0]["old_values"] == {"model_type": "text-generation"} + assert dry_run_row_updates[0]["new_values"] == {"model_type": ModelType.LLM.value} + assert all("rewrite_source" not in attrs for attrs in dry_run_row_updates) + + dry_run_row_deletes = [ + cast(dict[str, object], line["attrs"]) + for line in dry_run_lines + if line.get("event") == "row_deleted" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == table_name + and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids + ] + assert len(dry_run_row_deletes) == 1 + assert dry_run_row_deletes[0]["business_key"] == { + "tenant_id": tenant_id, + "provider_name": "openai", + "model_name": "gpt-4o-mini", + "model_type": ModelType.LLM.value, + } + assert dry_run_row_deletes[0]["merge_winner_id"] == newer_legacy_row_id + assert dry_run_row_deletes[0]["row"] == { + "id": older_canonical_row_id, + "tenant_id": tenant_id, + "provider_name": "openai", + "model_name": "gpt-4o-mini", + "model_type": ModelType.LLM.value, + "name": "__inherit__", + "encrypted_config": {"api_key": "older-inherit"}, + "credential_id": dirty_fixture.primary.winner_credential_id, + "credential_source_type": CredentialSourceType.CUSTOM_MODEL.value, + "enabled": True, + "created_at": created_at.isoformat(), + "updated_at": older_updated_at.isoformat(), + } + + dry_run_deleted_index = next( + index + for index, line in enumerate(dry_run_lines) + if line.get("event") == "row_deleted" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("id") == older_canonical_row_id + ) + dry_run_updated_index = next( + index + for index, line in enumerate(dry_run_lines) + if line.get("event") == "row_updated" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("id") == newer_legacy_row_id + ) + assert dry_run_deleted_index < dry_run_updated_index + + dry_run_cache_plan_ids = _cache_event_row_ids( + dry_run_lines, + table_name=table_name, + row_ids=expected_row_ids, + event_name="cache_delete_planned", + ) + assert newer_legacy_row_id in dry_run_cache_plan_ids + + apply_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=apply_output, + tables=(table_name,), + model_types=(ModelType.LLM,), + tenant_ids=(tenant_id,), + ).migrate() + + apply_lines = _parse_json_lines(apply_output) + apply_signatures = { + signature + for signature in _collect_processing_signatures(apply_lines) + if signature[1] == table_name and signature[2] in expected_row_ids + } + apply_row_updates = [ + cast(dict[str, object], line["attrs"]) + for line in apply_lines + if line.get("event") == "row_updated" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == table_name + and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids + ] + assert len(apply_row_updates) == 1 + assert str(apply_row_updates[0]["id"]) == newer_legacy_row_id + assert apply_signatures == dry_run_signatures + + apply_cache_delete_ids = _cache_event_row_ids( + apply_lines, + table_name=table_name, + row_ids=expected_row_ids, + event_name="cache_deleted", + ) + assert apply_cache_delete_ids == dry_run_cache_plan_ids + assert deleted_cache_keys + + lb_rows = fetch_table_rows(sqlite_engine, table_name, tenant_id=tenant_id) + surviving_rows = [row for row in lb_rows if str(row["id"]) in expected_row_ids] + assert len(surviving_rows) == 1 + surviving_row = surviving_rows[0] + assert surviving_row["id"] == newer_legacy_row_id + assert surviving_row["tenant_id"] == tenant_id + assert surviving_row["provider_name"] == "openai" + assert surviving_row["model_name"] == "gpt-4o-mini" + assert surviving_row["model_type"] == ModelType.LLM.value + assert surviving_row["name"] == "__inherit__" + assert surviving_row["encrypted_config"] == '{"api_key":"newer-inherit"}' + assert surviving_row["credential_id"] == dirty_fixture.primary.distinct_credential_id + assert surviving_row["credential_source_type"] == CredentialSourceType.CUSTOM_MODEL.value + + +def test_load_balancing_non_inherit_rows_do_not_participate_in_normalized_model_type_deduplication( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + inserted_row_id = "00000000-0000-0000-0000-00000000dd03" + created_at = datetime(2025, 1, 1, 8, 0, 0) + updated_at = created_at + timedelta(minutes=15) + _insert_load_balancing_model_config( + sqlite_engine, + row_id=inserted_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + name=dirty_fixture.primary.loser_credential_name, + encrypted_config='{"api_key":"second-lb"}', + credential_id=dirty_fixture.primary.distinct_credential_id, + enabled=True, + created_at=created_at, + updated_at=updated_at, + ) + + output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=output, + tables=("load_balancing_model_configs",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + + lines = _parse_json_lines(output) + row_deleted_events = [ + cast(dict[str, object], line["attrs"]) + for line in lines + if line.get("event") == "row_deleted" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == "load_balancing_model_configs" + ] + assert row_deleted_events == [] + + lb_rows = fetch_table_rows( + sqlite_engine, + "load_balancing_model_configs", + tenant_id=dirty_fixture.primary.tenant_id, + ) + matching_rows = [ + row for row in lb_rows if str(row["id"]) in {dirty_fixture.primary.load_balancing_config_id, inserted_row_id} + ] + assert len(matching_rows) == 2 + assert all(row["model_type"] == ModelType.LLM.value for row in matching_rows) + + +def test_migration_apply_updates_all_five_tables_and_rewrites_credential_references( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + output = io.StringIO() + + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + assert_tenant_rows_use_only_canonical_model_types(sqlite_engine, dirty_fixture.primary.tenant_id) + + provider_model_rows = fetch_table_rows(sqlite_engine, "provider_models", tenant_id=dirty_fixture.primary.tenant_id) + provider_model_row = next( + row for row in provider_model_rows if row["id"] == dirty_fixture.primary.provider_model_id + ) + assert provider_model_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + assert provider_model_row["credential_id"] == dirty_fixture.primary.winner_credential_id + + lb_rows = fetch_table_rows(sqlite_engine, "load_balancing_model_configs", tenant_id=dirty_fixture.primary.tenant_id) + lb_row = next(row for row in lb_rows if row["id"] == dirty_fixture.primary.load_balancing_config_id) + assert lb_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + assert lb_row["credential_id"] == dirty_fixture.primary.winner_credential_id + assert lb_row["encrypted_config"] == dirty_fixture.primary.winner_encrypted_config + + credential_rows = fetch_table_rows( + sqlite_engine, "provider_model_credentials", tenant_id=dirty_fixture.primary.tenant_id + ) + assert ( + count_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + == 2 + ) + credential_ids = {str(row["id"]) for row in credential_rows} + assert credential_ids == { + dirty_fixture.primary.winner_credential_id, + dirty_fixture.primary.distinct_credential_id, + } + distinct_row = next(row for row in credential_rows if row["id"] == dirty_fixture.primary.distinct_credential_id) + assert distinct_row["credential_name"] == dirty_fixture.primary.distinct_credential_name + assert distinct_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + + rendered_output = output.getvalue() + assert dirty_fixture.primary.loser_credential_id in rendered_output + assert dirty_fixture.primary.loser_encrypted_config in rendered_output + assert any("load_balancing_provider_model_credentials" in key for key in deleted_cache_keys) or any( + "load_balancing_provider_model" in key for key in deleted_cache_keys + ) + + +def test_migration_filters_by_tenant_model_types_and_tables( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + before_primary_credentials = fetch_table_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + before_secondary = { + table_name: fetch_table_rows(sqlite_engine, table_name, tenant_id=dirty_fixture.secondary.tenant_id) + for table_name in ALL_TABLE_NAMES + } + + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + tables=("provider_models",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + assert ( + count_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + == 3 + ) + credential_rows = fetch_table_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + assert credential_rows == before_primary_credentials + provider_model_rows = fetch_table_rows( + sqlite_engine, + "provider_models", + tenant_id=dirty_fixture.primary.tenant_id, + ) + provider_model_row = next( + row for row in provider_model_rows if row["id"] == dirty_fixture.primary.provider_model_id + ) + embedding_provider_model_row = next( + row for row in provider_model_rows if row["id"] == dirty_fixture.primary.embedding_provider_model_id + ) + assert provider_model_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + assert embedding_provider_model_row["model_type"] == "embeddings" + + tenant_default_row = fetch_table_rows( + sqlite_engine, + "tenant_default_models", + tenant_id=dirty_fixture.primary.tenant_id, + )[0] + assert tenant_default_row["model_type"] == "text-generation" + + provider_model_setting_rows = fetch_table_rows( + sqlite_engine, + "provider_model_settings", + tenant_id=dirty_fixture.primary.tenant_id, + ) + llm_setting_row = next( + row for row in provider_model_setting_rows if row["id"] == dirty_fixture.primary.provider_model_setting_id + ) + embedding_setting_row = next( + row for row in provider_model_setting_rows if row["id"] == dirty_fixture.primary.embedding_setting_id + ) + assert llm_setting_row["model_type"] == "text-generation" + assert embedding_setting_row["model_type"] == "embeddings" + + lb_row = fetch_table_rows( + sqlite_engine, + "load_balancing_model_configs", + tenant_id=dirty_fixture.primary.tenant_id, + )[0] + assert lb_row["model_type"] == "text-generation" + + after_secondary = { + table_name: fetch_table_rows(sqlite_engine, table_name, tenant_id=dirty_fixture.secondary.tenant_id) + for table_name in ALL_TABLE_NAMES + } + assert after_secondary == before_secondary + + +def test_migration_does_not_merge_credentials_with_different_credential_name( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + credential_rows = fetch_table_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + distinct_row = next(row for row in credential_rows if row["id"] == dirty_fixture.primary.distinct_credential_id) + assert distinct_row["credential_name"] == dirty_fixture.primary.distinct_credential_name + assert distinct_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + assert ( + count_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + == 2 + ) + + +def test_migration_is_idempotent_on_second_apply( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + after_first = snapshot_legacy_model_type_state(sqlite_engine) + + service.migrate() + after_second = snapshot_legacy_model_type_state(sqlite_engine) + + assert after_second == after_first diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_node_output_inspector.py b/api/tests/unit_tests/controllers/console/app/test_workflow_node_output_inspector.py new file mode 100644 index 0000000000..e66ae5246b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_node_output_inspector.py @@ -0,0 +1,454 @@ +"""Unit tests for the Node Output Inspector controller (Stage 4 §8). + +The controller has two non-trivial moving parts: + +1. :func:`_sse_envelope` — wire-format builder for the SSE ``{event, data, id}`` + records (decision D-5). +2. :func:`_stream_inspector_events` — the SSE generator that fans the redis + pub/sub stream out as snapshot / node_changed / workflow_run_completed / + error events. + +We exercise both as plain functions with mocked dependencies (service + +``inspector_events.subscribe``) — going through Flask routes would multiply +the test scaffolding without buying additional confidence in the core +behaviour. + +The Resource classes themselves are trivial wrappers (``_service().method()`` ++ ``_InspectorNotFound`` translation), and are touched here only by import so +codecov sees them as exercised; their detailed behaviour is locked down by +the service-level tests in +``tests/unit_tests/services/workflow/test_node_output_inspector_service.py``. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator +from typing import Any +from unittest.mock import MagicMock +from uuid import UUID + +import pytest + +from controllers.console.app import workflow_node_output_inspector as ctrl +from services.workflow.inspector_events import InspectorMessage +from services.workflow.node_output_inspector_service import ( + NodeOutputInspectorError, + NodeOutputStatus, + NodeOutputsView, + NodeStatus, + WorkflowRunSnapshotView, +) + +# ────────────────────────────────────────────────────────────────────────────── +# Fixtures +# ────────────────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def app_model() -> Any: + """A minimal ``App`` stub the controller passes through to the service. + + The SSE generator never reads its attributes — just forwards it — so a + sentinel object is enough. + """ + return MagicMock(name="App", tenant_id="tenant-1", id="app-1") + + +@pytest.fixture +def run_id() -> UUID: + return UUID("00000000-0000-0000-0000-0000000000aa") + + +def _snapshot_view(*, status: str, node_id: str = "agent-1") -> WorkflowRunSnapshotView: + from graphon.enums import WorkflowExecutionStatus + + return WorkflowRunSnapshotView( + workflow_run_id="00000000-0000-0000-0000-0000000000aa", + workflow_run_status=WorkflowExecutionStatus(status), + node_outputs=[ + NodeOutputsView( + node_id=node_id, + node_kind="agent", + node_display_name="Greeter", + node_status=NodeStatus.RUNNING if status == "running" else NodeStatus.READY, + outputs=[], + ) + ], + ) + + +def _node_view(*, node_id: str = "agent-1", node_status: NodeStatus = NodeStatus.READY) -> NodeOutputsView: + return NodeOutputsView( + node_id=node_id, + node_kind="agent", + node_display_name="Greeter", + node_status=node_status, + outputs=[], + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# _sse_envelope +# ────────────────────────────────────────────────────────────────────────────── + + +def test_sse_envelope_serializes_dict_payload(): + out = ctrl._sse_envelope("snapshot", {"foo": "bar"}, 7) + lines = out.rstrip("\n").split("\n") + assert lines[0] == "event: snapshot" + assert lines[1] == "id: 7" + assert lines[2] == 'data: {"foo": "bar"}' + assert out.endswith("\n\n") # SSE record separator + + +def test_sse_envelope_passes_strings_through_unmodified(): + """A raw string payload (e.g. ``:keepalive``) is emitted as-is.""" + out = ctrl._sse_envelope("snapshot", ":keepalive", 1) + assert "data: :keepalive\n" in out + + +def test_sse_envelope_handles_unicode_payload(): + out = ctrl._sse_envelope("node_changed", {"name": "你好"}, 3) + assert "你好" in out # ensure_ascii=False + + +# ────────────────────────────────────────────────────────────────────────────── +# _stream_inspector_events — fast path (already-terminal run) +# ────────────────────────────────────────────────────────────────────────────── + + +def _drain(stream: Iterator[str]) -> list[str]: + return list(stream) + + +def _parse(record: str) -> tuple[str, dict | None]: + """Pull ``event`` + ``data`` (json-decoded) out of one SSE record.""" + event = None + data: dict | None = None + for line in record.rstrip("\n").split("\n"): + if line.startswith("event: "): + event = line[len("event: ") :] + elif line.startswith("data: "): + try: + data = json.loads(line[len("data: ") :]) + except json.JSONDecodeError: + data = None + assert event is not None + return event, data + + +@pytest.fixture +def patch_service(monkeypatch: pytest.MonkeyPatch): + """Replace ``_service()`` with a MagicMock per-test.""" + + fake = MagicMock() + monkeypatch.setattr(ctrl, "_service", lambda: fake) + return fake + + +@pytest.fixture +def patch_subscribe(monkeypatch: pytest.MonkeyPatch): + """Patch the pub/sub subscribe iterator.""" + + def _make(messages: list[InspectorMessage | None]): + def _subscribe(workflow_run_id: str, *, timeout_seconds: float = 1.0): + for m in messages: + if m is None: + # heartbeat sentinel + yield InspectorMessage( + kind="node_changed", + workflow_run_id=workflow_run_id, + node_id=None, + status=None, + ) + else: + yield m + + monkeypatch.setattr(ctrl.inspector_events, "subscribe", _subscribe) + + return _make + + +def test_stream_fast_path_when_run_already_terminal(patch_service, app_model, run_id): + """A run that's already ``succeeded`` should produce ``snapshot`` → + ``workflow_run_completed`` and close without subscribing to pub/sub.""" + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="succeeded") + records = _drain(ctrl._stream_inspector_events(app_model, run_id)) + assert len(records) == 2 + e0, d0 = _parse(records[0]) + e1, d1 = _parse(records[1]) + assert e0 == "snapshot" + assert d0 is not None + assert d0["workflow_run_status"] == "succeeded" + assert e1 == "workflow_run_completed" + assert d1 is not None + assert d1["workflow_run_status"] == "succeeded" + + +def test_stream_fast_path_each_terminal_status(patch_service, app_model, run_id): + """All four terminal statuses take the fast-path. Note the enum value for + partial success is the hyphenated ``partial-succeeded``.""" + for terminal in ("succeeded", "failed", "stopped", "partial-succeeded"): + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status=terminal) + records = _drain(ctrl._stream_inspector_events(app_model, run_id)) + events = [_parse(r)[0] for r in records] + assert events == ["snapshot", "workflow_run_completed"], terminal + + +def test_stream_initial_404_propagates_before_any_bytes(patch_service, app_model, run_id): + """``NodeOutputInspectorError`` on the initial snapshot must surface as the + controller's ``_InspectorNotFound`` exception so Flask returns HTTP 404 + — not a half-streamed SSE body.""" + patch_service.snapshot_workflow_run.side_effect = NodeOutputInspectorError( + "workflow_run_not_found", "Workflow run not found." + ) + gen = ctrl._stream_inspector_events(app_model, run_id) + with pytest.raises(ctrl._InspectorNotFound) as exc: + next(gen) + assert exc.value.error_code == "workflow_run_not_found" + + +# ────────────────────────────────────────────────────────────────────────────── +# _stream_inspector_events — live path (run is running) +# ────────────────────────────────────────────────────────────────────────────── + + +def test_stream_live_emits_snapshot_then_node_changed_then_completion( + patch_service, patch_subscribe, app_model, run_id +): + """Happy path: snapshot → 2× node_changed → workflow_run_completed.""" + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="running") + patch_service.node_detail.return_value = _node_view(node_id="agent-1") + + msgs = [ + InspectorMessage(kind="node_changed", workflow_run_id=str(run_id), node_id="agent-1", status="running"), + InspectorMessage(kind="node_changed", workflow_run_id=str(run_id), node_id="agent-1", status="succeeded"), + InspectorMessage(kind="workflow_completed", workflow_run_id=str(run_id), node_id=None, status="succeeded"), + ] + patch_subscribe(msgs) + + events = [_parse(r)[0] for r in _drain(ctrl._stream_inspector_events(app_model, run_id))] + assert events == ["snapshot", "node_changed", "node_changed", "workflow_run_completed"] + # node_detail should be called once per delta (not once per heartbeat) + assert patch_service.node_detail.call_count == 2 + + +def test_stream_emits_heartbeat_after_n_idle_ticks( + patch_service, patch_subscribe, monkeypatch: pytest.MonkeyPatch, app_model, run_id +): + """When pub/sub returns the heartbeat sentinel ``_HEARTBEAT_EVERY_TICKS`` + times in a row, the generator emits a ``:keepalive`` SSE comment.""" + monkeypatch.setattr(ctrl, "_HEARTBEAT_EVERY_TICKS", 2) + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="running") + patch_service.node_detail.return_value = _node_view() + + # 2 heartbeats → keepalive, then real message + completion. + patch_subscribe( + [ + None, + None, + InspectorMessage(kind="workflow_completed", workflow_run_id=str(run_id), node_id=None, status="failed"), + ] + ) + records = _drain(ctrl._stream_inspector_events(app_model, run_id)) + raw = "".join(records) + assert ":keepalive\n\n" in raw + assert "workflow_run_completed" in raw + + +def test_stream_hard_timeout_force_closes_without_terminal( + patch_service, patch_subscribe, monkeypatch: pytest.MonkeyPatch, app_model, run_id +): + """If the engine crashes / drops the terminal event, the generator force- + closes after ``_STREAM_HARD_TIMEOUT_TICKS`` ticks rather than hanging.""" + monkeypatch.setattr(ctrl, "_STREAM_HARD_TIMEOUT_TICKS", 3) + monkeypatch.setattr(ctrl, "_HEARTBEAT_EVERY_TICKS", 100) # avoid keepalive noise + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="running") + + # 5 heartbeats, no terminal → generator should bail after 3 ticks. + patch_subscribe([None] * 10) + records = _drain(ctrl._stream_inspector_events(app_model, run_id)) + events = [_parse(r)[0] for r in records] + assert events == ["snapshot"] # only snapshot, then forced close + + +def test_stream_skips_messages_with_missing_node_id(patch_service, patch_subscribe, app_model, run_id): + """Defensive: malformed node_changed without node_id is silently dropped.""" + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="running") + patch_subscribe( + [ + InspectorMessage(kind="node_changed", workflow_run_id=str(run_id), node_id="", status="running"), + InspectorMessage(kind="workflow_completed", workflow_run_id=str(run_id), node_id=None, status="succeeded"), + ] + ) + events = [_parse(r)[0] for r in _drain(ctrl._stream_inspector_events(app_model, run_id))] + assert events == ["snapshot", "workflow_run_completed"] + assert patch_service.node_detail.call_count == 0 + + +def test_stream_skips_node_detail_404_without_breaking_stream(patch_service, patch_subscribe, app_model, run_id): + """When node_detail 404s mid-stream (node still being persisted), the + generator just drops that delta and keeps streaming.""" + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="running") + patch_service.node_detail.side_effect = NodeOutputInspectorError("node_not_in_workflow_run", "transient") + patch_subscribe( + [ + InspectorMessage(kind="node_changed", workflow_run_id=str(run_id), node_id="agent-1", status="running"), + InspectorMessage(kind="workflow_completed", workflow_run_id=str(run_id), node_id=None, status="succeeded"), + ] + ) + events = [_parse(r)[0] for r in _drain(ctrl._stream_inspector_events(app_model, run_id))] + assert events == ["snapshot", "workflow_run_completed"] + + +def test_stream_emits_error_event_on_node_detail_unexpected_exception( + patch_service, patch_subscribe, app_model, run_id +): + """Any non-Inspector exception (DB outage, JSON decode error) becomes a + user-visible ``error`` SSE record; the stream keeps running.""" + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="running") + patch_service.node_detail.side_effect = RuntimeError("db gone") + patch_subscribe( + [ + InspectorMessage(kind="node_changed", workflow_run_id=str(run_id), node_id="agent-1", status="running"), + InspectorMessage(kind="workflow_completed", workflow_run_id=str(run_id), node_id=None, status="succeeded"), + ] + ) + records = _drain(ctrl._stream_inspector_events(app_model, run_id)) + events = [_parse(r) for r in records] + kinds = [e for e, _ in events] + assert kinds == ["snapshot", "error", "workflow_run_completed"] + err_event, err_data = events[1] + assert err_data is not None + assert err_data["node_id"] == "agent-1" + assert "failed" in err_data["message"] + + +def test_stream_workflow_completed_status_falls_back_to_unknown(patch_service, patch_subscribe, app_model, run_id): + """If the pub/sub message arrives with status=None (publish race), the SSE + payload still carries ``workflow_run_status`` with the ``unknown`` + sentinel so the frontend never sees a missing field.""" + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="running") + patch_subscribe( + [InspectorMessage(kind="workflow_completed", workflow_run_id=str(run_id), node_id=None, status=None)] + ) + records = _drain(ctrl._stream_inspector_events(app_model, run_id)) + e, d = _parse(records[-1]) + assert e == "workflow_run_completed" + assert d is not None + assert d["workflow_run_status"] == "unknown" + + +# ────────────────────────────────────────────────────────────────────────────── +# Resource classes — import-level smoke + service-method delegation +# ────────────────────────────────────────────────────────────────────────────── + + +def test_resource_classes_are_registered(): + """All 8 Inspector Resource classes must be importable from the module so + flask-restx can discover them via the namespace decorators.""" + for name in ( + "WorkflowDraftRunNodeOutputsApi", + "WorkflowDraftRunNodeOutputDetailApi", + "WorkflowDraftRunNodeOutputPreviewApi", + "WorkflowDraftRunNodeOutputEventsApi", + "WorkflowPublishedRunNodeOutputsApi", + "WorkflowPublishedRunNodeOutputDetailApi", + "WorkflowPublishedRunNodeOutputPreviewApi", + "WorkflowPublishedRunNodeOutputEventsApi", + ): + assert hasattr(ctrl, name), name + + +def test_inspector_not_found_preserves_error_code(): + """Sanity: the controller's bespoke 404 wrapper hangs onto the + Inspector's specific error code rather than collapsing to a generic + ``not_found``.""" + err = NodeOutputInspectorError("node_not_in_workflow_run", "boom") + wrapped = ctrl._InspectorNotFound(err) + assert wrapped.error_code == "node_not_in_workflow_run" + assert wrapped.code == 404 + + +# ────────────────────────────────────────────────────────────────────────────── +# _serve_* — shared REST handler bodies (covered by both draft + published) +# ────────────────────────────────────────────────────────────────────────────── + + +def test_serve_snapshot_happy_path(patch_service, app_model, run_id): + """Returns the snapshot view as JSON-serialisable dict.""" + patch_service.snapshot_workflow_run.return_value = _snapshot_view(status="running") + result = ctrl._serve_snapshot(app_model, run_id) + assert isinstance(result, dict) + assert result["workflow_run_id"] == "00000000-0000-0000-0000-0000000000aa" + patch_service.snapshot_workflow_run.assert_called_once_with(app_model=app_model, workflow_run_id=str(run_id)) + + +def test_serve_snapshot_translates_inspector_error_to_404(patch_service, app_model, run_id): + """``NodeOutputInspectorError`` becomes the controller's 404 wrapper with + the specific ``error_code`` preserved.""" + patch_service.snapshot_workflow_run.side_effect = NodeOutputInspectorError("workflow_run_not_found", "no such run") + with pytest.raises(ctrl._InspectorNotFound) as exc: + ctrl._serve_snapshot(app_model, run_id) + assert exc.value.error_code == "workflow_run_not_found" + + +def test_serve_node_detail_happy_path(patch_service, app_model, run_id): + patch_service.node_detail.return_value = _node_view(node_id="agent-1") + result = ctrl._serve_node_detail(app_model, run_id, "agent-1") + assert result["node_id"] == "agent-1" + patch_service.node_detail.assert_called_once_with( + app_model=app_model, workflow_run_id=str(run_id), node_id="agent-1" + ) + + +def test_serve_node_detail_translates_inspector_error(patch_service, app_model, run_id): + patch_service.node_detail.side_effect = NodeOutputInspectorError("node_not_in_workflow_run", "missing") + with pytest.raises(ctrl._InspectorNotFound) as exc: + ctrl._serve_node_detail(app_model, run_id, "ghost") + assert exc.value.error_code == "node_not_in_workflow_run" + + +def test_serve_output_preview_happy_path(patch_service, app_model, run_id): + from services.workflow.node_output_inspector_service import ( + DeclaredOutputType, + OutputPreviewView, + ) + + patch_service.output_preview.return_value = OutputPreviewView( + node_id="agent-1", + output_name="text", + type=DeclaredOutputType.STRING, + status=NodeOutputStatus.READY, + value="Hello", + ) + result = ctrl._serve_output_preview(app_model, run_id, "agent-1", "text") + assert result["value"] == "Hello" + assert result["status"] == "ready" + patch_service.output_preview.assert_called_once_with( + app_model=app_model, + workflow_run_id=str(run_id), + node_id="agent-1", + output_name="text", + ) + + +def test_serve_output_preview_translates_inspector_error(patch_service, app_model, run_id): + patch_service.output_preview.side_effect = NodeOutputInspectorError("node_output_not_declared", "no such output") + with pytest.raises(ctrl._InspectorNotFound) as exc: + ctrl._serve_output_preview(app_model, run_id, "agent-1", "phantom") + assert exc.value.error_code == "node_output_not_declared" + + +# ────────────────────────────────────────────────────────────────────────────── +# Note: the Resource ``.get`` methods themselves (6 REST + 2 SSE) are +# 1-line delegators to the helpers above. They can't be called directly in a +# unit test because their decorator stack (``@setup_required`` / +# ``@login_required`` / ``@account_initialization_required`` / +# ``@get_app_model``) needs a real Flask request context + DB-backed account. +# The integration test in +# ``tests/integration_tests/services/test_node_output_inspector_service.py`` +# (and the E2E driver in /tmp/e2e_inspector_sse_published.py) exercise them +# through the HTTP stack. +# ────────────────────────────────────────────────────────────────────────────── diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..51bbc33079 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import PropertyMock, patch + +from controllers.console import console_ns +from controllers.console.auth.data_source_bearer_auth import ( + ApiKeyAuthDataSource, + ApiKeyAuthDataSourceBinding, + ApiKeyAuthDataSourceBindingDelete, +) + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _payload_patch(payload: dict): + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + +def test_list_data_source_auth_uses_injected_tenant_id() -> None: + api = ApiKeyAuthDataSource() + method = _unwrap(api.get) + binding = SimpleNamespace( + id="binding-1", + category="api_key", + provider="custom", + disabled=False, + created_at=datetime(2026, 1, 1, tzinfo=UTC), + updated_at=datetime(2026, 1, 2, tzinfo=UTC), + ) + + with patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list", + return_value=[binding], + ) as get_provider_auth_list: + result = method(api, "tenant-1") + + get_provider_auth_list.assert_called_once_with("tenant-1") + assert result["sources"][0]["id"] == "binding-1" + assert result["sources"][0]["provider"] == "custom" + + +def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None: + api = ApiKeyAuthDataSourceBinding() + method = _unwrap(api.post) + payload = { + "category": "api_key", + "provider": "custom", + "credentials": {"auth_type": "api_key", "config": {"api_key": "secret"}}, + } + + with ( + _payload_patch(payload), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth, + ): + result, status = method(api, "tenant-1") + + create_auth.assert_called_once_with("tenant-1", payload) + assert result == {"result": "success"} + assert status == 200 + + +def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None: + api = ApiKeyAuthDataSourceBindingDelete() + method = _unwrap(api.delete) + + with patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" + ) as delete_provider_auth: + result, status = method(api, "tenant-1", "binding-1") + + delete_provider_auth.assert_called_once_with("tenant-1", "binding-1") + assert result == "" + assert status == 204 diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..1508d7b50e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from unittest.mock import patch + +from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi +from models import Account +from models.account import AccountStatus, TenantAccountRole +from models.model import OAuthProviderApp + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _make_account() -> Account: + account = Account( + name="Test User", + email="test@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = "account-1" + account.role = TenantAccountRole.OWNER + return account + + +def _make_oauth_provider_app() -> OAuthProviderApp: + return OAuthProviderApp( + app_icon="icon", + client_id="client-1", + client_secret="secret", + app_label={"en-US": "Test App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + + +def test_oauth_authorize_uses_injected_current_user() -> None: + api = OAuthServerUserAuthorizeApi() + method = _unwrap(api.post) + account = _make_account() + oauth_provider_app = _make_oauth_provider_app() + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="authorization-code", + ) as sign_oauth_authorization_code: + response = method(api, oauth_provider_app, account) + + sign_oauth_authorization_code.assert_called_once_with("client-1", "account-1") + assert response == {"code": "authorization-code"} diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 3ed65b1ffb..3e76e6c21a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -70,13 +70,14 @@ class TestExternalApiTemplateListApi: ExternalDatasetService, "get_external_knowledge_apis", return_value=([api_item], 1), - ), + ) as get_external_knowledge_apis, ): - resp, status = method(api, "id") + resp, status = method(api, "tenant-1") assert status == 200 assert resp["total"] == 1 assert resp["data"][0]["id"] == "1" + get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None) def test_post_forbidden(self, app: Flask, current_user): current_user.is_dataset_editor = False @@ -321,13 +322,14 @@ class TestExternalApiTemplateListApiAdvanced: patch( "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis", return_value=(templates, 25), - ), + ) as get_external_knowledge_apis, ): - resp, status = method(api, "id") + resp, status = method(api, "tenant-1") assert status == 200 assert resp["total"] == 25 assert len(resp["data"]) == 3 + get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None) class TestExternalDatasetCreateApiAdvanced: diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 4fa5d21493..faedd4d7e1 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -1,5 +1,5 @@ import uuid -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import PropertyMock, patch import pytest from flask import Flask @@ -8,7 +8,7 @@ from werkzeug.exceptions import NotFound from controllers.console import console_ns from controllers.console.datasets.hit_testing import HitTestingApi -from controllers.console.datasets.hit_testing_base import HitTestingPayload +from models.dataset import Dataset def unwrap(func): @@ -32,7 +32,48 @@ def dataset_id(): @pytest.fixture def dataset(): - return MagicMock(id="dataset-1") + return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1") + + +def hit_testing_record() -> dict[str, object]: + return { + "segment": { + "id": "segment-1", + "position": 1, + "document_id": "document-1", + "content": "Chunk text", + "sign_content": "Chunk text", + "answer": None, + "word_count": 2, + "tokens": 3, + "keywords": [], + "index_node_id": None, + "index_node_hash": None, + "hit_count": 0, + "enabled": True, + "disabled_at": None, + "disabled_by": None, + "status": "completed", + "created_by": "account-1", + "created_at": 1_700_000_000, + "indexing_at": None, + "completed_at": None, + "error": None, + "stopped_at": None, + "document": { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": None, + }, + }, + "child_chunks": [], + "score": None, + "tsne_position": None, + "files": [], + "summary": None, + } @pytest.fixture(autouse=True) @@ -63,7 +104,6 @@ class TestHitTestingApi: payload = { "query": "what is vector search", - "top_k": 3, } with ( @@ -74,11 +114,6 @@ class TestHitTestingApi: new_callable=PropertyMock, return_value=payload, ), - patch.object( - HitTestingPayload, - "model_validate", - return_value=MagicMock(model_dump=lambda **_: payload), - ), patch.object( HitTestingApi, "get_and_validate_dataset", @@ -91,7 +126,7 @@ class TestHitTestingApi: patch.object( HitTestingApi, "perform_hit_testing", - return_value={"query": "what is vector search", "records": []}, + return_value={"query": {"content": "what is vector search"}, "records": []}, ), ): result = method(api, dataset_id) @@ -107,16 +142,7 @@ class TestHitTestingApi: payload = { "query": "what is vector search", } - records = [ - { - "segment": None, - "child_chunks": [], - "score": None, - "tsne_position": None, - "files": [], - "summary": None, - } - ] + records = [hit_testing_record()] with ( app.test_request_context("/"), @@ -126,11 +152,6 @@ class TestHitTestingApi: new_callable=PropertyMock, return_value=payload, ), - patch.object( - HitTestingPayload, - "model_validate", - return_value=MagicMock(model_dump=lambda **_: payload), - ), patch.object( HitTestingApi, "get_and_validate_dataset", @@ -143,13 +164,16 @@ class TestHitTestingApi: patch.object( HitTestingApi, "perform_hit_testing", - return_value={"query": payload["query"], "records": records}, + return_value={"query": {"content": payload["query"]}, "records": records}, ), ): result = method(api, dataset_id) - assert result["query"] == payload["query"] - assert result["records"] == records + assert result["query"] == {"content": payload["query"]} + assert result["records"][0]["segment"]["keywords"] == [] + assert result["records"][0]["child_chunks"] == [] + assert result["records"][0]["files"] == [] + assert result["records"][0]["score"] is None def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id): api = HitTestingApi() @@ -192,11 +216,6 @@ class TestHitTestingApi: new_callable=PropertyMock, return_value=payload, ), - patch.object( - HitTestingPayload, - "model_validate", - return_value=MagicMock(model_dump=lambda **_: payload), - ), patch.object( HitTestingApi, "get_and_validate_dataset", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index 77e9cfeb5b..072aa559df 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -22,6 +22,7 @@ from core.errors.error import ( ) from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account +from models.dataset import Dataset from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -43,7 +44,45 @@ def patch_current_user(mocker, account): @pytest.fixture def dataset(): - return MagicMock(id="dataset-1") + return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1") + + +def hit_testing_record() -> dict[str, object]: + return { + "segment": { + "id": "segment-1", + "position": 1, + "document_id": "document-1", + "content": "Chunk text", + "answer": None, + "word_count": 2, + "tokens": 3, + "keywords": None, + "index_node_id": None, + "index_node_hash": None, + "hit_count": 0, + "enabled": True, + "disabled_at": None, + "disabled_by": None, + "status": "completed", + "created_by": "account-1", + "created_at": 1_700_000_000, + "indexing_at": None, + "completed_at": None, + "error": None, + "stopped_at": None, + "document": { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": None, + }, + }, + "child_chunks": None, + "files": None, + "score": 0.8, + } class TestGetAndValidateDataset: @@ -116,6 +155,13 @@ class TestParseArgs: with pytest.raises(ValueError): DatasetsHitTestingBase.parse_args(payload) + def test_parse_args_ignores_unknown_fields_for_compatibility(self): + payload = {"query": "hello", "top_k": 3} + + result = DatasetsHitTestingBase.parse_args(payload) + + assert result == {"query": "hello"} + class TestPerformHitTesting: def test_success(self, dataset): @@ -131,48 +177,42 @@ class TestPerformHitTesting: ): result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) - assert result["query"] == "hello" + assert result["query"] == {"content": "hello"} assert result["records"] == [] def test_success_prepares_nullable_list_fields(self, dataset): response = { "query": {"content": "hello"}, - "records": [ - { - "segment": {"id": "segment-1", "keywords": None}, - "child_chunks": None, - "files": None, - "score": 0.8, - } - ], + "records": [hit_testing_record()], } + with patch.object( + HitTestingService, + "retrieve", + return_value=response, + ): + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + assert result["query"] == {"content": "hello"} + record = result["records"][0] + assert record["segment"]["keywords"] == [] + assert record["segment"]["sign_content"] is None + assert record["child_chunks"] == [] + assert record["files"] == [] + assert record["score"] == 0.8 + assert record["tsne_position"] is None + assert record["summary"] is None + + def test_invalid_query_response_raises_value_error(self, dataset): with ( patch.object( HitTestingService, "retrieve", - return_value=response, - ), - patch( - "controllers.console.datasets.hit_testing_base.marshal", - return_value=response["records"], + return_value={"query": "hello", "records": []}, ), + pytest.raises(ValueError, match="Invalid hit testing query response"), ): - result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) - - assert result["query"] == "hello" - assert result["records"] == [ - { - "segment": {"id": "segment-1", "keywords": []}, - "child_chunks": [], - "files": [], - "score": 0.8, - } - ] - - def test_invalid_query_response_raises_value_error(self): - with pytest.raises(ValueError, match="Invalid hit testing query response"): - DatasetsHitTestingBase._extract_hit_testing_query("hello") + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) def test_invalid_records_response_raises_value_error(self): with pytest.raises(ValueError, match="Invalid hit testing records response"): diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 3d41489435..cb63a52075 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -73,14 +73,13 @@ class TestMessageListApi: "/", query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "pagination_by_first_id", return_value=pagination, ), ): - result = method(installed_app) + result = method(MagicMock(), installed_app) assert result["limit"] == 20 assert result["has_more"] is False @@ -93,9 +92,8 @@ class TestMessageListApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="completion") - with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): - with pytest.raises(NotChatAppError): - method(installed_app) + with pytest.raises(NotChatAppError): + method(MagicMock(), installed_app) def test_conversation_not_exists(self, app: Flask): api = module.MessageListApi() @@ -109,7 +107,6 @@ class TestMessageListApi: "/", query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "pagination_by_first_id", @@ -117,7 +114,7 @@ class TestMessageListApi: ), ): with pytest.raises(NotFound): - method(installed_app) + method(MagicMock(), installed_app) def test_first_message_not_exists(self, app: Flask): api = module.MessageListApi() @@ -131,7 +128,6 @@ class TestMessageListApi: "/", query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "pagination_by_first_id", @@ -139,7 +135,7 @@ class TestMessageListApi: ), ): with pytest.raises(NotFound): - method(installed_app) + method(MagicMock(), installed_app) class TestMessageFeedbackApi: @@ -152,13 +148,12 @@ class TestMessageFeedbackApi: with ( app.test_request_context("/", json={"rating": "like"}), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "create_feedback", ), ): - result = method(installed_app, "mid") + result = method(MagicMock(), installed_app, "mid") assert result["result"] == "success" @@ -171,7 +166,6 @@ class TestMessageFeedbackApi: with ( app.test_request_context("/", json={}), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "create_feedback", @@ -179,7 +173,7 @@ class TestMessageFeedbackApi: ), ): with pytest.raises(NotFound): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") class TestMessageMoreLikeThisApi: @@ -195,7 +189,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -207,7 +200,7 @@ class TestMessageMoreLikeThisApi: return_value=("ok", 200), ), ): - resp = method(installed_app, "mid") + resp = method(MagicMock(), installed_app, "mid") assert resp == ("ok", 200) @@ -218,9 +211,8 @@ class TestMessageMoreLikeThisApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="chat") - with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): - with pytest.raises(NotCompletionAppError): - method(installed_app, "mid") + with pytest.raises(NotCompletionAppError): + method(MagicMock(), installed_app, "mid") def test_more_like_this_disabled(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -234,7 +226,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -242,7 +233,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(AppMoreLikeThisDisabledError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_message_not_exists_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -256,7 +247,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -264,7 +254,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(NotFound): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_provider_not_init_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -278,7 +268,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -286,7 +275,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_quota_exceeded_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -300,7 +289,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -308,7 +296,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_model_not_support_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -322,7 +310,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -330,7 +317,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_invoke_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -344,7 +331,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -352,7 +338,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(CompletionRequestError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_unexpected_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -366,7 +352,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -374,7 +359,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(InternalServerError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") class TestMessageSuggestedQuestionApi: @@ -386,14 +371,13 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", return_value=["q1", "q2"], ), ): - result = method(installed_app, "mid") + result = method(MagicMock(), installed_app, "mid") assert result["data"] == ["q1", "q2"] @@ -404,9 +388,8 @@ class TestMessageSuggestedQuestionApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="completion") - with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): - with pytest.raises(NotChatAppError): - method(installed_app, "mid") + with pytest.raises(NotChatAppError): + method(MagicMock(), installed_app, "mid") def test_disabled(self): api = module.MessageSuggestedQuestionApi() @@ -416,7 +399,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -424,7 +406,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(AppSuggestedQuestionsAfterAnswerDisabledError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_message_not_exists_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -434,7 +416,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -442,7 +423,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_conversation_not_exists_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -452,7 +433,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -460,7 +440,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_provider_not_init_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -470,7 +450,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -478,7 +457,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_quota_exceeded_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -488,7 +467,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -496,7 +474,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_model_not_support_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -506,7 +484,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -514,7 +491,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_invoke_error_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -524,7 +501,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -532,7 +508,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(CompletionRequestError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_unexpected_error_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -542,7 +518,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -550,4 +525,4 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(InternalServerError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 47866b92b2..69cff80f3a 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -13,6 +13,8 @@ from controllers.console.tag.tags import ( TagListApi, TagUpdateDeleteApi, ) +from models import Account +from models.account import AccountStatus, TenantAccountRole from models.enums import TagType from services.tag_service import UpdateTagPayload @@ -35,20 +37,26 @@ def app(): @pytest.fixture def admin_user(): - return MagicMock( - id="user-1", - has_edit_permission=True, - is_dataset_editor=True, + account = Account( + name="Admin User", + email="admin@example.com", + status=AccountStatus.ACTIVE, ) + account.id = "user-1" + account.role = TenantAccountRole.OWNER + return account @pytest.fixture def readonly_user(): - return MagicMock( - id="user-2", - has_edit_permission=False, - is_dataset_editor=False, + account = Account( + name="Readonly User", + email="readonly@example.com", + status=AccountStatus.ACTIVE, ) + account.id = "user-2" + account.role = TenantAccountRole.NORMAL + return account @pytest.fixture @@ -80,10 +88,6 @@ class TestTagListApi: with app.test_request_context("/?type=knowledge"): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.tag.tags.TagService.get_tags", return_value=[ @@ -96,7 +100,7 @@ class TestTagListApi: ], ), ): - result, status = method(api) + result, status = method(api, "tenant-1") assert status == 200 assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}] @@ -137,17 +141,13 @@ class TestTagListApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), payload_patch(payload), patch( "controllers.console.tag.tags.TagService.save_tags", return_value=tag, ), ): - result, status = method(api) + result, status = method(api, admin_user) assert status == 200 assert result["name"] == "test-tag" @@ -161,14 +161,10 @@ class TestTagListApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ), payload_patch(payload), ): with pytest.raises(Forbidden): - method(api) + method(api, readonly_user) class TestTagUpdateDeleteApi: @@ -180,10 +176,6 @@ class TestTagUpdateDeleteApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), payload_patch(payload), patch( "controllers.console.tag.tags.TagService.update_tags", @@ -194,7 +186,7 @@ class TestTagUpdateDeleteApi: return_value=3, ), ): - result, status = method(api, "tag-1") + result, status = method(api, admin_user, "tag-1") assert status == 200 update_payload, tag_id = update_tags_mock.call_args.args @@ -210,14 +202,10 @@ class TestTagUpdateDeleteApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ), payload_patch(payload), ): with pytest.raises(Forbidden): - method(api, "tag-1") + method(api, readonly_user, "tag-1") def test_delete_success(self, app: Flask, admin_user): api = TagUpdateDeleteApi() @@ -225,10 +213,6 @@ class TestTagUpdateDeleteApi: with ( app.test_request_context("/"), - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, "tenant-1"), - ), patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock, ): result, status = method(api, "tag-1") @@ -250,14 +234,10 @@ class TestTagBindingCollectionApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), payload_patch(payload), patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock, ): - result, status = method(api) + result, status = method(api, admin_user) save_mock.assert_called_once() assert status == 200 @@ -297,14 +277,10 @@ class TestTagBindingCollectionApi: with app.test_request_context("/", json={}): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ), payload_patch({}), ): with pytest.raises(Forbidden): - method(api) + method(api, readonly_user) class TestTagBindingRemoveApi: @@ -320,14 +296,10 @@ class TestTagBindingRemoveApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), payload_patch(payload), patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock, ): - result, status = method(api) + result, status = method(api, admin_user) delete_mock.assert_called_once() delete_payload = delete_mock.call_args.args[0] @@ -341,14 +313,10 @@ class TestTagBindingRemoveApi: with app.test_request_context("/", json={}): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ), payload_patch({}), ): with pytest.raises(Forbidden): - method(api) + method(api, readonly_user) class TestTagResponseModel: diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py new file mode 100644 index 0000000000..1517ff5ed8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import inspect +from collections.abc import Callable +from types import SimpleNamespace +from typing import cast +from unittest.mock import patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console.apikey import BaseApiKeyListResource, BaseApiKeyResource +from models import Account +from models.account import AccountStatus, TenantAccountRole +from models.enums import ApiTokenType +from models.model import ApiToken, App + + +def _make_list_resource() -> BaseApiKeyListResource: + resource = BaseApiKeyListResource() + resource.resource_type = ApiTokenType.APP + resource.resource_model = App + resource.resource_id_field = "app_id" + resource.token_prefix = "app-" + return resource + + +def _make_key_resource() -> BaseApiKeyResource: + resource = BaseApiKeyResource() + resource.resource_type = ApiTokenType.APP + resource.resource_model = App + resource.resource_id_field = "app_id" + return resource + + +def _make_account(role: TenantAccountRole) -> Account: + account = Account( + name="Test User", + email=f"{role.value}@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = f"{role.value}-user" + account.role = role + return account + + +def test_list_api_keys_uses_injected_tenant_id() -> None: + resource = _make_list_resource() + api_key = SimpleNamespace( + id="key-1", + type=ApiTokenType.APP, + token="app-token", + last_used_at=None, + created_at=None, + ) + + with ( + patch("controllers.console.apikey._get_resource") as get_resource, + patch("controllers.console.apikey.db") as db_mock, + ): + db_mock.session.scalars.return_value.all.return_value = [api_key] + + result = resource.get("app-1", "tenant-1") + + get_resource.assert_called_once_with("app-1", "tenant-1", App) + assert result == { + "data": [ + { + "id": "key-1", + "type": "app", + "token": "app-token", + "last_used_at": None, + "created_at": None, + } + ] + } + + +def test_create_api_key_uses_injected_tenant_id() -> None: + resource = _make_list_resource() + raw_post = cast( + Callable[[BaseApiKeyListResource, str, str], tuple[dict[str, object], int]], + inspect.unwrap(BaseApiKeyListResource.post), + ) + + def add_api_token(api_token: ApiToken) -> None: + api_token.id = "key-1" + + with ( + patch("controllers.console.apikey._get_resource") as get_resource, + patch("controllers.console.apikey.db") as db_mock, + patch("controllers.console.apikey.ApiToken.generate_api_key", return_value="app-generated-token"), + ): + db_mock.session.scalar.return_value = 0 + db_mock.session.add.side_effect = add_api_token + + result, status = raw_post(resource, "app-1", "tenant-1") + + get_resource.assert_called_once_with("app-1", "tenant-1", App) + assert status == 201 + assert result["token"] == "app-generated-token" + api_token = db_mock.session.add.call_args.args[0] + assert api_token.app_id == "app-1" + assert api_token.tenant_id == "tenant-1" + assert api_token.type == ApiTokenType.APP + db_mock.session.commit.assert_called_once() + + +def test_delete_api_key_rejects_non_admin_account() -> None: + resource = _make_key_resource() + + with ( + patch("controllers.console.apikey._get_resource") as get_resource, + patch("controllers.console.apikey.db") as db_mock, + ): + with pytest.raises(Forbidden): + resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.NORMAL)) + + get_resource.assert_called_once_with("app-1", "tenant-1", App) + db_mock.session.scalar.assert_not_called() + + +def test_delete_api_key_uses_injected_user_and_tenant() -> None: + resource = _make_key_resource() + api_key = SimpleNamespace(token="app-token", type=ApiTokenType.APP) + + with ( + patch("controllers.console.apikey._get_resource") as get_resource, + patch("controllers.console.apikey.db") as db_mock, + patch("controllers.console.apikey.ApiTokenCache.delete") as delete_cache, + ): + db_mock.session.scalar.return_value = api_key + + result, status = resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.OWNER)) + + get_resource.assert_called_once_with("app-1", "tenant-1", App) + delete_cache.assert_called_once_with("app-token", ApiTokenType.APP) + db_mock.session.execute.assert_called_once() + db_mock.session.commit.assert_called_once() + assert result == "" + assert status == 204 diff --git a/api/tests/unit_tests/controllers/console/test_extension.py b/api/tests/unit_tests/controllers/console/test_extension.py index 20fc62073b..487cf8f54f 100644 --- a/api/tests/unit_tests/controllers/console/test_extension.py +++ b/api/tests/unit_tests/controllers/console/test_extension.py @@ -54,7 +54,6 @@ def _masked_api_key(api_key: str) -> str: def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock: """Bypass console decorators so handlers can run in isolation.""" - import controllers.console.extension as extension_module from controllers.console import wraps as wraps_module account = MagicMock() @@ -66,7 +65,6 @@ def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock: monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD") monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True) monkeypatch.delenv("INIT_PASSWORD", raising=False) - monkeypatch.setattr(extension_module, "current_account_with_tenant", lambda: (account, "tenant-123")) monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (account, "tenant-123")) # The login_required decorator consults the shared LocalProxy in libs.login. diff --git a/api/tests/unit_tests/controllers/console/test_feature.py b/api/tests/unit_tests/controllers/console/test_feature.py index 0339c50777..58ba3f08cb 100644 --- a/api/tests/unit_tests/controllers/console/test_feature.py +++ b/api/tests/unit_tests/controllers/console/test_feature.py @@ -15,11 +15,6 @@ class TestFeatureApi: def test_get_tenant_features_success(self, mocker: MockerFixture): from controllers.console.feature import FeatureApi - mocker.patch( - "controllers.console.feature.current_account_with_tenant", - return_value=("account_id", "tenant_123"), - ) - get_features = mocker.patch("controllers.console.feature.FeatureService.get_features") get_features.return_value.model_dump.return_value = { "features": {"feature_a": True}, @@ -29,7 +24,7 @@ class TestFeatureApi: api = FeatureApi() raw_get = unwrap(FeatureApi.get) - result = raw_get(api) + result = raw_get(api, "tenant_123") assert result == {"features": {"feature_a": True}} get_features.assert_called_once_with("tenant_123", exclude_vector_space=True) @@ -39,18 +34,13 @@ class TestFeatureVectorSpaceApi: def test_get_vector_space_success(self, mocker: MockerFixture): from controllers.console.feature import FeatureVectorSpaceApi - mocker.patch( - "controllers.console.feature.current_account_with_tenant", - return_value=("account_id", "tenant_123"), - ) - get_vector_space = mocker.patch("controllers.console.feature.FeatureService.get_vector_space") get_vector_space.return_value.model_dump.return_value = {"size": 5120, "limit": 20480} api = FeatureVectorSpaceApi() raw_get = unwrap(FeatureVectorSpaceApi.get) - result = raw_get(api) + result = raw_get(api, "tenant_123") assert result == {"size": 5120, "limit": 20480} get_vector_space.assert_called_once_with("tenant_123") diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py index 9274f6cf61..f6ef1cb824 100644 --- a/api/tests/unit_tests/controllers/console/test_files.py +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -19,6 +19,8 @@ from controllers.console.files import ( FilePreviewApi, FileSupportTypeApi, ) +from models import Account +from models.account import AccountStatus, TenantAccountRole def unwrap(func): @@ -53,18 +55,15 @@ def mock_decorators(): @pytest.fixture def mock_current_user(): - user = MagicMock() - user.is_dataset_editor = True + user = Account(name="Test User", email="user-1@example.com", status=AccountStatus.ACTIVE) + user.id = "user-1" + user.role = TenantAccountRole.OWNER return user @pytest.fixture def mock_account_context(mock_current_user): - with patch( - "controllers.console.files.current_account_with_tenant", - return_value=(mock_current_user, None), - ): - yield + return mock_current_user @pytest.fixture @@ -101,7 +100,7 @@ class TestFileApiPost: with app.test_request_context(method="POST", data={}): with pytest.raises(NoFileUploadedError): - post_method(api) + post_method(api, mock_account_context) def test_too_many_files(self, app: Flask, mock_account_context): api = FileApi() @@ -118,7 +117,7 @@ class TestFileApiPost: mock_request.form.get.return_value = None with pytest.raises(TooManyFilesError): - post_method(api) + post_method(api, mock_account_context) def test_filename_missing(self, app: Flask, mock_account_context): api = FileApi() @@ -130,26 +129,22 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(FilenameNotExistsError): - post_method(api) + post_method(api, mock_account_context) def test_dataset_upload_without_permission(self, app: Flask, mock_current_user): - mock_current_user.is_dataset_editor = False + mock_current_user.role = TenantAccountRole.NORMAL - with patch( - "controllers.console.files.current_account_with_tenant", - return_value=(mock_current_user, None), - ): - api = FileApi() - post_method = unwrap(api.post) + api = FileApi() + post_method = unwrap(api.post) - data = { - "file": (io.BytesIO(b"abc"), "test.txt"), - "source": "datasets", - } + data = { + "file": (io.BytesIO(b"abc"), "test.txt"), + "source": "datasets", + } - with app.test_request_context(method="POST", data=data): - with pytest.raises(Forbidden): - post_method(api) + with app.test_request_context(method="POST", data=data): + with pytest.raises(Forbidden): + post_method(api, mock_current_user) def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service): api = FileApi() @@ -179,7 +174,7 @@ class TestFileApiPost: } with app.test_request_context(method="POST", data=data): - response, status = post_method(api) + response, status = post_method(api, mock_account_context) assert status == 201 assert response["id"] == "file-id-123" @@ -216,7 +211,7 @@ class TestFileApiPost: } with app.test_request_context(method="POST", data=data): - response, status = post_method(api) + response, status = post_method(api, mock_account_context) assert status == 201 assert response["id"] == "file-id-456" @@ -240,7 +235,7 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(FileTooLargeError): - post_method(api) + post_method(api, mock_account_context) def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service): api = FileApi() @@ -257,7 +252,7 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(UnsupportedFileTypeError): - post_method(api) + post_method(api, mock_account_context) def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service): api = FileApi() @@ -274,7 +269,7 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(BlockedFileExtensionError): - post_method(api) + post_method(api, mock_account_context) class TestFilePreviewApi: @@ -284,7 +279,7 @@ class TestFilePreviewApi: mock_file_service.get_file_preview.return_value = "preview text" with app.test_request_context(): - result = get_method(api, "1234") + result = get_method(api, "tenant-123", "1234") assert result == {"content": "preview text"} diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py index 8e86709b66..7c7abdcf2d 100644 --- a/api/tests/unit_tests/controllers/console/test_remote_files.py +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -10,6 +10,8 @@ import pytest from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError from controllers.console import remote_files as remote_files_module +from models import Account +from models.account import AccountStatus, TenantAccountRole from services.errors.file import FileTooLargeError as ServiceFileTooLargeError from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError @@ -20,6 +22,17 @@ def _unwrap(func): return func +def _make_account(account_id: str = "u1") -> Account: + account = Account( + name="Test User", + email=f"{account_id}@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = account_id + account.role = TenantAccountRole.OWNER + return account + + class _FakeResponse: def __init__( self, @@ -63,7 +76,7 @@ def _mock_upload_dependencies( file_service_cls = MagicMock() file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit monkeypatch.setattr(remote_files_module, "FileService", file_service_cls) - monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None)) + current_user = _make_account() monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr( remote_files_module.file_helpers, @@ -71,7 +84,7 @@ def _mock_upload_dependencies( lambda upload_file_id: f"https://signed.example/{upload_file_id}", ) - return file_service_cls + return file_service_cls, current_user def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -147,7 +160,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc get_mock = MagicMock(return_value=get_resp) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) - file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) upload_file = SimpleNamespace( id="file-1", name="report.txt", @@ -160,7 +173,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc file_service_cls.return_value.upload_file.return_value = upload_file with app.test_request_context(method="POST", json={"url": url}): - payload, status = handler(api) + payload, status = handler(api, current_user) assert status == 201 assert payload["id"] == "file-1" @@ -170,7 +183,7 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc filename="report.txt", content=b"fallback-content", mimetype="text/plain", - user=SimpleNamespace(id="u1"), + user=current_user, source_url=url, ) @@ -191,7 +204,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( get_mock = MagicMock(return_value=extra_get_resp) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) - file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) upload_file = SimpleNamespace( id="file-2", name="photo.jpg", @@ -204,7 +217,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( file_service_cls.return_value.upload_file.return_value = upload_file with app.test_request_context(method="POST", json={"url": url}): - payload, status = handler(api) + payload, status = handler(api, current_user) assert status == 201 assert payload["id"] == "file-2" @@ -226,7 +239,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"): - handler(api) + handler(api, _make_account()) def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -243,7 +256,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"): - handler(api) + handler(api, _make_account()) def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -258,11 +271,11 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk ) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) - _mock_upload_dependencies(monkeypatch, file_size_within_limit=False) + _, current_user = _mock_upload_dependencies(monkeypatch, file_size_within_limit=False) with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(FileTooLargeError): - handler(api) + handler(api, current_user) def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -276,12 +289,12 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), ) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) - file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded") with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(FileTooLargeError, match="size exceeded"): - handler(api) + handler(api, current_user) def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -295,9 +308,9 @@ def test_remote_file_upload_translates_service_unsupported_type_error(app, monke MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), ) monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) - file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError() with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(UnsupportedFileTypeError): - handler(api) + handler(api, current_user) diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 580941867c..c2575eae0d 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_login import LoginManager, UserMixin +from werkzeug.exceptions import HTTPException from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from controllers.console.workspace.error import AccountNotInitializedError @@ -17,8 +18,11 @@ from controllers.console.wraps import ( only_edition_enterprise, only_edition_self_hosted, setup_required, + with_current_tenant_id, + with_current_user, ) -from models.account import AccountStatus +from models import Account +from models.account import AccountStatus, TenantAccountRole from services.feature_service import LicenseStatus @@ -33,6 +37,17 @@ class MockUser(UserMixin): return self.id +def make_account(account_id: str = "account-1") -> Account: + account = Account( + name="Test Account", + email=f"{account_id}@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = account_id + account.role = TenantAccountRole.OWNER + return account + + def create_app_with_login(): """Create a Flask app with LoginManager configured.""" app = Flask(__name__) @@ -84,6 +99,42 @@ class TestAccountInitialization: protected_view() +class TestCurrentContextInjection: + """Test request context injection decorators.""" + + def test_with_current_tenant_id_injects_tenant_id(self): + class Handler: + @with_current_tenant_id + def get(self, current_tenant_id: str): + return current_tenant_id + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-123")): + assert Handler().get() == "tenant-123" + + def test_with_current_user_injects_account(self): + current_user = make_account() + + class Handler: + @with_current_user + def get(self, injected_user): + return injected_user + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")): + assert Handler().get() is current_user + + def test_stacked_current_context_injectors_preserve_argument_order(self): + current_user = make_account() + + class Handler: + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, injected_user): + return current_tenant_id, injected_user + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")): + assert Handler().get() == ("tenant-123", current_user) + + class TestEditionChecks: """Test edition-specific decorators""" @@ -114,7 +165,7 @@ class TestEditionChecks: # Act & Assert with app.test_request_context(): with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: cloud_view() assert exc_info.value.code == 404 @@ -177,7 +228,7 @@ class TestBillingEnabled: with app.test_request_context(): with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False): with patch("controllers.console.wraps.FeatureService.get_features") as get_features: - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: billing_view() assert exc_info.value.code == 403 @@ -237,6 +288,36 @@ class TestBillingResourceLimits: # Assert assert result == "segment_added" get_features.assert_called_once_with("tenant123", exclude_vector_space=False) + get_features.assert_called_once_with("tenant123", exclude_vector_space=True) + + def test_should_load_vector_space_from_dedicated_quota_api(self): + """Test vector-space limit checks avoid loading the full feature payload.""" + # Arrange + mock_vector_space = MagicMock() + mock_vector_space.limit = 10 + mock_vector_space.size = 5 + + @cloud_edition_billing_resource_check("vector_space") + def add_segment(): + return "segment_added" + + # Act + with patch( + "controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123") + ): + with ( + patch("controllers.console.wraps.dify_config.BILLING_ENABLED", True), + patch( + "controllers.console.wraps.FeatureService.get_vector_space", return_value=mock_vector_space + ) as get_vector_space, + patch("controllers.console.wraps.FeatureService.get_features") as get_features, + ): + result = add_segment() + + # Assert + assert result == "segment_added" + get_vector_space.assert_called_once_with("tenant123") + get_features.assert_not_called() def test_should_reject_when_over_resource_limit(self): """Test that requests are rejected when over resource limits""" @@ -258,7 +339,7 @@ class TestBillingResourceLimits: return_value=(MockUser("test_user"), "tenant123"), ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: add_member() assert exc_info.value.code == 403 assert "members has reached the limit" in str(exc_info.value.description) @@ -283,7 +364,7 @@ class TestBillingResourceLimits: return_value=(MockUser("test_user"), "tenant123"), ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: upload_document() assert exc_info.value.code == 403 @@ -357,7 +438,7 @@ class TestRateLimiting: with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ): - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: knowledge_request() # Verify error diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_composition.py b/api/tests/unit_tests/controllers/openapi/auth/test_composition.py index aa6478dd97..028d32009d 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_composition.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_composition.py @@ -1,66 +1,73 @@ -from unittest.mock import patch - -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy -from controllers.openapi.auth.pipeline import Pipeline -from controllers.openapi.auth.steps import ( - AppAuthzCheck, - AppResolver, - BearerCheck, - CallerMount, - ScopeCheck, - SurfaceCheck, - WorkspaceMembershipCheck, -) -from controllers.openapi.auth.strategies import ( - AccountMounter, - AclStrategy, - EndUserMounter, - MembershipStrategy, -) -from libs.oauth_bearer import SubjectType +from controllers.openapi.auth.composition import account_pipeline, auth_router, external_sso_pipeline +from controllers.openapi.auth.flow import When +from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter +from libs.oauth_bearer import TokenType -def test_pipeline_is_composed(): - assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline) +def test_account_pipeline_is_auth_pipeline(): + assert isinstance(account_pipeline, AuthPipeline) -def test_pipeline_step_order(): - """BearerCheck → SurfaceCheck → ScopeCheck → AppResolver → - WorkspaceMembershipCheck → AppAuthzCheck → CallerMount. - SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits - `openapi.wrong_surface_denied`. Rate-limit is enforced inside - `BearerAuthenticator.authenticate`, not as a separate pipeline step.""" - steps = OAUTH_BEARER_PIPELINE._steps - assert isinstance(steps[0], BearerCheck) - assert isinstance(steps[1], SurfaceCheck) - assert isinstance(steps[2], ScopeCheck) - assert isinstance(steps[3], AppResolver) - assert isinstance(steps[4], WorkspaceMembershipCheck) - assert isinstance(steps[5], AppAuthzCheck) - assert isinstance(steps[6], CallerMount) +def test_external_sso_pipeline_is_auth_pipeline(): + assert isinstance(external_sso_pipeline, AuthPipeline) -def test_pipeline_surface_check_accepts_account_only(): - """Current pipeline serves /apps//run — account surface only.""" - surface = OAUTH_BEARER_PIPELINE._steps[1] - assert isinstance(surface, SurfaceCheck) - assert surface._accepted == frozenset({SubjectType.ACCOUNT}) +def test_auth_router_is_pipeline_router(): + assert isinstance(auth_router, PipelineRouter) -def test_caller_mount_has_both_mounters(): - cm = OAUTH_BEARER_PIPELINE._steps[6] - kinds = {type(m) for m in cm._mounters} - assert AccountMounter in kinds - assert EndUserMounter in kinds +def test_account_pipeline_prepare_has_four_entries(): + assert len(account_pipeline._prepare) == 4 -@patch("controllers.openapi.auth.composition.FeatureService") -def test_strategy_resolver_picks_acl_when_enabled(fs): - fs.get_system_features.return_value.webapp_auth.enabled = True - assert isinstance(_resolve_app_authz_strategy(), AclStrategy) +def test_account_auth_list_has_five_entries(): + assert len(account_pipeline._auth) == 5 -@patch("controllers.openapi.auth.composition.FeatureService") -def test_strategy_resolver_picks_membership_when_disabled(fs): - fs.get_system_features.return_value.webapp_auth.enabled = False - assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy) +def test_external_sso_pipeline_prepare_has_four_entries(): + assert len(external_sso_pipeline._prepare) == 4 + + +def test_external_sso_auth_list_has_three_entries(): + assert len(external_sso_pipeline._auth) == 3 + + +def test_account_pipeline_has_unconditional_load_account(): + non_when = [s for s in account_pipeline._prepare if not isinstance(s, When)] + assert len(non_when) == 1 + + +def test_external_sso_pipeline_all_prepare_entries_are_when(): + assert all(isinstance(s, When) for s in external_sso_pipeline._prepare) + + +def test_first_auth_entry_is_check_scope_in_both_pipelines(): + assert not isinstance(account_pipeline._auth[0], When) + assert not isinstance(external_sso_pipeline._auth[0], When) + + +def test_remaining_auth_entries_are_when_for_account(): + assert all(isinstance(s, When) for s in account_pipeline._auth[1:]) + + +def test_remaining_auth_entries_are_when_for_external_sso(): + assert all(isinstance(s, When) for s in external_sso_pipeline._auth[1:]) + + +def test_router_routes_contain_both_token_types(): + assert TokenType.OAUTH_ACCOUNT in auth_router._routes + assert TokenType.OAUTH_EXTERNAL_SSO in auth_router._routes + + +def test_external_sso_route_has_ee_required_edition(): + route = auth_router._routes[TokenType.OAUTH_EXTERNAL_SSO] + assert isinstance(route, PipelineRoute) + from controllers.openapi.auth.data import Edition + + assert route.required_edition == frozenset({Edition.EE}) + + +def test_account_route_has_no_required_edition(): + route = auth_router._routes[TokenType.OAUTH_ACCOUNT] + assert isinstance(route, PipelineRoute) + assert route.required_edition is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_conditions.py b/api/tests/unit_tests/controllers/openapi/auth/test_conditions.py new file mode 100644 index 0000000000..8367933984 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_conditions.py @@ -0,0 +1,143 @@ +from unittest.mock import MagicMock, patch + +from controllers.openapi.auth.conditions import ( + EDITION_CE, + EDITION_EE, + EDITION_SAAS, + LOADED_APP_IS_PRIVATE, + PATH_HAS_APP_ID, + TOKEN_IS_OAUTH_ACCOUNT, + TOKEN_IS_OAUTH_EXTERNAL_SSO, + WEBAPP_AUTH_ENABLED, + Cond, + config_cond, + data_cond, + request_cond, +) +from controllers.openapi.auth.data import AuthData, Edition, RequestContext +from libs.oauth_bearer import TokenType +from services.enterprise.enterprise_service import WebAppAccessMode + + +def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None): + return RequestContext( + token_type=token_type, + path_params=path_params or {}, + ) + + +def _data(**kwargs): + defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "x", "scopes": frozenset()} + defaults.update(kwargs) + return AuthData(**defaults) + + +def test_and_both_true(): + a = Cond(lambda ctx, _: True) + b = Cond(lambda ctx, _: True) + assert (a & b)(_ctx()) is True + + +def test_and_one_false(): + a = Cond(lambda ctx, _: True) + b = Cond(lambda ctx, _: False) + assert (a & b)(_ctx()) is False + + +def test_or_one_true(): + a = Cond(lambda ctx, _: False) + b = Cond(lambda ctx, _: True) + assert (a | b)(_ctx()) is True + + +def test_or_both_false(): + a = Cond(lambda ctx, _: False) + b = Cond(lambda ctx, _: False) + assert (a | b)(_ctx()) is False + + +def test_invert(): + a = Cond(lambda ctx, _: True) + assert (~a)(_ctx()) is False + + +def test_chain_and_or(): + always_true = Cond(lambda ctx, _: True) + always_false = Cond(lambda ctx, _: False) + assert ((always_true | always_false) & always_true)(_ctx()) is True + + +def test_request_cond_ignores_data(): + c = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT) + assert c(_ctx(TokenType.OAUTH_ACCOUNT)) is True + assert c(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False + + +def test_data_cond_returns_false_when_data_none(): + c = data_cond(lambda data: True) + assert c(_ctx(), None) is False + + +def test_data_cond_evaluates_when_data_present(): + c = data_cond(lambda data: data.token_hash == "secret") + assert c(_ctx(), _data(token_hash="secret")) is True + assert c(_ctx(), _data(token_hash="other")) is False + + +def test_config_cond_ignores_ctx_and_data(): + c = config_cond(lambda: True) + assert c(_ctx()) is True + c2 = config_cond(lambda: False) + assert c2(_ctx(), _data()) is False + + +def test_token_is_oauth_account(): + assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_ACCOUNT)) is True + assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False + + +def test_token_is_oauth_external_sso(): + assert TOKEN_IS_OAUTH_EXTERNAL_SSO(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is True + + +def test_path_has_app_id_true(): + assert PATH_HAS_APP_ID(_ctx(path_params={"app_id": "abc"})) is True + + +def test_path_has_app_id_false(): + assert PATH_HAS_APP_ID(_ctx(path_params={})) is False + + +def test_edition_ce(): + with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.CE): + assert EDITION_CE(_ctx()) is True + assert EDITION_EE(_ctx()) is False + assert EDITION_SAAS(_ctx()) is False + + +def test_edition_ee(): + with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.EE): + assert EDITION_EE(_ctx()) is True + assert EDITION_CE(_ctx()) is False + + +def test_edition_saas(): + with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.SAAS): + assert EDITION_SAAS(_ctx()) is True + + +def test_webapp_auth_enabled(): + mock_features = MagicMock() + mock_features.webapp_auth.enabled = True + with patch("controllers.openapi.auth.conditions.FeatureService.get_system_features", return_value=mock_features): + assert WEBAPP_AUTH_ENABLED(_ctx()) is True + + +def test_loaded_app_is_private(): + data_private = _data(app_access_mode=WebAppAccessMode.PRIVATE) + data_public = _data(app_access_mode=WebAppAccessMode.PUBLIC) + data_none = _data(app_access_mode=None) + assert LOADED_APP_IS_PRIVATE(_ctx(), data_private) is True + assert LOADED_APP_IS_PRIVATE(_ctx(), data_public) is False + assert LOADED_APP_IS_PRIVATE(_ctx(), data_none) is False + assert LOADED_APP_IS_PRIVATE(_ctx(), None) is False diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_context.py b/api/tests/unit_tests/controllers/openapi/auth/test_context.py deleted file mode 100644 index cc9c011342..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_context.py +++ /dev/null @@ -1,21 +0,0 @@ -from controllers.openapi.auth.context import Context - - -def test_context_starts_unpopulated(): - ctx = Context(required_scope="apps:run") - assert ctx.bearer_token is None - assert ctx.path_params == {} - assert ctx.subject_type is None - assert ctx.subject_email is None - assert ctx.account_id is None - assert ctx.scopes == frozenset() - assert ctx.app is None - assert ctx.tenant is None - assert ctx.caller is None - assert ctx.caller_kind is None - - -def test_context_fields_are_mutable(): - ctx = Context(required_scope="apps:run") - ctx.scopes = frozenset({"full"}) - assert "full" in ctx.scopes diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_data.py b/api/tests/unit_tests/controllers/openapi/auth/test_data.py new file mode 100644 index 0000000000..c39ed9c6d0 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_data.py @@ -0,0 +1,117 @@ +import uuid +from unittest.mock import patch + +import pytest +from pydantic import ValidationError + +from controllers.openapi.auth.data import ( + AuthData, + Edition, + ExternalIdentity, + RequestContext, + current_edition, +) +from libs.oauth_bearer import Scope, TokenType + + +def test_current_edition_saas(): + with patch("controllers.openapi.auth.data.dify_config") as cfg: + cfg.EDITION = "CLOUD" + cfg.ENTERPRISE_ENABLED = True + assert current_edition() == Edition.SAAS + + +def test_current_edition_ee(): + with patch("controllers.openapi.auth.data.dify_config") as cfg: + cfg.EDITION = "SELF_HOSTED" + cfg.ENTERPRISE_ENABLED = True + assert current_edition() == Edition.EE + + +def test_current_edition_ce(): + with patch("controllers.openapi.auth.data.dify_config") as cfg: + cfg.EDITION = "SELF_HOSTED" + cfg.ENTERPRISE_ENABLED = False + assert current_edition() == Edition.CE + + +def test_external_identity_frozen(): + ei = ExternalIdentity(email="a@b.com", issuer="idp") + with pytest.raises(ValidationError): + ei.email = "other@b.com" # type: ignore[misc] + + +def test_external_identity_issuer_optional(): + ei = ExternalIdentity(email="a@b.com") + assert ei.issuer is None + + +def test_request_context_frozen(): + ctx = RequestContext( + token_type=TokenType.OAUTH_ACCOUNT, + path_params={"app_id": "123"}, + ) + with pytest.raises(ValidationError): + ctx.token_type = TokenType.OAUTH_EXTERNAL_SSO # type: ignore[misc] + + +def test_request_context_scope_optional(): + ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={}) + assert ctx.scope is None + + +def test_auth_data_is_mutable(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + token_hash="abc", + scopes=frozenset({Scope.FULL}), + ) + data.token_type = TokenType.OAUTH_EXTERNAL_SSO + assert data.token_type == TokenType.OAUTH_EXTERNAL_SSO + + +def test_auth_data_path_params_defaults_empty(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + token_hash="abc", + scopes=frozenset(), + ) + assert data.path_params == {} + + +def test_auth_data_account_id_optional(): + data = AuthData( + token_type=TokenType.OAUTH_EXTERNAL_SSO, + token_hash="abc", + scopes=frozenset({Scope.APPS_RUN}), + external_identity=ExternalIdentity(email="u@sso.com"), + ) + assert data.account_id is None + + +def test_auth_data_external_identity_none_for_account(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="abc", + scopes=frozenset({Scope.FULL}), + ) + assert data.external_identity is None + + +def test_auth_data_tenants_default_empty(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + token_hash="abc", + scopes=frozenset(), + ) + assert data.tenants == {} + + +def test_auth_data_token_id_optional(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + token_hash="abc", + scopes=frozenset(), + ) + assert data.token_id is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_flow.py b/api/tests/unit_tests/controllers/openapi/auth/test_flow.py new file mode 100644 index 0000000000..3ea7ac2b12 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_flow.py @@ -0,0 +1,42 @@ +import inspect + +from controllers.openapi.auth.conditions import Cond +from controllers.openapi.auth.data import AuthData, RequestContext +from controllers.openapi.auth.flow import When +from libs.oauth_bearer import TokenType + + +def _ctx(): + return RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={}) + + +def _data(): + return AuthData(token_type=TokenType.OAUTH_ACCOUNT, token_hash="x", scopes=frozenset()) + + +def test_applies_returns_true_when_condition_true(): + w = When(Cond(lambda ctx, _: True), then=lambda b: None) + assert w.applies(_ctx()) is True + + +def test_applies_returns_false_when_condition_false(): + w = When(Cond(lambda ctx, _: False), then=lambda b: None) + assert w.applies(_ctx()) is False + + +def test_applies_with_data(): + w = When(Cond(lambda ctx, data: data is not None), then=lambda b: None) + assert w.applies(_ctx(), _data()) is True + assert w.applies(_ctx(), None) is False + + +def test_call_invokes_step(): + calls = [] + w = When(Cond(lambda ctx, _: True), then=lambda arg: calls.append(arg)) + w("payload") + assert calls == ["payload"] + + +def test_then_is_keyword_only(): + sig = inspect.signature(When.__init__) + assert sig.parameters["then"].kind.name == "KEYWORD_ONLY" diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py b/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py index 15538275f5..a92f90112f 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py @@ -1,59 +1,269 @@ +import uuid +from unittest.mock import MagicMock, patch + import pytest from flask import Flask +from werkzeug.exceptions import Forbidden, NotFound, Unauthorized -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.pipeline import Pipeline +from controllers.openapi.auth.data import AuthData, Edition +from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter +from libs.oauth_bearer import Scope, TokenType -def test_run_invokes_each_step_in_order(): - calls = [] - - class S: - def __init__(self, tag): - self.tag = tag - - def __call__(self, ctx): - calls.append(self.tag) - - Pipeline(S("a"), S("b"), S("c")).run(Context(required_scope="x")) - assert calls == ["a", "b", "c"] +def _make_identity( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=None, + scopes=None, + token_hash="testhash", + subject_email=None, + subject_issuer=None, + verified_tenants=None, + token_id=None, +): + identity = MagicMock() + identity.token_type = token_type + identity.account_id = account_id or uuid.uuid4() + identity.scopes = scopes or frozenset({Scope.FULL}) + identity.token_hash = token_hash + identity.subject_email = subject_email + identity.subject_issuer = subject_issuer + identity.verified_tenants = verified_tenants or {} + identity.token_id = token_id or uuid.uuid4() + return identity -def test_run_short_circuits_on_raise(): - calls = [] - - class Boom: - def __call__(self, ctx): - raise RuntimeError("boom") - - class Tail: - def __call__(self, ctx): - calls.append("ran") - - with pytest.raises(RuntimeError): - Pipeline(Boom(), Tail()).run(Context(required_scope="x")) - assert calls == [] +@pytest.fixture +def app(): + return Flask(__name__) -def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs(): - seen = {} +def _make_router(token_type=TokenType.OAUTH_ACCOUNT, prepare=None, auth=None): + pipeline = AuthPipeline(prepare=prepare or [], auth=auth or []) + return PipelineRouter({token_type: PipelineRoute(pipeline)}) - class FakeStep: - def __call__(self, ctx): - ctx.app = "APP" - ctx.caller = "CALLER" - ctx.caller_kind = "account" - pipeline = Pipeline(FakeStep()) +def _fake_identity(): + return _make_identity() - @pipeline.guard(scope="apps:run") - def handler(app_model, caller, caller_kind): - seen["app_model"] = app_model - seen["caller"] = caller - seen["caller_kind"] = caller_kind - return "ok" - app = Flask(__name__) - with app.test_request_context("/x", method="POST"): - assert handler() == "ok" - assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"} +# --- PipelineRouter.guard --- + + +def test_guard_passes_auth_data_to_view(app): + router = _make_router() + received = {} + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx"), + ): + mock_auth.return_value.authenticate.return_value = _fake_identity() + + @router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def view(*, auth_data): + received["data"] = auth_data + + view() + + assert isinstance(received["data"], AuthData) + + +def test_guard_edition_gate_returns_404(app): + router = _make_router() + + with app.test_request_context("/test"): + with patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE): + + @router.guard(scope=Scope.FULL, edition=frozenset({Edition.EE})) + def view(*, auth_data): + pass + + with pytest.raises(NotFound): + view() + + +def test_guard_token_type_gate_returns_403(app): + router = _make_router() + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.emit_wrong_surface"), + patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE), + ): + identity = _fake_identity() + identity.token_type = TokenType.OAUTH_EXTERNAL_SSO + mock_auth.return_value.authenticate.return_value = identity + + @router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def view(*, auth_data): + pass + + with pytest.raises(Forbidden): + view() + + +def test_guard_unregistered_token_type_returns_403(app): + router = _make_router(token_type=TokenType.OAUTH_ACCOUNT) + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE), + ): + identity = _fake_identity() + identity.token_type = TokenType.OAUTH_EXTERNAL_SSO + mock_auth.return_value.authenticate.return_value = identity + + @router.guard(scope=Scope.FULL) + def view(*, auth_data): + pass + + with pytest.raises(Forbidden): + view() + + +def test_guard_no_bearer_returns_401(app): + router = _make_router() + + with app.test_request_context("/test"): + with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value=None): + + @router.guard(scope=Scope.FULL) + def view(*, auth_data): + pass + + with pytest.raises(Unauthorized): + view() + + +def test_guard_runs_prepare_steps_in_order(app): + order = [] + + def p1(b): + order.append("p1") + + def p2(b): + order.append("p2") + + router = _make_router(prepare=[p1, p2]) + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx"), + ): + mock_auth.return_value.authenticate.return_value = _fake_identity() + + @router.guard(scope=Scope.FULL) + def view(*, auth_data): + pass + + view() + + assert order == ["p1", "p2"] + + +def test_guard_resets_auth_ctx_on_exception(app): + router = _make_router() + reset_called = [] + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value="tok"), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx", side_effect=lambda t: reset_called.append(t)), + ): + mock_auth.return_value.authenticate.return_value = _fake_identity() + + @router.guard(scope=Scope.FULL) + def view(*, auth_data): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError): + view() + + assert reset_called == ["tok"] + + +def test_router_rejects_token_type_on_wrong_edition(app): + pipeline = AuthPipeline(prepare=[], auth=[]) + route = PipelineRoute(pipeline, required_edition=frozenset({Edition.EE})) + router = PipelineRouter({TokenType.OAUTH_EXTERNAL_SSO: route}) + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE), + ): + identity = _make_identity(token_type=TokenType.OAUTH_EXTERNAL_SSO) + mock_auth.return_value.authenticate.return_value = identity + + @router.guard(scope=Scope.APPS_RUN) + def view(*, auth_data): + pass + + with pytest.raises(Forbidden): + view() + + +def test_guard_populates_external_identity_from_subject_email(app): + from controllers.openapi.auth.data import ExternalIdentity + + router = _make_router(token_type=TokenType.OAUTH_EXTERNAL_SSO) + received = {} + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx"), + ): + identity = _make_identity( + token_type=TokenType.OAUTH_EXTERNAL_SSO, + subject_email="user@sso.com", + subject_issuer="https://idp.example.com", + ) + mock_auth.return_value.authenticate.return_value = identity + + @router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO})) + def view(*, auth_data): + received["data"] = auth_data + + view() + + assert isinstance(received["data"].external_identity, ExternalIdentity) + assert received["data"].external_identity.email == "user@sso.com" + assert received["data"].external_identity.issuer == "https://idp.example.com" + + +def test_guard_no_external_identity_when_subject_email_absent(app): + router = _make_router() + received = {} + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx"), + ): + mock_auth.return_value.authenticate.return_value = _make_identity(subject_email=None) + + @router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def view(*, auth_data): + received["data"] = auth_data + + view() + + assert received["data"].external_identity is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_prepare.py b/api/tests/unit_tests/controllers/openapi/auth/test_prepare.py new file mode 100644 index 0000000000..39d8aafa0e --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_prepare.py @@ -0,0 +1,183 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound, Unauthorized + +from controllers.openapi.auth.data import AuthData, ExternalIdentity +from controllers.openapi.auth.prepare import ( + load_account, + load_app, + load_app_access_mode, + load_tenant, + resolve_external_user, +) +from libs.oauth_bearer import TokenType + + +def _make_auth_data(**kwargs) -> AuthData: + mock_fields = {k: kwargs.pop(k) for k in ("app", "tenant", "caller") if k in kwargs} + data = AuthData( + token_type=kwargs.pop("token_type", TokenType.OAUTH_ACCOUNT), + token_hash=kwargs.pop("token_hash", "testhash"), + scopes=kwargs.pop("scopes", frozenset()), + **kwargs, + ) + for k, v in mock_fields.items(): + setattr(data, k, v) + return data + + +def test_load_app_writes_app_to_data(): + app = MagicMock() + app.status = "normal" + app.enable_api = True + data = _make_auth_data(path_params={"app_id": "abc"}) + with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app): + load_app(data) + assert data.app is app + + +def test_load_app_raises_not_found_when_missing(): + data = _make_auth_data(path_params={"app_id": "missing"}) + with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=None): + with pytest.raises(NotFound): + load_app(data) + + +def test_load_app_raises_not_found_when_not_normal(): + app = MagicMock() + app.status = "archived" + data = _make_auth_data(path_params={"app_id": "abc"}) + with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app): + with pytest.raises(NotFound): + load_app(data) + + +def test_load_app_raises_forbidden_when_api_disabled(): + app = MagicMock() + app.status = "normal" + app.enable_api = False + data = _make_auth_data(path_params={"app_id": "abc"}) + with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app): + with pytest.raises(Forbidden): + load_app(data) + + +def test_load_tenant_writes_tenant(): + app = MagicMock() + app.tenant_id = uuid.uuid4() + tenant = MagicMock() + tenant.status = "normal" + data = _make_auth_data(app=app) + with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant): + load_tenant(data) + assert data.tenant is tenant + + +def test_load_tenant_raises_forbidden_when_archived(): + from models.account import TenantStatus + + app = MagicMock() + app.tenant_id = uuid.uuid4() + tenant = MagicMock() + tenant.status = TenantStatus.ARCHIVE + data = _make_auth_data(app=app) + with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant): + with pytest.raises(Forbidden): + load_tenant(data) + + +def test_load_tenant_raises_forbidden_when_missing(): + app = MagicMock() + app.tenant_id = uuid.uuid4() + data = _make_auth_data(app=app) + with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=None): + with pytest.raises(Forbidden): + load_tenant(data) + + +def test_load_tenant_raises_500_when_app_not_loaded(): + from werkzeug.exceptions import InternalServerError + + data = _make_auth_data() + with pytest.raises(InternalServerError): + load_tenant(data) + + +def test_load_account_writes_caller(): + account = MagicMock() + account_id = uuid.uuid4() + data = _make_auth_data(account_id=account_id) + with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account): + load_account(data) + assert data.caller is account + assert data.caller_kind == "account" + + +def test_load_account_sets_current_tenant_when_tenant_present(): + account = MagicMock() + tenant = MagicMock() + data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant) + with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account): + load_account(data) + assert account.current_tenant is tenant + + +def test_load_account_raises_unauthorized_when_not_found(): + data = _make_auth_data(account_id=uuid.uuid4()) + with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=None): + with pytest.raises(Unauthorized): + load_account(data) + + +def test_resolve_external_user_writes_caller(): + tenant = MagicMock() + app = MagicMock() + end_user = MagicMock() + ext = ExternalIdentity(email="user@sso.com") + data = _make_auth_data(tenant=tenant, app=app, external_identity=ext) + with patch("controllers.openapi.auth.prepare.EndUserService.get_or_create_end_user_by_type", return_value=end_user): + resolve_external_user(data) + assert data.caller is end_user + assert data.caller_kind == "end_user" + + +def test_resolve_external_user_raises_unauthorized_when_context_missing(): + data = _make_auth_data(tenant=None, app=MagicMock(), external_identity=ExternalIdentity(email="u@s.com")) + with pytest.raises(Unauthorized): + resolve_external_user(data) + + +def test_load_app_access_mode_writes_mode(): + from services.enterprise.enterprise_service import WebAppAccessMode + + app = MagicMock() + app.id = "app-1" + settings = MagicMock() + settings.access_mode = "public" + data = _make_auth_data(app=app) + with patch( + "controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + return_value=settings, + ): + load_app_access_mode(data) + assert data.app_access_mode == WebAppAccessMode.PUBLIC + + +def test_load_app_access_mode_writes_none_when_value_error(): + app = MagicMock() + app.id = "app-1" + data = _make_auth_data(app=app) + with patch( + "controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + side_effect=ValueError("No data found."), + ): + load_app_access_mode(data) + assert data.app_access_mode is None + + +def test_load_app_access_mode_no_op_when_app_missing(): + data = _make_auth_data() + load_app_access_mode(data) + assert data.app_access_mode is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py b/api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py new file mode 100644 index 0000000000..68b436e824 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py @@ -0,0 +1,330 @@ +"""Role-gate tests. + +The decorator wraps `validate_bearer` + `accept_subjects` and must: +- 404 when caller is not a member of ``workspace_id`` (parity with + `GET /openapi/v1/workspaces/`; prevents tenant-id existence leak) +- 403 when caller IS a member but their role is not in the allowed set +- pass through when role matches (or when no role restriction given) +- raise RuntimeError on missing auth context / account_id / workspace_id — + those are wiring bugs, not user-driven failures + +Identity is read from the openapi auth ContextVar — the slot +`validate_bearer` publishes — so these tests seed it via `_seed` +(``set_auth_ctx``), NOT ``flask.g``. `test_seeding_only_flask_g_*` +locks in that ``g`` is *not* a valid identity source. +""" + +from __future__ import annotations + +import uuid +from contextlib import contextmanager +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden, NotFound + +from controllers.openapi.auth.role_gate import require_workspace_role +from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx +from models.account import TenantAccountRole + +# Tokens from `_seed`'s `set_auth_ctx` calls, drained after each test so a +# published identity can't leak into the next (the ContextVar is module-global +# and worker threads are reused). Seed via `_seed(...)`, never `flask.g`. +_seed_tokens: list = [] + + +def _seed(ctx: AuthContext) -> None: + _seed_tokens.append(set_auth_ctx(ctx)) + + +@pytest.fixture(autouse=True) +def _reset_auth_ctx(): + yield + while _seed_tokens: + reset_auth_ctx(_seed_tokens.pop()) + + +def _account_ctx(account_id: uuid.UUID | None = None) -> AuthContext: + return AuthContext( + subject_type=SubjectType.ACCOUNT, + subject_email="user@example.com", + subject_issuer="dify:account", + account_id=account_id or uuid.uuid4(), + client_id="difyctl", + scopes=frozenset({Scope.FULL}), + token_id=uuid.uuid4(), + token_type=TokenType.OAUTH_ACCOUNT, + expires_at=datetime.now(UTC), + token_hash="h1", + verified_tenants={}, + ) + + +def _sso_ctx() -> AuthContext: + return AuthContext( + subject_type=SubjectType.EXTERNAL_SSO, + subject_email="sso@partner.com", + subject_issuer="https://idp.partner.com", + account_id=None, + client_id="difyctl", + scopes=frozenset({Scope.APPS_RUN}), + token_id=uuid.uuid4(), + token_type=TokenType.OAUTH_EXTERNAL_SSO, + expires_at=datetime.now(UTC), + token_hash="h2", + verified_tenants={}, + ) + + +@contextmanager +def _stub_role(role: TenantAccountRole | None): + """Stub the service-layer membership lookup the gate delegates to. + + The gate no longer issues SQL itself — it calls + ``TenantService.get_account_role_in_tenant`` and acts purely on the + returned role (``None`` → non-member). These tests pin that behaviour; + the query itself is covered in ``TestTenantService``. + """ + with patch( + "controllers.openapi.auth.role_gate.TenantService.get_account_role_in_tenant", + return_value=role, + ) as mocked: + yield mocked + + +# --------------------------------------------------------------------------- +# Non-member → 404 +# --------------------------------------------------------------------------- + + +def test_non_member_gets_404(): + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role() + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"): + _seed(_account_ctx()) + with _stub_role(None): + with pytest.raises(NotFound): + view(workspace_id=workspace_id) + + +# --------------------------------------------------------------------------- +# Member with insufficient role → 403 +# --------------------------------------------------------------------------- + + +def test_normal_member_blocked_when_admin_required(): + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"): + _seed(_account_ctx()) + with _stub_role(TenantAccountRole.NORMAL): + with pytest.raises(Forbidden): + view(workspace_id=workspace_id) + + +def test_editor_blocked_when_admin_required(): + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"): + _seed(_account_ctx()) + with _stub_role(TenantAccountRole.EDITOR): + with pytest.raises(Forbidden): + view(workspace_id=workspace_id) + + +# --------------------------------------------------------------------------- +# Member with allowed role → pass +# --------------------------------------------------------------------------- + + +def test_admin_passes_when_admin_required(): + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"): + _seed(_account_ctx()) + with _stub_role(TenantAccountRole.ADMIN): + assert view(workspace_id=workspace_id) == "ok" + + +def test_owner_passes_when_admin_required(): + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/members"): + _seed(_account_ctx()) + with _stub_role(TenantAccountRole.OWNER): + assert view(workspace_id=workspace_id) == "ok" + + +# --------------------------------------------------------------------------- +# Membership-only (no role restriction) +# --------------------------------------------------------------------------- + + +def test_membership_only_passes_for_any_role(): + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role() + def view(workspace_id: str) -> str: + return "ok" + + for role in ( + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + ): + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"): + _seed(_account_ctx()) + with _stub_role(role): + assert view(workspace_id=workspace_id) == "ok" + + +def test_membership_only_still_404s_non_member(): + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role() + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"): + _seed(_account_ctx()) + with _stub_role(None): + with pytest.raises(NotFound): + view(workspace_id=workspace_id) + + +# --------------------------------------------------------------------------- +# Lookup is scoped to the caller's account_id and the URL workspace_id +# --------------------------------------------------------------------------- + + +def test_lookup_is_scoped_to_caller_and_workspace(): + """The decorator must delegate the lookup keyed on + `(caller's account_id, URL workspace_id)` — otherwise a member of + workspace A could quietly hit endpoints for workspace B. Assert the + exact arguments handed to the service; the SQL those arguments compile + to is pinned in ``TestTenantService.test_get_account_role_in_tenant_*``. + """ + + app = Flask(__name__) + account_id = uuid.uuid4() + workspace_id = str(uuid.uuid4()) + + @require_workspace_role() + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"): + _seed(_account_ctx(account_id=account_id)) + with _stub_role(TenantAccountRole.NORMAL) as mocked: + view(workspace_id=workspace_id) + + _session, passed_account_id, passed_workspace_id = mocked.call_args.args + assert passed_account_id == str(account_id) + assert passed_workspace_id == workspace_id + + +# --------------------------------------------------------------------------- +# Wiring bugs surface as RuntimeError (loud), not 403 (silent) +# --------------------------------------------------------------------------- + + +def test_missing_auth_ctx_is_runtime_error(): + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role() + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"): + with pytest.raises(RuntimeError): + view(workspace_id=workspace_id) + + +def test_seeding_only_flask_g_does_not_satisfy_gate(): + """Regression — pins the identity source to the ContextVar, not ``flask.g``. + + Production fills the ContextVar (``validate_bearer`` → ``set_auth_ctx``) + and never touches ``g.auth_ctx``. An earlier revision of this gate read + ``g.auth_ctx``, so every real request raised RuntimeError → 500 while the + suite stayed green (it seeded ``g`` directly). Here we seed ONLY ``g`` and + leave the ContextVar empty: the gate must still raise, proving it does not + accept ``g`` as an identity source. Reading ``g`` again would let the + membership lookup run (stubbed to succeed) and this would fail. + """ + from flask import g + + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role() + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"): + g.auth_ctx = _account_ctx() # the wrong slot — must be ignored + with _stub_role(TenantAccountRole.OWNER): + with pytest.raises(RuntimeError): + view(workspace_id=workspace_id) + + +def test_sso_caller_is_runtime_error(): + """External SSO context has account_id=None — the caller stacked the + role gate without `accept_subjects(SubjectType.ACCOUNT)`. That's a + wiring bug, surface it as RuntimeError rather than 404 the SSO user.""" + + app = Flask(__name__) + workspace_id = str(uuid.uuid4()) + + @require_workspace_role() + def view(workspace_id: str) -> str: + return "ok" + + with app.test_request_context(f"/openapi/v1/workspaces/{workspace_id}/switch"): + _seed(_sso_ctx()) + with pytest.raises(RuntimeError): + view(workspace_id=workspace_id) + + +def test_missing_workspace_id_kwarg_is_runtime_error(): + app = Flask(__name__) + + @require_workspace_role() + def view() -> str: + return "ok" + + with app.test_request_context("/openapi/v1/foo"): + _seed(_account_ctx()) + with pytest.raises(RuntimeError): + view() diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py deleted file mode 100644 index f051f1a71c..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py +++ /dev/null @@ -1,64 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import patch - -import pytest -from werkzeug.exceptions import BadRequest, Forbidden, NotFound - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import AppResolver -from models import TenantStatus - - -def _ctx(path_params: dict[str, str] | None) -> Context: - return Context(required_scope="apps:run", path_params=path_params or {}) - - -def _app(*, status="normal", enable_api=True): - return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api) - - -def _tenant(*, status=TenantStatus.NORMAL): - return SimpleNamespace(id="t1", status=status) - - -def test_resolver_rejects_missing_path_param(): - with pytest.raises(BadRequest): - AppResolver()(_ctx({})) - - -def test_resolver_rejects_empty_path_params(): - # `Pipeline.guard` always seeds an empty dict when Flask reports no - # view args, so a missing `app_id` key surfaces here as BadRequest. - with pytest.raises(BadRequest): - AppResolver()(_ctx(None)) - - -@patch("controllers.openapi.auth.steps.db") -def test_resolver_404_when_app_missing(db): - db.session.get.side_effect = [None] - with pytest.raises(NotFound): - AppResolver()(_ctx({"app_id": "x"})) - - -@patch("controllers.openapi.auth.steps.db") -def test_resolver_403_when_disabled(db): - db.session.get.side_effect = [_app(enable_api=False)] - with pytest.raises(Forbidden) as exc: - AppResolver()(_ctx({"app_id": "x"})) - assert "service_api_disabled" in str(exc.value.description) - - -@patch("controllers.openapi.auth.steps.db") -def test_resolver_403_when_tenant_archived(db): - db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)] - with pytest.raises(Forbidden): - AppResolver()(_ctx({"app_id": "x"})) - - -@patch("controllers.openapi.auth.steps.db") -def test_resolver_populates_app_and_tenant(db): - db.session.get.side_effect = [_app(), _tenant()] - ctx = _ctx({"app_id": "x"}) - AppResolver()(ctx) - assert ctx.app.id == "app1" - assert ctx.tenant.id == "t1" diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py deleted file mode 100644 index 6a5933da3b..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py +++ /dev/null @@ -1,76 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import AppAuthzCheck -from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy -from libs.oauth_bearer import SubjectType - - -def _ctx(*, subject_type, account_id="acc1"): - c = Context(required_scope="apps:run") - c.subject_type = subject_type - c.subject_email = "alice@example.com" - c.account_id = account_id - c.app = SimpleNamespace(id="app1") - c.tenant = SimpleNamespace(id="t1") - return c - - -@patch("controllers.openapi.auth.strategies.EnterpriseService") -def test_acl_strategy_private_calls_inner_api(ent): - ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private") - ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True - assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True - ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with( - user_id="acc1", - app_id="app1", - ) - - -@pytest.mark.parametrize( - ("access_mode", "subject_type", "expected"), - [ - ("public", SubjectType.ACCOUNT, True), - ("public", SubjectType.EXTERNAL_SSO, True), - ("sso_verified", SubjectType.ACCOUNT, True), - ("sso_verified", SubjectType.EXTERNAL_SSO, True), - ("private_all", SubjectType.ACCOUNT, True), - ("private_all", SubjectType.EXTERNAL_SSO, False), - ("private", SubjectType.EXTERNAL_SSO, False), - ], -) -@patch("controllers.openapi.auth.strategies.EnterpriseService") -def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected): - """Step 1 matrix: subject vs access-mode compatibility. No inner API call expected.""" - ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode) - account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None - assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected - ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called() - - -@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant") -@patch("controllers.openapi.auth.strategies.db") -def test_membership_strategy_uses_join_lookup(db_mock, member): - member.return_value = True - assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True - member.assert_called_once_with(db_mock.session, "acc1", "t1") - - -def test_membership_strategy_rejects_external_sso(): - assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False - - -def test_app_authz_check_raises_when_strategy_denies(): - deny = SimpleNamespace(authorize=lambda c: False) - with pytest.raises(Forbidden) as exc: - AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT)) - assert "subject_no_app_access" in str(exc.value.description) - - -def test_app_authz_check_passes_when_strategy_allows(): - allow = SimpleNamespace(authorize=lambda c: True) - AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT)) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py deleted file mode 100644 index 329f158f30..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py +++ /dev/null @@ -1,83 +0,0 @@ -import uuid -from datetime import UTC, datetime -from unittest.mock import patch - -import pytest -from flask import Flask -from werkzeug.exceptions import Unauthorized - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import BearerCheck -from libs.oauth_bearer import ( - AuthContext, - InvalidBearerError, - Scope, - SubjectType, - reset_auth_ctx, - try_get_auth_ctx, -) - - -def _ctx(bearer_token: str | None) -> Context: - return Context(required_scope="apps:run", bearer_token=bearer_token) - - -def test_bearer_check_rejects_missing_header(): - app = Flask(__name__) - with app.test_request_context(), pytest.raises(Unauthorized): - BearerCheck()(_ctx(None)) - - -@patch("controllers.openapi.auth.steps.get_authenticator") -def test_bearer_check_rejects_unknown_prefix(get_auth): - get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer") - app = Flask(__name__) - with app.test_request_context(), pytest.raises(Unauthorized): - BearerCheck()(_ctx("xxx_abc")) - - -@patch("controllers.openapi.auth.steps.get_authenticator") -def test_bearer_check_populates_context_and_publishes_auth_ctx(get_auth): - tok_id = uuid.uuid4() - authn = AuthContext( - subject_type=SubjectType.ACCOUNT, - subject_email="a@x.com", - subject_issuer=None, - account_id=None, - client_id="difyctl", - scopes=frozenset({Scope.FULL}), - token_id=tok_id, - source="oauth-account", - expires_at=datetime.now(UTC), - token_hash="hash-1", - verified_tenants={}, - ) - get_auth.return_value.authenticate.return_value = authn - - app = Flask(__name__) - ctx = _ctx("dfoa_abc") - with app.test_request_context(): - BearerCheck()(ctx) - try: - assert ctx.subject_type == SubjectType.ACCOUNT - assert ctx.subject_email == "a@x.com" - assert ctx.scopes == frozenset({Scope.FULL}) - assert ctx.source == "oauth-account" - assert ctx.token_id == tok_id - assert ctx.token_hash == "hash-1" - # BearerCheck must also publish the same identity on the - # openapi auth ContextVar so the surface gate + downstream - # handlers don't see two different identity sources between - # the decorator + pipeline paths. The reset token is parked - # on `ctx.auth_ctx_reset_token` for `Pipeline.guard` to - # consume in its `finally`. - published = try_get_auth_ctx() - assert published is authn - assert published.client_id == "difyctl" - assert ctx.auth_ctx_reset_token is not None - finally: - # In production `Pipeline.guard` resets the ContextVar; in - # this isolated step-level test we reset it ourselves so the - # value doesn't leak into the next test on the same worker. - assert ctx.auth_ctx_reset_token is not None - reset_auth_ctx(ctx.auth_ctx_reset_token) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py deleted file mode 100644 index 82ea07d736..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py +++ /dev/null @@ -1,157 +0,0 @@ -"""Unit tests for WorkspaceMembershipCheck (Layer 0).""" - -from __future__ import annotations - -import uuid -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import WorkspaceMembershipCheck -from libs.oauth_bearer import SubjectType - - -def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context: - c = Context(required_scope="apps:read") - c.subject_type = subject_type - c.account_id = account_id - c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None - c.cached_verified_tenants = cached_verified_tenants - c.token_hash = token_hash - return c - - -@pytest.fixture -def step(): - return WorkspaceMembershipCheck() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = True - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id=str(uuid.uuid4()), - tenant_id=str(uuid.uuid4()), - cached_verified_tenants={}, - token_hash="hash-1", - ) - step(ctx) # no raise - mock_db.session.execute.assert_not_called() - mock_record.assert_not_called() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - ctx = _ctx( - subject_type=SubjectType.EXTERNAL_SSO, - account_id=None, - tenant_id=str(uuid.uuid4()), - cached_verified_tenants={}, - token_hash="hash-1", - ) - step(ctx) # no raise - mock_db.session.execute.assert_not_called() - mock_record.assert_not_called() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={"t1": True}, - token_hash="hash-1", - ) - step(ctx) - mock_db.session.execute.assert_not_called() - mock_record.assert_not_called() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={"t1": False}, - token_hash="hash-1", - ) - with pytest.raises(Forbidden, match="workspace_membership_revoked"): - step(ctx) - mock_db.session.execute.assert_not_called() - mock_record.assert_not_called() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - mock_db.session.execute.return_value.scalar_one_or_none.return_value = None - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={}, - token_hash="hash-1", - ) - with pytest.raises(Forbidden, match="workspace_membership_revoked"): - step(ctx) - mock_record.assert_called_once_with("hash-1", "t1", False) - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - mock_db.session.execute.side_effect = [ - MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")), - MagicMock(scalar_one_or_none=MagicMock(return_value="banned")), - ] - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={}, - token_hash="hash-1", - ) - with pytest.raises(Forbidden, match="workspace_membership_revoked"): - step(ctx) - mock_record.assert_called_once_with("hash-1", "t1", False) - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_allows_active_member(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - mock_db.session.execute.side_effect = [ - MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")), - MagicMock(scalar_one_or_none=MagicMock(return_value="active")), - ] - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={}, - token_hash="hash-1", - ) - step(ctx) # no raise - mock_record.assert_called_once_with("hash-1", "t1", True) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py deleted file mode 100644 index 8c5ad38a16..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py +++ /dev/null @@ -1,77 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import patch - -import pytest -from werkzeug.exceptions import Unauthorized - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import CallerMount -from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter -from core.app.entities.app_invoke_entities import InvokeFrom -from libs.oauth_bearer import SubjectType - - -def _ctx(*, subject_type, account_id=None, subject_email=None): - c = Context(required_scope="apps:run") - c.subject_type = subject_type - c.account_id = account_id - c.subject_email = subject_email - c.app = SimpleNamespace(id="app1") - c.tenant = SimpleNamespace(id="t1") - return c - - -@patch("controllers.openapi.auth.strategies._login_as") -@patch("controllers.openapi.auth.strategies.db") -def test_account_mounter(db, login): - account = SimpleNamespace() - db.session.get.return_value = account - ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1") - AccountMounter().mount(ctx) - assert ctx.caller is account - assert ctx.caller.current_tenant is ctx.tenant - assert ctx.caller_kind == "account" - login.assert_called_once_with(account) - - -@patch("controllers.openapi.auth.strategies._login_as") -@patch("controllers.openapi.auth.strategies.EndUserService") -def test_end_user_mounter(svc, login): - eu = SimpleNamespace() - svc.get_or_create_end_user_by_type.return_value = eu - ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com") - EndUserMounter().mount(ctx) - svc.get_or_create_end_user_by_type.assert_called_once_with( - InvokeFrom.OPENAPI, - tenant_id="t1", - app_id="app1", - user_id="a@x.com", - ) - assert ctx.caller is eu - assert ctx.caller_kind == "end_user" - - -def test_caller_mount_dispatches_by_subject_type(): - seen = {} - - class Fake: - def __init__(self, st, tag): - self._st, self._tag = st, tag - - def applies_to(self, st): - return st == self._st - - def mount(self, ctx): - seen["who"] = self._tag - - cm = CallerMount( - Fake(SubjectType.ACCOUNT, "acct"), - Fake(SubjectType.EXTERNAL_SSO, "sso"), - ) - cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO)) - assert seen == {"who": "sso"} - - -def test_caller_mount_raises_when_none_applies(): - with pytest.raises(Unauthorized): - CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT)) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py deleted file mode 100644 index b4adbacd1e..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import ScopeCheck - - -def _ctx(scopes, required): - c = Context(required_scope=required) - c.scopes = frozenset(scopes) - return c - - -def test_scope_check_passes_on_full(): - ScopeCheck()(_ctx({"full"}, "apps:run")) - - -def test_scope_check_passes_on_explicit_match(): - ScopeCheck()(_ctx({"apps:run"}, "apps:run")) - - -def test_scope_check_rejects_when_missing(): - with pytest.raises(Forbidden) as exc: - ScopeCheck()(_ctx({"apps:read"}, "apps:run")) - assert "insufficient_scope" in str(exc.value.description) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py b/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py deleted file mode 100644 index f3b49b18da..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Surface gate tests. - -The gate has two attachment forms — decorator (`accept_subjects`) and -pipeline step (`SurfaceCheck`) — and both must: -- 403 on mismatched subject type with a canonical-path hint -- emit `openapi.wrong_surface_denied` once with the right payload -- pass-through on match -- raise RuntimeError (not 403) if the auth ContextVar is unset — that's - a wiring bug, not a user-driven failure - -Identity is published via `libs.oauth_bearer.set_auth_ctx` / read with -`try_get_auth_ctx`. Tests wrap the publish in a `_publish_auth_ctx` -context manager so the ContextVar resets even when an assertion fails; -that keeps state from leaking into the next test on the same worker. -""" - -from __future__ import annotations - -import uuid -from collections.abc import Iterator -from contextlib import contextmanager -from datetime import UTC, datetime -from unittest.mock import patch - -import pytest -from flask import Flask -from werkzeug.exceptions import Forbidden - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import SurfaceCheck -from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface -from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx - - -@contextmanager -def _publish_auth_ctx(ctx: AuthContext) -> Iterator[None]: - token = set_auth_ctx(ctx) - try: - yield - finally: - reset_auth_ctx(token) - - -def _account_ctx() -> AuthContext: - return AuthContext( - subject_type=SubjectType.ACCOUNT, - subject_email="user@example.com", - subject_issuer="dify:account", - account_id=uuid.uuid4(), - client_id="difyctl", - scopes=frozenset({Scope.FULL}), - token_id=uuid.uuid4(), - source="oauth_account", - expires_at=datetime.now(UTC), - token_hash="h1", - verified_tenants={}, - ) - - -def _sso_ctx() -> AuthContext: - return AuthContext( - subject_type=SubjectType.EXTERNAL_SSO, - subject_email="sso@partner.com", - subject_issuer="https://idp.partner.com", - account_id=None, - client_id="difyctl", - scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}), - token_id=uuid.uuid4(), - source="oauth_external_sso", - expires_at=datetime.now(UTC), - token_hash="h2", - verified_tenants={}, - ) - - -# --------------------------------------------------------------------------- -# check_surface — shared core -# --------------------------------------------------------------------------- - - -def test_check_surface_passes_when_subject_in_accepted(): - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_account_ctx()): - check_surface(frozenset({SubjectType.ACCOUNT})) # no raise - - -def test_check_surface_rejects_on_wrong_subject_and_emits_audit(): - app = Flask(__name__) - with app.test_request_context("/openapi/v1/permitted-external-apps"), _publish_auth_ctx(_account_ctx()): - with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit: - with pytest.raises(Forbidden) as exc: - check_surface(frozenset({SubjectType.EXTERNAL_SSO})) - assert "wrong_surface" in exc.value.description - # canonical-path hint should point at the caller's surface, - # not the surface they were rejected from - assert "/openapi/v1/apps" in exc.value.description - emit.assert_called_once() - kwargs = emit.call_args.kwargs - assert kwargs["subject_type"] == SubjectType.ACCOUNT.value - assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps" - assert kwargs["client_id"] == "difyctl" - assert kwargs["token_id"] is not None - - -def test_check_surface_rejects_sso_on_account_surface(): - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_sso_ctx()): - with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit: - with pytest.raises(Forbidden): - check_surface(frozenset({SubjectType.ACCOUNT})) - kwargs = emit.call_args.kwargs - assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value - - -def test_check_surface_runtime_error_when_auth_ctx_missing(): - """Missing auth ContextVar means the bearer layer didn't run — wiring - bug, not a user-driven failure. Surface as RuntimeError (loud) so a - future refactor doesn't accidentally let a route skip authentication - and return a 403 that looks identical to a legitimate wrong-surface - deny. - """ - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps"): - with pytest.raises(RuntimeError): - check_surface(frozenset({SubjectType.ACCOUNT})) - - -# --------------------------------------------------------------------------- -# @accept_subjects — decorator form -# --------------------------------------------------------------------------- - - -def _make_app() -> Flask: - app = Flask(__name__) - - @app.route("/account-only") - @accept_subjects(SubjectType.ACCOUNT) - def _account_only(): - return "ok" - - @app.route("/external-only") - @accept_subjects(SubjectType.EXTERNAL_SSO) - def _external_only(): - return "ok" - - return app - - -def test_accept_subjects_decorator_passes_on_match(): - app = _make_app() - with app.test_request_context("/account-only"), _publish_auth_ctx(_account_ctx()): - # Re-route through the decorated function by reaching for view_function - view = app.view_functions["_account_only"] - assert view() == "ok" - - -def test_accept_subjects_decorator_403_on_miss(): - app = _make_app() - with app.test_request_context("/external-only"), _publish_auth_ctx(_account_ctx()): - view = app.view_functions["_external_only"] - with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"): - with pytest.raises(Forbidden): - view() - - -# --------------------------------------------------------------------------- -# SurfaceCheck — pipeline step form -# --------------------------------------------------------------------------- - - -def _pipeline_ctx() -> Context: - # SurfaceCheck reads ``request.path`` from Flask's global request — set up - # via ``app.test_request_context`` in the calling tests — not from Context. - return Context(required_scope=Scope.APPS_RUN) - - -def test_surface_check_passes_on_match(): - step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})) - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()): - step(_pipeline_ctx()) # no raise - - -def test_surface_check_rejects_on_miss_and_emits_audit(): - step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO})) - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()): - with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit: - with pytest.raises(Forbidden): - step(_pipeline_ctx()) - emit.assert_called_once() - - -# --------------------------------------------------------------------------- -# _coerce_subject_type — normalises whatever sat on ctx.subject_type -# --------------------------------------------------------------------------- -# -# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value -# could be a real enum (happy path), a raw string (e.g. rehydrated from a -# dict-shaped context), `None` (attribute missing), or something unexpected -# from a buggy upstream. The coercer must collapse all of that to -# `SubjectType | None` so `check_surface` can do a clean set-membership -# check and emit a clean audit payload. - - -def test_coerce_subject_type_returns_none_for_none(): - assert _coerce_subject_type(None) is None - - -def test_coerce_subject_type_returns_enum_instance_unchanged(): - # Identity matters: we don't want to round-trip through the string - # constructor for an already-valid enum. - assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT - assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO - - -@pytest.mark.parametrize( - ("raw", "expected"), - [ - ("account", SubjectType.ACCOUNT), - ("external_sso", SubjectType.EXTERNAL_SSO), - ], -) -def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType): - assert _coerce_subject_type(raw) is expected - - -def test_coerce_subject_type_raises_on_unknown_string(): - # Unknown strings reach `SubjectType(raw)` which raises ValueError. - # We surface that loudly rather than silently returning None, because - # a string that *looks* like a subject type but isn't is almost - # certainly an upstream bug worth catching. - with pytest.raises(ValueError): - _coerce_subject_type("not_a_subject") - - -@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}]) -def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object): - assert _coerce_subject_type(raw) is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_verify.py b/api/tests/unit_tests/controllers/openapi/auth/test_verify.py new file mode 100644 index 0000000000..c7e0cd7402 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_verify.py @@ -0,0 +1,142 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, Unauthorized + +from controllers.openapi.auth.data import AuthData +from controllers.openapi.auth.verify import ( + check_acl, + check_app_access, + check_membership, + check_private_app_permission, + check_scope, +) +from libs.oauth_bearer import Scope, TokenType +from models.account import Tenant +from models.model import App +from services.enterprise.enterprise_service import WebAppAccessMode + + +def _data(**kwargs) -> AuthData: + defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "hash", "scopes": frozenset({Scope.FULL})} + defaults.update(kwargs) + return AuthData(**defaults) + + +def test_check_scope_passes_when_required_is_none(): + check_scope(_data(required_scope=None)) + + +def test_check_scope_passes_when_full_in_scopes(): + check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.FULL}))) + + +def test_check_scope_passes_when_exact_scope_present(): + check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_RUN}))) + + +def test_check_scope_raises_forbidden_when_scope_missing(): + with pytest.raises(Forbidden, match="insufficient_scope"): + check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_READ}))) + + +def test_check_membership_raises_unauthorized_when_tenant_none(): + with pytest.raises(Unauthorized): + check_membership(_data(tenant=None)) + + +def test_check_membership_calls_check_workspace_membership(): + tenant = MagicMock(spec=Tenant) + tenant.id = "tenant-1" + data = _data( + account_id=uuid.uuid4(), + token_hash="myhash", + tenants={"tenant-1": True}, + tenant=tenant, + ) + with patch("controllers.openapi.auth.verify.check_workspace_membership") as mock_cwm: + check_membership(data) + mock_cwm.assert_called_once_with( + account_id=data.account_id, + tenant_id="tenant-1", + token_hash="myhash", + membership_cache=data.tenants, + ) + + +def test_check_app_access_passes_when_tenant_none(): + check_app_access(_data(tenant=None)) + + +def test_check_app_access_passes_when_member(): + tenant = MagicMock(spec=Tenant) + tenant.id = "t1" + data = _data(account_id=uuid.uuid4(), tenant=tenant) + with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=True): + check_app_access(data) + + +def test_check_app_access_raises_when_not_member(): + tenant = MagicMock(spec=Tenant) + tenant.id = "t1" + data = _data(account_id=uuid.uuid4(), tenant=tenant) + with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=False): + with pytest.raises(Forbidden, match="subject_no_app_access"): + check_app_access(data) + + +def test_check_acl_raises_when_app_or_mode_missing(): + with pytest.raises(Forbidden): + check_acl(_data(app=None, app_access_mode=None)) + + +def test_check_acl_account_allowed_for_public(): + app = MagicMock(spec=App) + data = _data(token_type=TokenType.OAUTH_ACCOUNT, app=app, app_access_mode=WebAppAccessMode.PUBLIC) + check_acl(data) + + +def test_check_acl_external_sso_blocked_for_private(): + app = MagicMock(spec=App) + data = _data( + token_type=TokenType.OAUTH_EXTERNAL_SSO, + app=app, + app_access_mode=WebAppAccessMode.PRIVATE, + ) + with pytest.raises(Forbidden, match="subject_not_allowed_for_access_mode"): + check_acl(data) + + +def test_check_acl_external_sso_allowed_for_sso_verified(): + app = MagicMock(spec=App) + data = _data( + token_type=TokenType.OAUTH_EXTERNAL_SSO, + app=app, + app_access_mode=WebAppAccessMode.SSO_VERIFIED, + ) + check_acl(data) + + +def test_check_private_app_permission_raises_when_app_none(): + with pytest.raises(Forbidden): + check_private_app_permission(_data(app=None)) + + +def test_check_private_app_permission_raises_when_user_not_allowed(): + app = MagicMock(spec=App) + app.id = "app-1" + data = _data(account_id=uuid.uuid4(), app=app) + target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp" + with patch(target, return_value=False): + with pytest.raises(Forbidden, match="user_not_allowed_for_private_app"): + check_private_app_permission(data) + + +def test_check_private_app_permission_passes_when_allowed(): + app = MagicMock(spec=App) + app.id = "app-1" + data = _data(account_id=uuid.uuid4(), app=app) + target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp" + with patch(target, return_value=True): + check_private_app_permission(data) diff --git a/api/tests/unit_tests/controllers/openapi/conftest.py b/api/tests/unit_tests/controllers/openapi/conftest.py index 38dae79a11..18b3b2fabf 100644 --- a/api/tests/unit_tests/controllers/openapi/conftest.py +++ b/api/tests/unit_tests/controllers/openapi/conftest.py @@ -1,20 +1,36 @@ +import uuid + import pytest from flask import Flask from controllers.openapi import bp as openapi_bp -from controllers.openapi.auth.pipeline import Pipeline +from controllers.openapi.auth.data import AuthData +from controllers.openapi.auth.pipeline import PipelineRouter +from libs.oauth_bearer import Scope, TokenType + + +def _stub_execute(self, args, kwargs, view, *, scope=None, allowed_token_types=None, edition=None): + """Bypass all auth logic; inject minimal AuthData and call the view directly.""" + kwargs["auth_data"] = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="test", + token_id=uuid.uuid4(), + scopes=frozenset({Scope.FULL}), + required_scope=scope, + ) + return view(*args, **kwargs) @pytest.fixture def bypass_pipeline(monkeypatch): - """Stub Pipeline.run so endpoint decoration does not invoke real auth. + """Stub PipelineRouter._execute so endpoints skip real auth at request time. - Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real - pipeline at import time; mocking the module attribute does not undo - that. Patching Pipeline.run on the class is the bypass that actually - works. + Module-level @auth_router.guard(...) captures the real router at import + time — patching guard itself does nothing. Patching _execute on the class + is the seam that fires at request time. """ - monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None) + monkeypatch.setattr(PipelineRouter, "_execute", _stub_execute) @pytest.fixture diff --git a/api/tests/unit_tests/controllers/openapi/test_account.py b/api/tests/unit_tests/controllers/openapi/test_account.py index 15624305a3..f73dc5c0cc 100644 --- a/api/tests/unit_tests/controllers/openapi/test_account.py +++ b/api/tests/unit_tests/controllers/openapi/test_account.py @@ -86,7 +86,7 @@ def test_subject_match_for_account_filters_by_account_id(): """Account subject scopes queries via account_id.""" import uuid as _uuid - from libs.oauth_bearer import AuthContext, SubjectType + from libs.oauth_bearer import AuthContext, SubjectType, TokenType from services.oauth_device_flow import subject_match_clauses aid = _uuid.uuid4() @@ -98,7 +98,7 @@ def test_subject_match_for_account_filters_by_account_id(): client_id="difyctl", scopes=frozenset({"full"}), token_id=_uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, expires_at=None, token_hash="h1", verified_tenants={}, @@ -116,7 +116,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer(): """ import uuid as _uuid - from libs.oauth_bearer import AuthContext, SubjectType + from libs.oauth_bearer import AuthContext, SubjectType, TokenType from services.oauth_device_flow import subject_match_clauses ctx = AuthContext( @@ -127,7 +127,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer(): client_id="difyctl", scopes=frozenset({"apps:run"}), token_id=_uuid.uuid4(), - source="oauth_external_sso", + token_type=TokenType.OAUTH_EXTERNAL_SSO, expires_at=None, token_hash="h1", verified_tenants={}, diff --git a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py index 8db5033704..8933533af0 100644 --- a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py +++ b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py @@ -57,7 +57,11 @@ def test_stop_task_endpoint_registered(openapi_app): def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch): + import uuid + from controllers.openapi.app_run import AppRunTaskStopApi + from controllers.openapi.auth.data import AuthData + from libs.oauth_bearer import Scope, TokenType queue_mock = Mock() graph_mock = Mock() @@ -69,15 +73,23 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo monkeypatch.setattr(run_module, "GraphEngineManager", graph_mock) monkeypatch.setattr(run_module, "redis_client", object()) + auth_data = AuthData.model_construct( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="test", + scopes=frozenset({Scope.FULL}), + app=SimpleNamespace(id="app-1", tenant_id="t-1"), + caller=SimpleNamespace(id="acct-1"), + caller_kind="account", + ) + api = AppRunTaskStopApi() with app.test_request_context("/openapi/v1/apps/app-1/tasks/task-1/stop", method="POST"): result = api.post.__wrapped__( api, app_id="app-1", task_id="task-1", - app_model=SimpleNamespace(id="app-1", tenant_id="t-1"), - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=auth_data, ) queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1") diff --git a/api/tests/unit_tests/controllers/openapi/test_human_input_form.py b/api/tests/unit_tests/controllers/openapi/test_human_input_form.py index 42ecfc5eb2..52fd0f89d5 100644 --- a/api/tests/unit_tests/controllers/openapi/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/openapi/test_human_input_form.py @@ -4,6 +4,7 @@ from __future__ import annotations import json import sys +import uuid from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock @@ -11,9 +12,23 @@ from unittest.mock import Mock import pytest from werkzeug.exceptions import NotFound +from controllers.openapi.auth.data import AuthData +from libs.oauth_bearer import Scope, TokenType from models.human_input import RecipientType +def _make_auth_data(app_model, caller, caller_kind): + return AuthData.model_construct( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="test", + scopes=frozenset({Scope.FULL}), + app=app_model, + caller=caller, + caller_kind=caller_kind, + ) + + class TestOpenApiHumanInputFormGet: def test_get_success(self, app, bypass_pipeline, monkeypatch): from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi @@ -43,15 +58,14 @@ class TestOpenApiHumanInputFormGet: api = OpenApiWorkflowHumanInputFormApi() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"): resp = api.get.__wrapped__( api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) payload = json.loads(resp.get_data(as_text=True)) @@ -71,6 +85,7 @@ class TestOpenApiHumanInputFormGet: api = OpenApiWorkflowHumanInputFormApi() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"): with pytest.raises(NotFound): @@ -78,9 +93,7 @@ class TestOpenApiHumanInputFormGet: api, app_id="app-1", form_token="bad", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch): @@ -97,6 +110,7 @@ class TestOpenApiHumanInputFormGet: api = OpenApiWorkflowHumanInputFormApi() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"): with pytest.raises(NotFound): @@ -104,9 +118,7 @@ class TestOpenApiHumanInputFormGet: api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch): @@ -126,6 +138,7 @@ class TestOpenApiHumanInputFormGet: api = OpenApiWorkflowHumanInputFormApi() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"): with pytest.raises(NotFound): @@ -133,9 +146,7 @@ class TestOpenApiHumanInputFormGet: api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) @@ -172,9 +183,7 @@ class TestOpenApiHumanInputFormPost: api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=caller, - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) service_mock.submit_form_by_token.assert_called_once_with( @@ -211,9 +220,7 @@ class TestOpenApiHumanInputFormPost: api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=caller, - caller_kind="end_user", + auth_data=_make_auth_data(app_model, caller, "end_user"), ) service_mock.submit_form_by_token.assert_called_once_with( diff --git a/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py b/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py index 78b85460b3..78f2d0f20d 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py +++ b/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py @@ -3,15 +3,30 @@ from __future__ import annotations import sys +import uuid from types import SimpleNamespace from unittest.mock import Mock import pytest from werkzeug.exceptions import NotFound +from controllers.openapi.auth.data import AuthData +from libs.oauth_bearer import Scope, TokenType from models.enums import CreatorUserRole +def _make_auth_data(app_model, caller, caller_kind): + return AuthData.model_construct( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="test", + scopes=frozenset({Scope.FULL}), + app=app_model, + caller=caller, + caller_kind=caller_kind, + ) + + def _make_workflow_run( *, app_id="app-1", @@ -50,6 +65,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): with pytest.raises(NotFound): @@ -57,9 +73,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch): @@ -77,6 +91,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): with pytest.raises(NotFound): @@ -84,9 +99,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch): @@ -115,6 +128,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") api = self._get_api() with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): @@ -123,9 +137,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) assert resp.mimetype == "text/event-stream" @@ -143,6 +155,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") api = self._get_api() with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): @@ -151,9 +164,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch): @@ -179,6 +190,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="eu-1") api = self._get_api() with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): @@ -186,9 +198,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="eu-1"), - caller_kind="end_user", + auth_data=_make_auth_data(app_model, caller, "end_user"), ) assert resp.mimetype == "text/event-stream" @@ -222,6 +232,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") api = self._get_api() with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): @@ -229,9 +240,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) assert resp.mimetype == "text/event-stream" chunks = list(resp.response) diff --git a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py new file mode 100644 index 0000000000..6e32487348 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py @@ -0,0 +1,928 @@ +"""Member endpoints under /openapi/v1/workspaces//... + +Coverage: +- Route registration (5 endpoints across 4 URL patterns) +- Body validation lands at 400 (per spec — not Pydantic's default 422) +- Domain exception → HTTP code mapping is preserved with the service's + original message (so CLI users see what the console user sees) +- Response shape matches the Pydantic models + +Auth-pipeline plumbing is bypassed via the `bypass_pipeline` fixture from +conftest.py; the bearer identity is seeded into the openapi auth ContextVar +via `_seed` (the slot `validate_bearer` publishes), and the role gate's DB +lookup is mocked. Tests that exercise endpoint *bodies* skip the decorators +via ``__wrapped__`` since those layers are covered in `auth/test_role_gate.py`. +""" + +from __future__ import annotations + +import builtins +import json +import sys +import uuid +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock + +import pytest +from flask import Flask +from flask.views import MethodView +from pydantic import ValidationError +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +from controllers.openapi import bp as openapi_bp +from controllers.openapi._models import MemberInvitePayload, MemberRoleUpdatePayload +from controllers.openapi.workspaces import ( + WorkspaceMemberApi, + WorkspaceMemberRoleApi, + WorkspaceMembersApi, + WorkspaceSwitchApi, +) +from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx +from models.account import AccountStatus, TenantAccountRole +from services.errors.account import ( + AccountAlreadyInTenantError, + AccountNotLinkTenantError, + AccountRegisterError, + CannotOperateSelfError, + MemberNotInTenantError, + NoPermissionError, + RoleAlreadyAssignedError, +) + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +# Tokens from `_seed`'s `set_auth_ctx` calls, drained after each test so a +# published identity can't leak into the next (the ContextVar is module-global +# and worker threads are reused). Seed via `_seed(...)`, never `flask.g` — +# production fills the ContextVar, nothing fills `g.auth_ctx`. +_seed_tokens: list = [] + + +def _seed(ctx: AuthContext) -> None: + _seed_tokens.append(set_auth_ctx(ctx)) + + +@pytest.fixture(autouse=True) +def _reset_auth_ctx(): + yield + while _seed_tokens: + reset_auth_ctx(_seed_tokens.pop()) + + +@pytest.fixture +def openapi_app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + app.register_blueprint(openapi_bp) + return app + + +def _rule(app: Flask, path: str): + return next(r for r in app.url_map.iter_rules() if r.rule == path) + + +def _auth_ctx(account_id: uuid.UUID | None = None) -> AuthContext: + return AuthContext( + subject_type=SubjectType.ACCOUNT, + subject_email="caller@example.com", + subject_issuer="dify:account", + account_id=account_id or uuid.uuid4(), + client_id="difyctl", + scopes=frozenset({Scope.FULL}), + token_id=uuid.uuid4(), + token_type=TokenType.OAUTH_ACCOUNT, + expires_at=datetime.now(UTC), + token_hash="h", + verified_tenants={}, + ) + + +def _auth_data(account_id: uuid.UUID) -> AuthData: + from controllers.openapi.auth.data import AuthData + from libs.oauth_bearer import Scope, TokenType + + return AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=account_id, + token_hash="testhash", + scopes=frozenset({Scope.FULL}), + ) + + +def _account(account_id: str = "acct-1", email: str = "u@example.com") -> SimpleNamespace: + return SimpleNamespace( + id=account_id, + name="User", + email=email, + status=AccountStatus.ACTIVE, + avatar=None, + ) + + +def _tenant(tenant_id: str = "ws-1") -> SimpleNamespace: + return SimpleNamespace( + id=tenant_id, + name="WS", + status="normal", + created_at=datetime(2026, 5, 18, tzinfo=UTC), + ) + + +def _tenant_service(**overrides) -> SimpleNamespace: + """TenantService double for the workspaces module. + + Read getters (`get_tenant_by_id`, `find_workspace_for_account`) delegate + to the session they're handed, so tests keep driving entity loads through + ``mock_db.session.get`` / ``.execute`` and their existing side_effect + ordering — the SQL those methods run is covered in test_account_service.py. + Domain mutators default to no-op Mocks; override per test as needed. + """ + methods: dict = { + "switch_tenant": Mock(), + "get_tenant_members": Mock(return_value=[]), + "remove_member_from_tenant": Mock(), + "update_member_role": Mock(), + "get_tenant_by_id": lambda session, tenant_id: session.get(None, tenant_id), + "find_workspace_for_account": lambda session, account_id, workspace_id: session.execute(None).first(), + } + methods.update(overrides) + return SimpleNamespace(**methods) + + +def _account_service(**overrides) -> SimpleNamespace: + """AccountService double; ``get_account_by_id`` delegates to the injected + session (see :func:`_tenant_service`).""" + methods: dict = { + "get_account_by_id": lambda session, account_id: session.get(None, account_id), + } + methods.update(overrides) + return SimpleNamespace(**methods) + + +# --------------------------------------------------------------------------- +# Route registration +# --------------------------------------------------------------------------- + + +def test_switch_route_registered(openapi_app: Flask): + rule = _rule(openapi_app, "/openapi/v1/workspaces//switch") + assert openapi_app.view_functions[rule.endpoint].view_class is WorkspaceSwitchApi + assert "POST" in rule.methods + + +def test_members_route_registered(openapi_app: Flask): + rule = _rule(openapi_app, "/openapi/v1/workspaces//members") + assert openapi_app.view_functions[rule.endpoint].view_class is WorkspaceMembersApi + assert "GET" in rule.methods + assert "POST" in rule.methods + + +def test_member_by_id_route_registered(openapi_app: Flask): + rule = _rule(openapi_app, "/openapi/v1/workspaces//members/") + assert openapi_app.view_functions[rule.endpoint].view_class is WorkspaceMemberApi + assert "DELETE" in rule.methods + + +def test_member_role_route_registered(openapi_app: Flask): + rule = _rule(openapi_app, "/openapi/v1/workspaces//members//role") + assert openapi_app.view_functions[rule.endpoint].view_class is WorkspaceMemberRoleApi + assert "PUT" in rule.methods + + +# --------------------------------------------------------------------------- +# Payload validation lands at 400 +# --------------------------------------------------------------------------- + + +def test_invite_payload_rejects_unknown_role(): + with pytest.raises(ValidationError): + MemberInvitePayload.model_validate({"email": "u@example.com", "role": "owner"}) + + +def test_invite_payload_rejects_bad_email(): + with pytest.raises(ValidationError): + MemberInvitePayload.model_validate({"email": "not-an-email", "role": "normal"}) + + +def test_invite_payload_rejects_extra_field(): + with pytest.raises(ValidationError): + MemberInvitePayload.model_validate({"email": "u@example.com", "role": "normal", "extra": "x"}) + + +def test_role_payload_rejects_owner(): + with pytest.raises(ValidationError): + MemberRoleUpdatePayload.model_validate({"role": "owner"}) + + +def test_role_payload_rejects_extra_field(): + with pytest.raises(ValidationError): + MemberRoleUpdatePayload.model_validate({"role": "normal", "extra": "x"}) + + +def test_validate_body_helper_maps_validation_error_to_400(app, monkeypatch): + """`_validate_body` is the centralized 400-mapper for invalid request bodies.""" + from controllers.openapi.workspaces import _validate_body + + with app.test_request_context( + "/openapi/v1/workspaces/ws-1/members", + method="POST", + data=json.dumps({"email": "u@example.com", "role": "owner"}), + content_type="application/json", + ): + with pytest.raises(BadRequest): + _validate_body(MemberInvitePayload) + + +# --------------------------------------------------------------------------- +# Switch endpoint behavior +# --------------------------------------------------------------------------- + + +def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline, monkeypatch): + """Happy path: switch service is called, then the workspace+membership + row is re-queried so the returned `current` reflects post-commit state. + """ + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceSwitchApi() + + mock_db = MagicMock() + mock_db.session.get.return_value = _account(account_id=str(acct_id)) + membership = SimpleNamespace(role=TenantAccountRole.OWNER, current=True) + mock_db.session.execute.return_value.first.return_value = (_tenant(ws_id), membership) + + switch_mock = Mock() + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(switch_tenant=switch_mock), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"): + _seed(_auth_ctx(account_id=acct_id)) + body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + assert status == 200 + assert body["id"] == ws_id + assert body["current"] is True + assert switch_mock.called + + +def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pipeline, monkeypatch): + """If switch_tenant raises (e.g. Tenant.status != NORMAL), the body + surfaces as NotFound, not 500.""" + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceSwitchApi() + + mock_db = MagicMock() + mock_db.session.get.return_value = _account(account_id=str(acct_id)) + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(switch_tenant=Mock(side_effect=AccountNotLinkTenantError("…"))), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(NotFound): + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + +# --------------------------------------------------------------------------- +# Members list +# --------------------------------------------------------------------------- + + +def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch): + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + member = SimpleNamespace( + id="m-1", + name="Mia", + email="mia@example.com", + status=AccountStatus.ACTIVE, + avatar=None, + role=TenantAccountRole.ADMIN, + ) + + mock_db = MagicMock() + mock_db.session.get.return_value = _tenant(ws_id) + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(get_tenant_members=Mock(return_value=[member])), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"): + _seed(_auth_ctx(account_id=acct_id)) + body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + assert status == 200 + assert body["page"] == 1 + assert body["limit"] == 20 + assert body["total"] == 1 + assert body["has_more"] is False + assert body["data"][0]["email"] == "mia@example.com" + assert body["data"][0]["role"] == "admin" + assert body["data"][0]["status"] == "active" + + +def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypatch): + """`?page=2&limit=2` slices service output and reports total/has_more.""" + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + members = [ + SimpleNamespace( + id=f"m-{i}", + name=f"User {i}", + email=f"u{i}@example.com", + status=AccountStatus.ACTIVE, + avatar=None, + role=TenantAccountRole.NORMAL, + ) + for i in range(5) + ] + + mock_db = MagicMock() + mock_db.session.get.return_value = _tenant(ws_id) + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(get_tenant_members=Mock(return_value=members)), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?page=2&limit=2"): + _seed(_auth_ctx(account_id=acct_id)) + body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + assert status == 200 + assert body["page"] == 2 + assert body["limit"] == 2 + assert body["total"] == 5 + assert body["has_more"] is True + assert [d["id"] for d in body["data"]] == ["m-2", "m-3"] + + +def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypatch): + """Strict (`extra='forbid'`) — typos like `?pg=2` surface as 400.""" + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + mock_db = MagicMock() + mock_db.session.get.return_value = _tenant(ws_id) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(BadRequest): + api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + +# --------------------------------------------------------------------------- +# Invite endpoint +# --------------------------------------------------------------------------- + + +def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline, monkeypatch): + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + invited = _account(account_id="new-1", email="new@example.com") + + mock_db = MagicMock() + # session.get is called twice: once for inviter Account, once for Tenant + mock_db.session.get.side_effect = [_account(account_id=str(acct_id)), _tenant(ws_id)] + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "RegisterService", + SimpleNamespace(invite_new_member=Mock(return_value="tok-123")), + ) + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "AccountService", + _account_service(get_account_by_email_with_case_fallback=Mock(return_value=invited)), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members", + method="POST", + data=json.dumps({"email": "NEW@example.com", "role": "normal"}), + content_type="application/json", + ): + _seed(_auth_ctx(account_id=acct_id)) + body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + assert status == 201 + assert body["result"] == "success" + assert body["email"] == "new@example.com" + assert body["role"] == "normal" + assert body["member_id"] == "new-1" + assert "token=tok-123" in body["invite_url"] + assert "email=new%40example.com" in body["invite_url"] + assert body["tenant_id"] == ws_id + + +def _features( + *, + billing_enabled: bool = False, + members_size: int = 0, + members_limit: int = 0, + workspace_members_enabled: bool = False, + workspace_members_size: int = 0, + workspace_members_limit: int = 0, +) -> SimpleNamespace: + """Build a feature object matching the surface `_check_member_invite_quota` + reads: `.billing.enabled`, `.members.{size,limit}`, + `.workspace_members.{enabled, is_available(N)}`. + + Defaults model CE (both flags off, both caps inert). + """ + + def _is_available(n: int) -> bool: + return workspace_members_size + n <= workspace_members_limit + + return SimpleNamespace( + billing=SimpleNamespace(enabled=billing_enabled), + members=SimpleNamespace(size=members_size, limit=members_limit), + workspace_members=SimpleNamespace( + enabled=workspace_members_enabled, + size=workspace_members_size, + limit=workspace_members_limit, + is_available=_is_available, + ), + ) + + +def _invite_request(app, ws_id: str, acct_id: uuid.UUID): + return app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members", + method="POST", + data=json.dumps({"email": "new@example.com", "role": "normal"}), + content_type="application/json", + ) + + +def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch): + """SaaS billing plan member cap → 403 with `members.limit_exceeded`. + + Verifies the envelope shape the CLI error-mapper relies on (code + + message + hint on the wire body). + """ + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [_account(account_id=str(acct_id)), _tenant(ws_id)] + + invite_mock = Mock() + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "RegisterService", + SimpleNamespace(invite_new_member=invite_mock), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "FeatureService", + SimpleNamespace( + get_features=Mock( + return_value=_features(billing_enabled=True, members_size=10, members_limit=10), + ), + ), + ) + + with _invite_request(app, ws_id, acct_id): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(Forbidden) as exc_info: + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + body = exc_info.value.response.json + assert body["code"] == "members.limit_exceeded" + assert "Subscription member limit" in body["message"] + assert body["hint"] + invite_mock.assert_not_called() + + +def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, monkeypatch): + """EE License workspace_members cap → 403 with `workspace_members.license_exceeded`. + + Note: billing.enabled is False (EE without SaaS billing); only the + license cap fires. + """ + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [_account(account_id=str(acct_id)), _tenant(ws_id)] + + invite_mock = Mock() + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "RegisterService", + SimpleNamespace(invite_new_member=invite_mock), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "FeatureService", + SimpleNamespace( + get_features=Mock( + return_value=_features( + workspace_members_enabled=True, + workspace_members_size=5, + workspace_members_limit=5, + ), + ), + ), + ) + + with _invite_request(app, ws_id, acct_id): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(Forbidden) as exc_info: + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + body = exc_info.value.response.json + assert body["code"] == "workspace_members.license_exceeded" + assert "license" in body["message"].lower() + assert body["hint"] + invite_mock.assert_not_called() + + +def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypatch): + """CE deployment (no billing, no license) → quota gate is a no-op, + invite proceeds normally.""" + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + invited = _account(account_id="new-1", email="new@example.com") + mock_db = MagicMock() + mock_db.session.get.side_effect = [_account(account_id=str(acct_id)), _tenant(ws_id)] + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "RegisterService", + SimpleNamespace(invite_new_member=Mock(return_value="tok-ce")), + ) + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "AccountService", + _account_service(get_account_by_email_with_case_fallback=Mock(return_value=invited)), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "FeatureService", + SimpleNamespace(get_features=Mock(return_value=_features())), # all defaults + ) + + with _invite_request(app, ws_id, acct_id): + _seed(_auth_ctx(account_id=acct_id)) + body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + assert status == 201 + assert body["email"] == "new@example.com" + + +def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch): + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [_account(account_id=str(acct_id)), _tenant(ws_id)] + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "RegisterService", + SimpleNamespace(invite_new_member=Mock(side_effect=AccountAlreadyInTenantError("already in tenant"))), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members", + method="POST", + data=json.dumps({"email": "u@example.com", "role": "normal"}), + content_type="application/json", + ): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(BadRequest): + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + +# --------------------------------------------------------------------------- +# Delete member +# --------------------------------------------------------------------------- + + +def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch): + ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMemberApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [ + _account(account_id=str(acct_id)), # operator + _tenant(ws_id), # tenant + _account(account_id=member_id), # target member + ] + + remove_mock = Mock() + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(remove_member_from_tenant=remove_mock), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members/{member_id}", + method="DELETE", + ): + _seed(_auth_ctx(account_id=acct_id)) + body, status = api.delete.__wrapped__.__wrapped__( + api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id) + ) + + assert status == 200 + assert body == {"result": "success"} + assert remove_mock.called + + +@pytest.mark.parametrize( + ("exc", "expected"), + [ + (CannotOperateSelfError("cannot operate self"), BadRequest), + (NoPermissionError("no permission"), BadRequest), + (MemberNotInTenantError("not in tenant"), NotFound), + ], +) +def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc, expected): + ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMemberApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [ + _account(account_id=str(acct_id)), + _tenant(ws_id), + _account(account_id=member_id), + ] + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(remove_member_from_tenant=Mock(side_effect=exc)), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members/{member_id}", + method="DELETE", + ): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(expected): + api.delete.__wrapped__.__wrapped__( + api, + workspace_id=ws_id, + member_id=member_id, + auth_data=_auth_data(acct_id), + ) + + +def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch): + ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMemberApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [ + _account(account_id=str(acct_id)), + _tenant(ws_id), + None, # member not found + ] + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members/{member_id}", + method="DELETE", + ): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(NotFound): + api.delete.__wrapped__.__wrapped__( + api, + workspace_id=ws_id, + member_id=member_id, + auth_data=_auth_data(acct_id), + ) + + +# --------------------------------------------------------------------------- +# Update role +# --------------------------------------------------------------------------- + + +def test_update_role_happy_path(app, bypass_pipeline, monkeypatch): + ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMemberRoleApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [ + _account(account_id=str(acct_id)), + _tenant(ws_id), + _account(account_id=member_id), + ] + + update_mock = Mock() + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(update_member_role=update_mock), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members/{member_id}/role", + method="PUT", + data=json.dumps({"role": "admin"}), + content_type="application/json", + ): + _seed(_auth_ctx(account_id=acct_id)) + body, status = api.put.__wrapped__.__wrapped__( + api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id) + ) + + assert status == 200 + assert body == {"result": "success"} + args = update_mock.call_args.args + assert args[2] == "admin" + + +@pytest.mark.parametrize( + ("exc", "expected"), + [ + (CannotOperateSelfError("cannot operate self"), BadRequest), + (NoPermissionError("no permission"), BadRequest), + (RoleAlreadyAssignedError("already"), BadRequest), + (MemberNotInTenantError("not in tenant"), NotFound), + ], +) +def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, expected): + ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMemberRoleApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [ + _account(account_id=str(acct_id)), + _tenant(ws_id), + _account(account_id=member_id), + ] + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(update_member_role=Mock(side_effect=exc)), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members/{member_id}/role", + method="PUT", + data=json.dumps({"role": "admin"}), + content_type="application/json", + ): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(expected): + api.put.__wrapped__.__wrapped__( + api, + workspace_id=ws_id, + member_id=member_id, + auth_data=_auth_data(acct_id), + ) + + +# --------------------------------------------------------------------------- +# Role gate composition — non-member sees 404 even with valid bearer +# --------------------------------------------------------------------------- + + +def test_non_member_caller_gets_404_on_switch(app, bypass_pipeline, monkeypatch): + """End-to-end: caller has valid account bearer but no membership in + the requested workspace. The role gate must short-circuit to 404 + before any TenantService method is touched.""" + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceSwitchApi() + + mock_db = MagicMock() + mock_db.session.execute.return_value.scalar_one_or_none.return_value = None + + switch_mock = Mock() + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(switch_tenant=switch_mock), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + monkeypatch.setattr(sys.modules["controllers.openapi.auth.role_gate"], "db", mock_db) + + with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"): + _seed(_auth_ctx(account_id=acct_id)) + # Strip only the bearer + surface-gate wrappers; keep the role gate. + # Decorator stack (innermost → outermost): + # role_gate → accept_subjects → validate_bearer + # `post.__wrapped__` is now the role-gate wrapper directly (auth_router.guard is the only outer wrapper). + gated = api.post.__wrapped__ + with pytest.raises(NotFound): + gated(api, workspace_id=ws_id) + + switch_mock.assert_not_called() + + +# --------------------------------------------------------------------------- +# _load_tenant rejects archived tenant +# --------------------------------------------------------------------------- + + +def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatch): + """Member management against an archived workspace → 404.""" + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + archived = SimpleNamespace(id=ws_id, name="WS", status="archive", created_at=datetime(2026, 5, 18, tzinfo=UTC)) + mock_db = MagicMock() + mock_db.session.get.return_value = archived + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "TenantService", + _tenant_service(), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(NotFound): + api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + +# --------------------------------------------------------------------------- +# Invite catches AccountRegisterError +# --------------------------------------------------------------------------- + + +def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch): + """AccountRegisterError (frozen email, workspace creation blocked) → 400.""" + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() + + mock_db = MagicMock() + mock_db.session.get.side_effect = [_account(account_id=str(acct_id)), _tenant(ws_id)] + + monkeypatch.setattr( + sys.modules["controllers.openapi.workspaces"], + "RegisterService", + SimpleNamespace( + invite_new_member=Mock(side_effect=AccountRegisterError("Workspace is not allowed to create.")), + ), + ) + monkeypatch.setattr(sys.modules["controllers.openapi.workspaces"], "db", mock_db) + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members", + method="POST", + data=json.dumps({"email": "frozen@example.com", "role": "normal"}), + content_type="application/json", + ): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(BadRequest): + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index fe8fc02548..2e1051ab6b 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -872,6 +872,11 @@ class TestSegmentApiPost: mock_features.billing.enabled = False mock_feature_svc.get_features.return_value = mock_features + mock_vector_space = Mock() + mock_vector_space.limit = 10 + mock_vector_space.size = 0 + mock_feature_svc.get_vector_space.return_value = mock_vector_space + mock_rate_limit = Mock() mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit @@ -1209,6 +1214,10 @@ class TestDatasetSegmentApiUpdate: mock_features = Mock() mock_features.billing.enabled = False mock_feature_svc.get_features.return_value = mock_features + mock_vector_space = Mock() + mock_vector_space.limit = 10 + mock_vector_space.size = 0 + mock_feature_svc.get_vector_space.return_value = mock_vector_space mock_rate_limit = Mock() mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit @@ -1710,6 +1719,10 @@ class TestChildChunkApiPost: mock_features = Mock() mock_features.billing.enabled = False mock_feature_svc.get_features.return_value = mock_features + mock_vector_space = Mock() + mock_vector_space.limit = 10 + mock_vector_space.size = 0 + mock_feature_svc.get_vector_space.return_value = mock_vector_space mock_rate_limit = Mock() mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 61ec397193..2185e65326 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -950,7 +950,8 @@ class TestDocumentAddByTextApi: """Configure mocks to neutralise billing/auth decorators. ``cloud_edition_billing_resource_check`` calls - ``FeatureService.get_features`` and + ``FeatureService.get_vector_space`` for vector-space checks and + ``FeatureService.get_features`` for other resource checks. ``cloud_edition_billing_rate_limit_check`` calls ``FeatureService.get_knowledge_rate_limit``. Both call ``validate_and_get_api_token`` first. @@ -963,6 +964,11 @@ class TestDocumentAddByTextApi: mock_features.billing.enabled = False mock_feature_svc.get_features.return_value = mock_features + mock_vector_space = Mock() + mock_vector_space.limit = 10 + mock_vector_space.size = 0 + mock_feature_svc.get_vector_space.return_value = mock_vector_space + mock_rate_limit = Mock() mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit @@ -1140,6 +1146,10 @@ def _setup_billing_mocks(mock_validate_token, mock_feature_svc, tenant_id: str): mock_features = Mock() mock_features.billing.enabled = False mock_feature_svc.get_features.return_value = mock_features + mock_vector_space = Mock() + mock_vector_space.limit = 10 + mock_vector_space.size = 0 + mock_feature_svc.get_vector_space.return_value = mock_vector_space mock_rate_limit = Mock() mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index 38fcb55fc0..4809cc0e8a 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -24,12 +24,53 @@ from werkzeug.exceptions import Forbidden, NotFound import services from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload from models.account import Account +from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import RetrievalModel # --------------------------------------------------------------------------- # HitTestingPayload Model Tests # --------------------------------------------------------------------------- +def hit_testing_record() -> dict[str, object]: + return { + "segment": { + "id": "segment-1", + "position": 1, + "document_id": "document-1", + "content": "Chunk text", + "sign_content": "Chunk text", + "answer": None, + "word_count": 2, + "tokens": 3, + "keywords": None, + "index_node_id": None, + "index_node_hash": None, + "hit_count": 0, + "enabled": True, + "disabled_at": None, + "disabled_by": None, + "status": "completed", + "created_by": "account-1", + "created_at": 1_700_000_000, + "indexing_at": None, + "completed_at": None, + "error": None, + "stopped_at": None, + "document": { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": None, + }, + }, + "child_chunks": None, + "files": None, + "score": 0.9, + } + + class TestHitTestingPayload: """Test suite for HitTestingPayload Pydantic model.""" @@ -48,7 +89,7 @@ class TestHitTestingPayload: } payload = HitTestingPayload( query="test query", - retrieval_model=retrieval_model_data, + retrieval_model=RetrievalModel.model_validate(retrieval_model_data), external_retrieval_model={"provider": "openai"}, attachment_ids=["att_1", "att_2"], ) @@ -68,6 +109,12 @@ class TestHitTestingPayload: payload = HitTestingPayload(query="x" * 250) assert len(payload.query) == 250 + def test_payload_ignores_unknown_fields_for_compatibility(self): + """Top-level fields outside the documented schema remain ignored as before.""" + payload = HitTestingPayload.model_validate({"query": "test query", "top_k": 3}) + + assert payload.model_dump(exclude_none=True) == {"query": "test query"} + # --------------------------------------------------------------------------- # HitTestingApi Tests @@ -80,8 +127,11 @@ class TestHitTestingPayload: class TestHitTestingApiPost: """Tests for HitTestingApi.post() via __wrapped__ to skip billing decorator.""" + @staticmethod + def _dataset(dataset_id: str, tenant_id: str) -> Dataset: + return Dataset(id=dataset_id, tenant_id=tenant_id, name="Dataset", created_by="account-1") + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") - @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) @@ -90,7 +140,6 @@ class TestHitTestingApiPost: mock_current_user, mock_dataset_svc, mock_hit_svc, - mock_marshal, mock_ns, app: Flask, ): @@ -98,15 +147,13 @@ class TestHitTestingApiPost: dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None mock_hit_svc.retrieve.return_value = {"query": {"content": "test query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None - mock_marshal.return_value = [] mock_ns.payload = {"query": "test query"} @@ -115,11 +162,10 @@ class TestHitTestingApiPost: # Skip billing decorator via __wrapped__ response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) - assert response["query"] == "test query" + assert response["query"] == {"content": "test query"} mock_hit_svc.retrieve.assert_called_once() @patch("controllers.service_api.dataset.hit_testing.service_api_ns") - @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) @@ -128,7 +174,6 @@ class TestHitTestingApiPost: mock_current_user, mock_dataset_svc, mock_hit_svc, - mock_marshal, mock_ns, app: Flask, ): @@ -136,8 +181,7 @@ class TestHitTestingApiPost: dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None @@ -152,7 +196,6 @@ class TestHitTestingApiPost: mock_hit_svc.retrieve.return_value = {"query": {"content": "complex query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None - mock_marshal.return_value = [] mock_ns.payload = { "query": "complex query", @@ -164,7 +207,7 @@ class TestHitTestingApiPost: api = HitTestingApi() response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) - assert response["query"] == "complex query" + assert response["query"] == {"content": "complex query"} call_kwargs = mock_hit_svc.retrieve.call_args # retrieval_model is serialized via model_dump, verify key fields passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model") @@ -173,7 +216,6 @@ class TestHitTestingApiPost: assert passed_retrieval_model["top_k"] == 10 @patch("controllers.service_api.dataset.hit_testing.service_api_ns") - @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) @@ -182,7 +224,6 @@ class TestHitTestingApiPost: mock_current_user, mock_dataset_svc, mock_hit_svc, - mock_marshal, mock_ns, app: Flask, ): @@ -190,14 +231,12 @@ class TestHitTestingApiPost: dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None mock_hit_svc.retrieve.return_value = {"query": {"content": "filtered query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None - mock_marshal.return_value = [] metadata_filtering_conditions = { "logical_operator": "and", @@ -229,7 +268,6 @@ class TestHitTestingApiPost: assert passed_retrieval_model["metadata_filtering_conditions"] == metadata_filtering_conditions @patch("controllers.service_api.dataset.hit_testing.service_api_ns") - @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) @@ -238,30 +276,23 @@ class TestHitTestingApiPost: mock_current_user, mock_dataset_svc, mock_hit_svc, - mock_marshal, mock_ns, app: Flask, ): - """Test service API prepares nullable list fields from marshalled records.""" + """Test service API prepares nullable list fields from retrieval records.""" dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None - mock_hit_svc.retrieve.return_value = {"query": {"content": "legacy query"}, "records": ["placeholder"]} + mock_hit_svc.retrieve.return_value = { + "query": {"content": "legacy query"}, + "records": [hit_testing_record()], + } mock_hit_svc.hit_testing_args_check.return_value = None - mock_marshal.return_value = [ - { - "segment": {"id": "segment-1", "keywords": None}, - "child_chunks": None, - "files": None, - "score": 0.9, - } - ] mock_ns.payload = {"query": "legacy query"} @@ -269,15 +300,15 @@ class TestHitTestingApiPost: api = HitTestingApi() response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) - assert response["query"] == "legacy query" - assert response["records"] == [ - { - "segment": {"id": "segment-1", "keywords": []}, - "child_chunks": [], - "files": [], - "score": 0.9, - } - ] + assert response["query"] == {"content": "legacy query"} + record = response["records"][0] + assert record["segment"]["id"] == "segment-1" + assert record["segment"]["keywords"] == [] + assert record["child_chunks"] == [] + assert record["files"] == [] + assert record["score"] == 0.9 + assert record["tsne_position"] is None + assert record["summary"] is None @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @@ -315,8 +346,7 @@ class TestHitTestingApiPost: dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError( diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index 5f2dc19aab..f9d49237b7 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -126,7 +126,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask): monkeypatch.setattr( site_module.FeatureService, "get_features", - lambda tenant_id: SimpleNamespace(can_replace_logo=True), + lambda tenant_id, **_kwargs: SimpleNamespace(can_replace_logo=True), ) with app.test_request_context("/api/form/human_input/token-1", method="GET"): @@ -245,7 +245,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F monkeypatch.setattr( site_module.FeatureService, "get_features", - lambda tenant_id: SimpleNamespace(can_replace_logo=True), + lambda tenant_id, **_kwargs: SimpleNamespace(can_replace_logo=True), ) with app.test_request_context("/api/form/human_input/token-1", method="GET"): diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py index d5fb853ee3..f5e4b09993 100644 --- a/api/tests/unit_tests/core/agent/test_base_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -61,79 +61,20 @@ class TestRepack: class TestUpdatePromptTool: - def build_param(self, mocker: MockerFixture, **kwargs): - p = mocker.MagicMock() - p.form = kwargs.get("form") - - mock_type = mocker.MagicMock() - mock_type.as_normal_type.return_value = "string" - p.type = mock_type - - p.name = kwargs.get("name", "p1") - p.llm_description = "desc" - p.input_schema = kwargs.get("input_schema") - p.options = kwargs.get("options") - p.required = kwargs.get("required", False) - return p - - def test_skip_non_llm(self, runner, mocker: MockerFixture): + def test_replaces_prompt_tool_parameters_with_tool_schema(self, runner, mocker: MockerFixture): tool = mocker.MagicMock() - param = self.build_param(mocker, form="NOT_LLM") - tool.get_runtime_parameters.return_value = [param] + schema = { + "type": "object", + "properties": {"p1": {"type": "string", "description": "desc"}}, + "required": ["p1"], + } + tool.get_llm_parameters_json_schema.return_value = schema prompt_tool = mocker.MagicMock() prompt_tool.parameters = {"properties": {}, "required": []} result = runner.update_prompt_message_tool(tool, prompt_tool) - assert result.parameters["properties"] == {} - - def test_enum_and_required(self, runner, mocker: MockerFixture): - option = mocker.MagicMock(value="opt1") - param = self.build_param( - mocker, - form=module.ToolParameter.ToolParameterForm.LLM, - options=[option], - required=True, - ) - - tool = mocker.MagicMock() - tool.get_runtime_parameters.return_value = [param] - - prompt_tool = mocker.MagicMock() - prompt_tool.parameters = {"properties": {}, "required": []} - - result = runner.update_prompt_message_tool(tool, prompt_tool) - assert "p1" in result.parameters["required"] - - def test_skip_file_type_param(self, runner, mocker: MockerFixture): - tool = mocker.MagicMock() - param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM) - param.type = module.ToolParameter.ToolParameterType.FILE - tool.get_runtime_parameters.return_value = [param] - - prompt_tool = mocker.MagicMock() - prompt_tool.parameters = {"properties": {}, "required": []} - - result = runner.update_prompt_message_tool(tool, prompt_tool) - assert result.parameters["properties"] == {} - - def test_duplicate_required_not_duplicated(self, runner, mocker: MockerFixture): - tool = mocker.MagicMock() - - param = self.build_param( - mocker, - form=module.ToolParameter.ToolParameterForm.LLM, - required=True, - ) - - tool.get_runtime_parameters.return_value = [param] - - prompt_tool = mocker.MagicMock() - prompt_tool.parameters = {"properties": {}, "required": ["p1"]} - - result = runner.update_prompt_message_tool(tool, prompt_tool) - - assert result.parameters["required"].count("p1") == 1 + assert result.parameters == schema # ========================================================== @@ -383,57 +324,21 @@ class TestConvertToolToPromptMessageTool: def test_basic_conversion(self, runner, mocker: MockerFixture): tool = mocker.MagicMock(tool_name="tool1") - runtime_param = mocker.MagicMock() - runtime_param.form = module.ToolParameter.ToolParameterForm.LLM - runtime_param.name = "param1" - runtime_param.llm_description = "desc" - runtime_param.required = True - runtime_param.input_schema = None - runtime_param.options = None - - mock_type = mocker.MagicMock() - mock_type.as_normal_type.return_value = "string" - runtime_param.type = mock_type - tool_entity = mocker.MagicMock() tool_entity.entity.description.llm = "desc" - tool_entity.get_merged_runtime_parameters.return_value = [runtime_param] + schema = { + "type": "object", + "properties": {"param1": {"type": "string", "description": "desc"}}, + "required": ["param1"], + } + tool_entity.get_llm_parameters_json_schema.return_value = schema mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool) assert entity == tool_entity - - def test_full_conversion_multiple_params(self, runner, mocker: MockerFixture): - tool = mocker.MagicMock(tool_name="tool1") - - # LLM param with input_schema override - param1 = mocker.MagicMock() - param1.form = module.ToolParameter.ToolParameterForm.LLM - param1.name = "p1" - param1.llm_description = "desc" - param1.required = True - param1.input_schema = {"type": "integer"} - param1.options = None - param1.type = mocker.MagicMock() - - # SYSTEM_FILES param should be skipped - param2 = mocker.MagicMock() - param2.form = module.ToolParameter.ToolParameterForm.LLM - param2.name = "file_param" - param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES - - tool_entity = mocker.MagicMock() - tool_entity.entity.description.llm = "desc" - tool_entity.get_merged_runtime_parameters.return_value = [param1, param2] - - mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) - mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) - - prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool) - - assert entity == tool_entity + assert prompt_tool.parameters == schema # ========================================================== @@ -465,29 +370,6 @@ class TestInitPromptToolsExtended: class TestAdditionalCoverage: - def test_update_prompt_with_input_schema(self, runner, mocker: MockerFixture): - tool = mocker.MagicMock() - - param = mocker.MagicMock() - param.form = module.ToolParameter.ToolParameterForm.LLM - param.name = "p1" - param.required = False - param.llm_description = "desc" - param.options = None - param.input_schema = {"type": "number"} - - mock_type = mocker.MagicMock() - mock_type.as_normal_type.return_value = "string" - param.type = mock_type - - tool.get_runtime_parameters.return_value = [param] - - prompt_tool = mocker.MagicMock() - prompt_tool.parameters = {"properties": {}, "required": []} - - result = runner.update_prompt_message_tool(tool, prompt_tool) - assert result.parameters["properties"]["p1"]["type"] == "number" - def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker: MockerFixture): agent = mocker.MagicMock() agent.tool = "tool1" @@ -571,33 +453,6 @@ class TestAdditionalCoverage: result = runner.organize_agent_history([]) assert isinstance(result, list) - # ================= Additional Surgical Coverage ================= - - def test_convert_tool_select_enum_branch(self, runner, mocker: MockerFixture): - tool = mocker.MagicMock(tool_name="tool1") - - param = mocker.MagicMock() - param.form = module.ToolParameter.ToolParameterForm.LLM - param.name = "select_param" - param.required = True - param.llm_description = "desc" - param.input_schema = None - - option1 = mocker.MagicMock(value="A") - option2 = mocker.MagicMock(value="B") - param.options = [option1, option2] - param.type = module.ToolParameter.ToolParameterType.SELECT - - tool_entity = mocker.MagicMock() - tool_entity.entity.description.llm = "desc" - tool_entity.get_merged_runtime_parameters.return_value = [param] - - mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) - mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) - - prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool) - assert prompt_tool is not None - class TestConvertDatasetRetrieverTool: def test_required_param_added(self, runner, mocker: MockerFixture): @@ -663,24 +518,6 @@ class TestBaseAgentRunnerInit: class TestBaseAgentRunnerCoverage: - def test_convert_tool_skips_non_llm_param(self, runner, mocker: MockerFixture): - tool = mocker.MagicMock(tool_name="tool1") - - param = mocker.MagicMock() - param.form = "NOT_LLM" - param.type = mocker.MagicMock() - - tool_entity = mocker.MagicMock() - tool_entity.entity.description.llm = "desc" - tool_entity.get_merged_runtime_parameters.return_value = [param] - - mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) - mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) - - prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool) - - assert prompt_tool.parameters["properties"] == {} - def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker: MockerFixture): dataset_tool = mocker.MagicMock() dataset_tool.entity.identity.name = "ds" @@ -693,30 +530,6 @@ class TestBaseAgentRunnerCoverage: assert tools["ds"] == dataset_tool assert len(prompt_tools) == 1 - def test_update_prompt_message_tool_select_enum(self, runner, mocker: MockerFixture): - tool = mocker.MagicMock() - - option1 = mocker.MagicMock(value="A") - option2 = mocker.MagicMock(value="B") - - param = mocker.MagicMock() - param.form = module.ToolParameter.ToolParameterForm.LLM - param.name = "select_param" - param.required = False - param.llm_description = "desc" - param.input_schema = None - param.options = [option1, option2] - param.type = module.ToolParameter.ToolParameterType.SELECT - - tool.get_runtime_parameters.return_value = [param] - - prompt_tool = mocker.MagicMock() - prompt_tool.parameters = {"properties": {}, "required": []} - - result = runner.update_prompt_message_tool(tool, prompt_tool) - - assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"] - def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker: MockerFixture): agent = mocker.MagicMock() agent.tool = "tool1" diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index c6eedf7be7..cd1e5babf8 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -12,6 +12,7 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.app.apps.base_app_runner import AppRunner +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -41,6 +42,23 @@ class _QueueRecorder: self.events.append(event) +class _ClosableStream: + def __init__(self, chunks: list[LLMResultChunk]) -> None: + self._chunks = chunks + self.closed = False + + def __iter__(self): + return self + + def __next__(self): + if not self._chunks: + raise StopIteration + return self._chunks.pop(0) + + def close(self) -> None: + self.closed = True + + class TestAppRunner: def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch: pytest.MonkeyPatch): runner = AppRunner() @@ -331,6 +349,28 @@ class TestAppRunner: assert queue.events[-1].llm_result.usage == usage exception_logger.assert_called_once() + def test_handle_invoke_result_stream_closes_generator_when_stopped(self): + runner = AppRunner() + chunk = LLMResultChunk( + model="stream-model", + prompt_messages=[AssistantPromptMessage(content="prompt")], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="a")), + ) + stream = _ClosableStream([chunk]) + + queue_manager = SimpleNamespace( + publish=MagicMock(side_effect=GenerateTaskStoppedError("stopped")), + ) + + with pytest.raises(GenerateTaskStoppedError): + runner._handle_invoke_result_stream( + invoke_result=stream, + queue_manager=queue_manager, + agent=False, + ) + + assert stream.closed is True + def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch: pytest.MonkeyPatch): runner = AppRunner() diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence_inspector_publish.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence_inspector_publish.py new file mode 100644 index 0000000000..da8d6c5089 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence_inspector_publish.py @@ -0,0 +1,192 @@ +"""Verify the workflow persistence layer fans Inspector deltas to redis pub/sub. + +The hook lives in ``core/app/workflow/layers/persistence.py``: +every ``_handle_node_*`` and the terminal ``_handle_graph_run_*`` handlers +call into ``services.workflow.inspector_events.publish_node_changed`` / +``publish_workflow_completed`` after the DB write succeeds. Those calls are +the only thing the Inspector SSE stream listens to, so any future refactor of +the persistence layer must keep them in place. + +We don't reconstruct a full workflow engine here — the handlers are tested +in isolation by patching just the moving parts they touch +(``_workflow_execution`` + ``_node_execution_cache``) and asserting against +the publisher module's call sites. This keeps the test compact and tied to +the contract, not the implementation. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.app.workflow.layers import persistence as persistence_mod +from core.app.workflow.layers.persistence import WorkflowPersistenceLayer + + +@pytest.fixture +def layer() -> WorkflowPersistenceLayer: + """Build a layer instance with all repository / trace deps stubbed. + + We bypass ``__init__`` because constructing it for real pulls in the + workflow engine's app-generate-entity, repos, and a runtime state — none + of which matter for asserting that the publish-hook fires. + """ + instance = WorkflowPersistenceLayer.__new__(WorkflowPersistenceLayer) + # Minimum surface the handlers touch: + instance._workflow_execution_repository = MagicMock() + instance._workflow_node_execution_repository = MagicMock() + instance._trace_manager = None + instance._workflow_info = MagicMock(workflow_id="wf-1") + instance._application_generate_entity = MagicMock() + # Use a SimpleNamespace-like spec so Pydantic-validated callsites (e.g. + # ``WorkflowNodeExecution.new`` requires real strings) get the right types. + workflow_execution = MagicMock() + workflow_execution.id_ = "run-1" + workflow_execution.workflow_id = "wf-1" + workflow_execution.status = MagicMock(value="succeeded") + workflow_execution.outputs = {} + workflow_execution.error_message = None + workflow_execution.exceptions_count = 0 + workflow_execution.finished_at = None + instance._workflow_execution = workflow_execution + instance._node_execution_cache = {} + instance._node_snapshots = {} + instance._node_sequence = 0 + # `graph_runtime_state` is a layer-base property; stub it. + instance._graph_runtime_state = MagicMock(total_tokens=0, node_run_steps=0, outputs={}, exceptions_count=0) + return instance + + +@pytest.fixture +def capture_publishes(monkeypatch: pytest.MonkeyPatch) -> dict[str, list]: + """Replace the two publishers with capture lists so each test can assert + on the exact arguments.""" + calls: dict[str, list] = {"node": [], "workflow": []} + + def fake_node(*, workflow_run_id: str, node_id: str, status: str) -> None: + calls["node"].append({"workflow_run_id": workflow_run_id, "node_id": node_id, "status": status}) + + def fake_workflow(*, workflow_run_id: str, status: str) -> None: + calls["workflow"].append({"workflow_run_id": workflow_run_id, "status": status}) + + monkeypatch.setattr(persistence_mod, "_inspector_publish_node_changed", fake_node) + monkeypatch.setattr(persistence_mod, "_inspector_publish_workflow_completed", fake_workflow) + return calls + + +# ────────────────────────────────────────────────────────────────────────────── +# Graph-level publish hooks +# ────────────────────────────────────────────────────────────────────────────── + + +def _graph_event(**kwargs: Any) -> MagicMock: + return MagicMock(**kwargs) + + +def test_graph_run_succeeded_publishes_workflow_completed(layer, capture_publishes): + layer._workflow_execution.status = MagicMock(value="succeeded") + layer._handle_graph_run_succeeded(_graph_event(outputs={"text": "hi"})) + assert capture_publishes["workflow"] == [{"workflow_run_id": "run-1", "status": "succeeded"}] + assert capture_publishes["node"] == [] + + +def test_graph_run_partial_succeeded_publishes_workflow_completed(layer, capture_publishes): + layer._workflow_execution.status = MagicMock(value="partial-succeeded") + layer._handle_graph_run_partial_succeeded(_graph_event(outputs={}, exceptions_count=1)) + assert capture_publishes["workflow"] == [{"workflow_run_id": "run-1", "status": "partial-succeeded"}] + + +def test_graph_run_failed_publishes_workflow_completed(layer, capture_publishes): + layer._workflow_execution.status = MagicMock(value="failed") + layer._handle_graph_run_failed(_graph_event(error="boom", exceptions_count=0)) + assert capture_publishes["workflow"] == [{"workflow_run_id": "run-1", "status": "failed"}] + + +def test_graph_run_aborted_publishes_workflow_completed(layer, capture_publishes): + layer._workflow_execution.status = MagicMock(value="stopped") + layer._handle_graph_run_aborted(_graph_event(reason="user stop")) + assert capture_publishes["workflow"] == [{"workflow_run_id": "run-1", "status": "stopped"}] + + +def test_graph_run_paused_does_not_publish_completion(layer, capture_publishes): + """Pause is not a terminal state — the Inspector keeps waiting for either + resume or a real terminal event.""" + layer._handle_graph_run_paused(_graph_event(outputs={})) + assert capture_publishes["workflow"] == [] + assert capture_publishes["node"] == [] + + +# ────────────────────────────────────────────────────────────────────────────── +# Node-level publish hooks +# ────────────────────────────────────────────────────────────────────────────── + + +def _node_started_event(node_id: str = "agent-1", exec_id: str = "exec-1") -> MagicMock: + return MagicMock( + id=exec_id, + node_id=node_id, + node_type="agent", + node_title="Greeter", + predecessor_node_id=None, + in_iteration_id=None, + in_loop_id=None, + start_at=datetime(2026, 5, 26, 0, 0, 0), + ) + + +def _seed_node_execution(layer: WorkflowPersistenceLayer, exec_id: str, node_id: str) -> None: + """Inject a domain execution into the cache so the success / fail / etc + handlers (which look it up by id) can run without going through started.""" + layer._node_execution_cache[exec_id] = MagicMock( + id=exec_id, node_id=node_id, status=MagicMock(value="running"), outputs={}, error=None + ) + + +def test_node_started_publishes_running(layer, capture_publishes): + layer._handle_node_started(_node_started_event()) + assert capture_publishes["node"] == [{"workflow_run_id": "run-1", "node_id": "agent-1", "status": "running"}] + + +def test_node_retry_publishes_retry(layer, capture_publishes): + _seed_node_execution(layer, exec_id="exec-1", node_id="agent-1") + event = MagicMock(id="exec-1", error="rate limit") + layer._handle_node_retry(event) + assert capture_publishes["node"] == [{"workflow_run_id": "run-1", "node_id": "agent-1", "status": "retry"}] + + +def test_node_succeeded_publishes_succeeded(layer, capture_publishes, monkeypatch: pytest.MonkeyPatch): + _seed_node_execution(layer, exec_id="exec-1", node_id="agent-1") + # Stub the inner _update_node_execution so we don't have to construct a + # full NodeRunResult — we only want to confirm the publish happens after. + monkeypatch.setattr(layer, "_update_node_execution", lambda *a, **kw: None) + event = MagicMock(id="exec-1", node_run_result=MagicMock(), finished_at=datetime.now()) + layer._handle_node_succeeded(event) + assert capture_publishes["node"] == [{"workflow_run_id": "run-1", "node_id": "agent-1", "status": "succeeded"}] + + +def test_node_failed_publishes_failed(layer, capture_publishes, monkeypatch: pytest.MonkeyPatch): + _seed_node_execution(layer, exec_id="exec-1", node_id="agent-1") + monkeypatch.setattr(layer, "_update_node_execution", lambda *a, **kw: None) + event = MagicMock(id="exec-1", node_run_result=MagicMock(), error="bad", finished_at=datetime.now()) + layer._handle_node_failed(event) + assert capture_publishes["node"] == [{"workflow_run_id": "run-1", "node_id": "agent-1", "status": "failed"}] + + +def test_node_exception_publishes_exception(layer, capture_publishes, monkeypatch: pytest.MonkeyPatch): + _seed_node_execution(layer, exec_id="exec-1", node_id="agent-1") + monkeypatch.setattr(layer, "_update_node_execution", lambda *a, **kw: None) + event = MagicMock(id="exec-1", node_run_result=MagicMock(), error="oom", finished_at=datetime.now()) + layer._handle_node_exception(event) + assert capture_publishes["node"] == [{"workflow_run_id": "run-1", "node_id": "agent-1", "status": "exception"}] + + +def test_node_pause_requested_does_not_publish(layer, capture_publishes, monkeypatch: pytest.MonkeyPatch): + """Node pause is not an Inspector-visible state — no publish.""" + _seed_node_execution(layer, exec_id="exec-1", node_id="agent-1") + monkeypatch.setattr(layer, "_update_node_execution", lambda *a, **kw: None) + event = MagicMock(id="exec-1", node_run_result=MagicMock()) + layer._handle_node_pause_requested(event) + assert capture_publishes["node"] == [] diff --git a/api/tests/unit_tests/core/tools/test_base_tool.py b/api/tests/unit_tests/core/tools/test_base_tool.py index 23d3e77c1d..9486144e98 100644 --- a/api/tests/unit_tests/core/tools/test_base_tool.py +++ b/api/tests/unit_tests/core/tools/test_base_tool.py @@ -8,7 +8,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) class DummyCastType: @@ -25,6 +31,7 @@ class DummyParameter: default: Any = None options: list[Any] | None = None llm_description: str | None = None + input_schema: dict[str, Any] | None = None class DummyTool(Tool): @@ -149,13 +156,27 @@ def test_fork_tool_runtime_returns_new_tool_with_copied_entity(): def test_get_runtime_parameters_and_merge_runtime_parameters(): tool = _build_tool() - original = DummyParameter(name="temperature", type=DummyCastType(), form="schema", required=True, default="0.7") + original = DummyParameter( + name="temperature", + type=DummyCastType(), + form="schema", + required=True, + default="0.7", + input_schema={"type": "string"}, + ) tool.entity.parameters = cast(Any, [original]) default_runtime_parameters = tool.get_runtime_parameters() assert default_runtime_parameters == [original] - override = DummyParameter(name="temperature", type=DummyCastType(), form="llm", required=False, default="0.5") + override = DummyParameter( + name="temperature", + type=DummyCastType(), + form="llm", + required=False, + default="0.5", + input_schema={"type": "object"}, + ) appended = DummyParameter(name="new_param", type=DummyCastType(), form="form", required=False, default="x") tool.runtime_parameter_overrides = [override, appended] @@ -165,7 +186,93 @@ def test_get_runtime_parameters_and_merge_runtime_parameters(): assert merged[0].form == "llm" assert merged[0].required is False assert merged[0].default == "0.5" + assert merged[0].input_schema == {"type": "object"} assert merged[1].name == "new_param" + assert merged[0] is not original + assert merged[1] is not appended + assert original.form == "schema" + assert original.required is True + assert original.default == "0.7" + assert original.input_schema == {"type": "string"} + + +def test_get_llm_parameters_json_schema_uses_effective_runtime_parameters(): + tool = _build_tool() + query_parameter = ToolParameter.get_simple_instance( + name="query", + llm_description="Declared query", + typ=ToolParameter.ToolParameterType.STRING, + required=True, + ) + region_parameter = ToolParameter.get_simple_instance( + name="region", + llm_description="Search region", + typ=ToolParameter.ToolParameterType.SELECT, + required=False, + options=["global", "cn"], + ) + hidden_parameter = ToolParameter.get_simple_instance( + name="api_key", + llm_description="Hidden api key", + typ=ToolParameter.ToolParameterType.STRING, + required=True, + ) + hidden_parameter.form = ToolParameter.ToolParameterForm.FORM + file_parameter = ToolParameter.get_simple_instance( + name="attachment", + llm_description="Attachment", + typ=ToolParameter.ToolParameterType.FILE, + required=False, + ) + payload_parameter = ToolParameter( + name="payload", + label=I18nObject(en_US="payload", zh_Hans="payload"), + placeholder=None, + human_description=I18nObject(en_US="payload", zh_Hans="payload"), + type=ToolParameter.ToolParameterType.OBJECT, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Payload", + required=False, + input_schema={ + "type": "object", + "properties": {"nested": {"type": "string"}}, + }, + ) + tool.entity.parameters = [query_parameter, region_parameter, hidden_parameter, file_parameter, payload_parameter] + + query_override = ToolParameter.get_simple_instance( + name="query", + llm_description="Runtime query", + typ=ToolParameter.ToolParameterType.STRING, + required=True, + ) + tool.runtime_parameter_overrides = [query_override] + + schema = tool.get_llm_parameters_json_schema() + + assert schema == { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Runtime query"}, + "region": { + "type": "string", + "description": "Search region", + "enum": ["global", "cn"], + }, + "payload": { + "type": "object", + "properties": {"nested": {"type": "string"}}, + "description": "Payload", + }, + }, + "required": ["query"], + } + + schema["properties"]["payload"]["properties"]["nested"]["type"] = "number" + assert payload_parameter.input_schema == { + "type": "object", + "properties": {"nested": {"type": "string"}}, + } def test_message_factory_helpers(): diff --git a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_plugin_tools_builder.py b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_plugin_tools_builder.py new file mode 100644 index 0000000000..c27b560e45 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_plugin_tools_builder.py @@ -0,0 +1,439 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import pytest + +from core.agent.entities import AgentToolEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, +) +from core.workflow.nodes.agent_v2.plugin_tools_builder import ( + WorkflowAgentPluginToolsBuilder, + WorkflowAgentPluginToolsBuildError, +) +from models.agent_config_entities import AgentSoulToolsConfig + + +class FakeRuntimeProvider: + def __init__(self, tool: Tool | Exception) -> None: + # Either a Tool to hand back, or an exception to raise on lookup. The + # latter lets tests exercise the error-mapping branches in + # ``WorkflowAgentPluginToolsBuilder._fetch_tool_runtime``. + self.tool = tool + self.last_agent_tool: AgentToolEntity | None = None + self.last_invoke_from: InvokeFrom | None = None + + def get_agent_tool_runtime( + self, + tenant_id: str, + app_id: str, + agent_tool: AgentToolEntity, + user_id: str | None = None, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + variable_pool: Any | None = None, + ) -> Tool: + self.last_agent_tool = agent_tool + self.last_invoke_from = invoke_from + if isinstance(self.tool, Exception): + raise self.tool + return self.tool + + +class FakeTool(Tool): + def tool_provider_type(self): + raise NotImplementedError + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: + raise NotImplementedError + + +def _tool(*, runtime_parameters: dict[str, Any] | None = None) -> FakeTool: + if runtime_parameters is None: + runtime_parameters = {"region": "us"} + parameters = [ + ToolParameter( + name="query", + label=I18nObject(en_US="Query"), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + required=True, + llm_description="Search query", + ), + ToolParameter( + name="region", + label=I18nObject(en_US="Region"), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.FORM, + required=True, + ), + ] + entity = ToolEntity( + identity=ToolIdentity( + author="langgenius", + name="search", + label=I18nObject(en_US="Search"), + provider="search", + ), + description=ToolDescription(human=I18nObject(en_US="Search"), llm="Search the web."), + parameters=parameters, + ) + runtime = ToolRuntime( + tenant_id="tenant-1", + user_id="user-1", + credentials={"api_key": "secret"}, + runtime_parameters=runtime_parameters, + ) + return FakeTool(entity=entity, runtime=runtime) + + +def _build( + builder: WorkflowAgentPluginToolsBuilder, + tools: AgentSoulToolsConfig, + *, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, +): + """Shorthand for ``builder.build(...)`` with the standard tenant/app/user + triple, so each test only highlights what's actually unique to it.""" + return builder.build( + tenant_id="tenant-1", + app_id="app-1", + user_id="user-1", + tools=tools, + invoke_from=invoke_from, + ) + + +def test_builds_dify_plugin_tools_layer_from_existing_tool_runtime(): + runtime_provider = FakeRuntimeProvider(_tool()) + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=runtime_provider) + tools = AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "provider_id": "langgenius/search/search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + "runtime_parameters": {"region": "us"}, + } + ] + } + ) + + result = _build(builder, tools) + + assert result is not None + prepared = result.tools[0] + assert prepared.plugin_id == "langgenius/search" + assert prepared.provider == "search" + assert prepared.tool_name == "search" + assert prepared.name == "search" + assert prepared.credentials == {"api_key": "secret"} + assert prepared.runtime_parameters == {"region": "us"} + assert prepared.parameters_json_schema["properties"]["query"]["type"] == "string" + assert "region" not in prepared.parameters_json_schema["properties"] + assert runtime_provider.last_agent_tool is not None + assert runtime_provider.last_agent_tool.credential_id == "credential-1" + # Default ``provider_type`` is now ``"plugin"`` — the agent tool entity + # must surface that so ToolManager hits the plugin provider table, not the + # built-in legacy table. + assert runtime_provider.last_agent_tool.provider_type.value == "plugin" + + +def test_rejects_duplicate_exposed_tool_names(): + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=FakeRuntimeProvider(_tool())) + tools = AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "provider_id": "langgenius/search/search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + "runtime_parameters": {"region": "us"}, + }, + { + "provider_id": "langgenius/search/search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + "runtime_parameters": {"region": "us"}, + }, + ] + } + ) + + with pytest.raises(WorkflowAgentPluginToolsBuildError) as exc_info: + _build(builder, tools) + + assert exc_info.value.error_code == "agent_tool_name_duplicated" + + +def test_rejects_missing_required_runtime_parameter(): + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=FakeRuntimeProvider(_tool(runtime_parameters={}))) + tools = AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "provider_id": "langgenius/search/search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + } + ] + } + ) + + with pytest.raises(WorkflowAgentPluginToolsBuildError) as exc_info: + _build(builder, tools) + + assert exc_info.value.error_code == "agent_tool_runtime_parameter_missing" + + +# ────────────────────────────────────────────────────────────────────────────── +# invoke_from is threaded through to ToolManager +# ────────────────────────────────────────────────────────────────────────────── + + +def test_invoke_from_is_forwarded_to_tool_runtime_provider(): + """``WorkflowAgentRuntimeRequestBuilder`` passes the *real* runtime + invocation source (DEBUGGER for draft test run, SERVICE_API for published + run, etc.). ToolManager uses ``invoke_from`` for credential quotas / rate + limits / audit tags, so any default-falling-back here would silently + misattribute usage. Lock in the forwarding behaviour for both + representative invoke_from values.""" + for invoke_from in (InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP): + runtime_provider = FakeRuntimeProvider(_tool()) + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=runtime_provider) + tools = AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "provider_id": "langgenius/search/search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + "runtime_parameters": {"region": "us"}, + } + ] + } + ) + + _build(builder, tools, invoke_from=invoke_from) + + assert runtime_provider.last_invoke_from == invoke_from + + +# ────────────────────────────────────────────────────────────────────────────── +# disabled tools / plugin_id+provider fallback / unauthorized credentials +# ────────────────────────────────────────────────────────────────────────────── + + +def test_disabled_tools_are_skipped(): + runtime_provider = FakeRuntimeProvider(_tool()) + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=runtime_provider) + tools = AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "provider_id": "langgenius/search/search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + "runtime_parameters": {"region": "us"}, + "enabled": False, + } + ] + } + ) + + # All entries are disabled → builder short-circuits and returns None so the + # request_builder skips adding the tools layer entirely. + assert _build(builder, tools) is None + assert runtime_provider.last_agent_tool is None # ToolManager never queried + + +def test_plugin_id_plus_provider_fallback_when_provider_id_missing(): + """Frontend may send ``plugin_id`` + ``provider`` instead of the + concatenated ``provider_id``; the builder must accept both shapes.""" + runtime_provider = FakeRuntimeProvider(_tool()) + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=runtime_provider) + tools = AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "plugin_id": "langgenius/search", + "provider": "search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + "runtime_parameters": {"region": "us"}, + } + ] + } + ) + + result = _build(builder, tools) + + assert result is not None + assert runtime_provider.last_agent_tool is not None + assert runtime_provider.last_agent_tool.provider_id == "langgenius/search/search" + assert result.tools[0].plugin_id == "langgenius/search" + assert result.tools[0].provider == "search" + + +def test_unauthorized_tool_without_credentials(): + """``credential_type=unauthorized`` removes the ``credential_ref.id`` + requirement (e.g. public Wikipedia / current_time tools).""" + + def _no_credentials_tool() -> FakeTool: + tool = _tool() + assert tool.runtime is not None + tool.runtime.credentials = {} + return tool + + runtime_provider = FakeRuntimeProvider(_no_credentials_tool()) + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=runtime_provider) + tools = AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "provider_id": "langgenius/time/time", + "tool_name": "current_time", + "credential_type": "unauthorized", + "runtime_parameters": {"region": "us"}, + } + ] + } + ) + + result = _build(builder, tools) + assert result is not None + assert result.tools[0].credential_type == "unauthorized" + assert result.tools[0].credentials == {} + + +# ────────────────────────────────────────────────────────────────────────────── +# Error-code mapping: declaration not found / credential invalid / config +# ────────────────────────────────────────────────────────────────────────────── + + +def _standard_tools_payload() -> AgentSoulToolsConfig: + return AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "provider_id": "langgenius/search/search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + "runtime_parameters": {"region": "us"}, + } + ] + } + ) + + +def test_tool_provider_not_found_maps_to_declaration_not_found(): + from core.tools.errors import ToolProviderNotFoundError + + builder = WorkflowAgentPluginToolsBuilder( + tool_runtime_provider=FakeRuntimeProvider(ToolProviderNotFoundError("provider gone")) + ) + with pytest.raises(WorkflowAgentPluginToolsBuildError) as exc_info: + _build(builder, _standard_tools_payload()) + assert exc_info.value.error_code == "agent_tool_declaration_not_found" + + +def test_credential_validation_error_maps_to_credential_invalid(): + from core.tools.errors import ToolProviderCredentialValidationError + + builder = WorkflowAgentPluginToolsBuilder( + tool_runtime_provider=FakeRuntimeProvider(ToolProviderCredentialValidationError("creds expired")) + ) + with pytest.raises(WorkflowAgentPluginToolsBuildError) as exc_info: + _build(builder, _standard_tools_payload()) + assert exc_info.value.error_code == "agent_tool_credential_invalid" + + +def test_generic_value_error_maps_to_config_invalid(): + """Bare ``ValueError`` from ToolManager (e.g. "runtime not found") becomes + ``agent_tool_config_invalid`` — distinct from + ``agent_tool_declaration_not_found`` so callers can render a different + hint.""" + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=FakeRuntimeProvider(ValueError("runtime missing"))) + with pytest.raises(WorkflowAgentPluginToolsBuildError) as exc_info: + _build(builder, _standard_tools_payload()) + assert exc_info.value.error_code == "agent_tool_config_invalid" + + +# ────────────────────────────────────────────────────────────────────────────── +# Non-scalar credentials rejected instead of silently str()'d +# ────────────────────────────────────────────────────────────────────────────── + + +def test_rejects_non_scalar_credential_value(): + """If a credential ever shows up shaped like ``{"access_token": "..."}``, + ``str(value)`` would forward a Python repr to the plugin daemon. The + builder should refuse and surface an explicit error code so an operator + fixes the credential schema instead of debugging a daemon JSON parse + failure.""" + + def _dict_credential_tool() -> FakeTool: + tool = _tool() + assert tool.runtime is not None + tool.runtime.credentials = {"oauth": {"access_token": "secret", "expires_in": 3600}} + return tool + + builder = WorkflowAgentPluginToolsBuilder(tool_runtime_provider=FakeRuntimeProvider(_dict_credential_tool())) + with pytest.raises(WorkflowAgentPluginToolsBuildError) as exc_info: + _build(builder, _standard_tools_payload()) + assert exc_info.value.error_code == "agent_tool_credential_shape_invalid" + + +# ────────────────────────────────────────────────────────────────────────────── +# Legacy payload normalization +# ────────────────────────────────────────────────────────────────────────────── + + +def test_legacy_provider_name_and_tool_parameters_normalized(): + """Old Composer save payloads used ``provider_name`` / ``tool_parameters`` + keys. The ``@model_validator(mode="before")`` on AgentSoulDifyToolConfig + rewrites them in-place so reading historical Agent Soul snapshots from the + DB still works.""" + config = AgentSoulToolsConfig.model_validate( + { + "dify_tools": [ + { + "provider_name": "langgenius/search/search", + "tool_name": "search", + "credential_type": "api-key", + "credential_id": "credential-1", + "tool_parameters": {"region": "us"}, + } + ] + } + ) + + tool = config.dify_tools[0] + assert tool.provider_id == "langgenius/search/search" + assert tool.runtime_parameters == {"region": "us"} + assert tool.credential_ref is not None + assert tool.credential_ref.id == "credential-1" diff --git a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py index 7ddb5552a8..48ae0d46f2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py @@ -1,7 +1,9 @@ from dataclasses import replace import pytest +from dify_agent.layers.dify_plugin import DifyPluginToolConfig, DifyPluginToolsLayerConfig +from clients.agent_backend import DIFY_EXECUTION_CONTEXT_LAYER_ID, DIFY_PLUGIN_TOOLS_LAYER_ID from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.workflow.nodes.agent_v2.runtime_request_builder import ( WorkflowAgentRuntimeBuildContext, @@ -25,6 +27,38 @@ class FakeCredentialsProvider: return {"api_key": "secret-key"} +class FakePluginToolsBuilder: + def __init__(self) -> None: + # Capture the runtime invocation source so tests can assert it was + # threaded through from ``DifyRunContext.invoke_from`` rather than + # hard-coded to a placeholder like ``VALIDATION``. + self.last_invoke_from: InvokeFrom | None = None + + def build(self, *, tenant_id, app_id, user_id, tools, invoke_from): + assert tenant_id == "tenant-1" + assert app_id == "app-1" + assert user_id == "user-1" + self.last_invoke_from = invoke_from + if not tools.dify_tools: + return None + return DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/time", + provider="time", + tool_name="current_time", + credential_type="unauthorized", + name="current_time", + description="Get current time.", + credentials={}, + runtime_parameters={}, + parameters=[], + parameters_json_schema={"type": "object", "properties": {}, "required": []}, + ) + ] + ) + + class FakeVariablePool: def get(self, selector): if list(selector) == ["sys", "query"]: @@ -93,9 +127,10 @@ def test_builds_create_run_request_from_agent_soul_and_node_job(): result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(_context()) dumped = result.request.model_dump(mode="json") - assert dumped["execution_context"]["agent_id"] == "agent-1" - assert dumped["execution_context"]["agent_config_version_id"] == "snapshot-1" - assert dumped["execution_context"]["invoke_from"] == "single_step" + layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]} + assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID]["config"]["agent_id"] == "agent-1" + assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID]["config"]["agent_config_version_id"] == "snapshot-1" + assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID]["config"]["invoke_from"] == "single_step" assert dumped["idempotency_key"] == "run-1:node-exec-1" assert dumped["composition"]["layers"][0]["config"]["prefix"] == "You are careful." assert dumped["composition"]["layers"][1]["config"]["prefix"] == "Use the previous output." @@ -145,15 +180,68 @@ def test_builds_workflow_run_request_with_file_output_schema_and_reserved_metada result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context) dumped = result.request.model_dump(mode="json") - assert dumped["execution_context"]["invoke_from"] == "workflow_run" + layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]} + assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID]["config"]["invoke_from"] == "workflow_run" assert dumped["idempotency_key"] == "node-exec-1" output_schema = dumped["composition"]["layers"][-1]["config"]["json_schema"] assert output_schema["properties"]["report"]["properties"]["file_id"]["type"] == "string" assert output_schema["properties"]["confidence"]["type"] == "number" assert output_schema["required"] == ["report"] assert dumped["composition"]["layers"][4]["config"]["model_settings"] == {"temperature": 0.2} - assert result.metadata["runtime_support"]["reserved_status"]["tools"] == "reserved_not_executed" - assert result.metadata["runtime_support"]["unsupported_runtime_warnings"][0]["section"] == "agent_soul.tools" + assert result.metadata["runtime_support"]["reserved_status"]["tools.dify_tools"] == "supported_when_config_valid" + assert result.metadata["runtime_support"]["reserved_status"]["tools.cli_tools"] == "reserved_not_executed" + warnings = result.metadata["runtime_support"]["unsupported_runtime_warnings"] + assert warnings[0]["section"] == "agent_soul.tools.cli_tools" + + +def test_builds_workflow_run_request_with_dify_plugin_tools_layer(): + context = _context() + snapshot = AgentConfigSnapshot( + id="snapshot-1", + tenant_id="tenant-1", + agent_id="agent-1", + version=1, + config_snapshot=AgentSoulConfig( + prompt={"system_prompt": "You are careful."}, + model=AgentSoulModelConfig( + plugin_id="langgenius/openai", + model_provider="openai", + model="gpt-test", + ), + tools={ + "dify_tools": [ + { + "provider_id": "langgenius/time/time", + "tool_name": "current_time", + "credential_type": "unauthorized", + } + ] + }, + ), + ) + context = replace(context, snapshot=snapshot) + + plugin_tools_builder = FakePluginToolsBuilder() + result = WorkflowAgentRuntimeRequestBuilder( + credentials_provider=FakeCredentialsProvider(), + plugin_tools_builder=plugin_tools_builder, + ).build(context) + + dumped = result.request.model_dump(mode="json") + layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]} + assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID]["type"] == "dify.plugin.tools" + assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID]["deps"] == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID} + assert layers[DIFY_PLUGIN_TOOLS_LAYER_ID]["config"]["tools"][0]["tool_name"] == "current_time" + assert result.metadata["agent_tools"] == { + "dify_tool_count": 1, + "dify_tool_names": ["current_time"], + "cli_tool_count": 0, + } + # The runtime invocation source must flow from ``DifyRunContext.invoke_from`` + # into the plugin tools builder so ToolManager attributes credential + # quotas / rate limits / audit tags to the real call site instead of a + # hard-coded ``VALIDATION`` placeholder. + assert plugin_tools_builder.last_invoke_from == context.dify_context.invoke_from def test_requires_agent_soul_model_config(): diff --git a/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py b/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py index dd4304ccb1..d3d2d583f2 100644 --- a/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py +++ b/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py @@ -11,6 +11,7 @@ from libs.oauth_bearer import ( SubjectType, TokenKind, TokenKindRegistry, + TokenType, ) @@ -21,7 +22,7 @@ def _registry_with_resolver(resolver) -> TokenKindRegistry: prefix="dfoa_", subject_type=SubjectType.ACCOUNT, scopes=frozenset({Scope.FULL}), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, resolver=resolver, ) ] @@ -63,7 +64,7 @@ def test_unknown_prefix_raises_generic_invalid_bearer(): prefix="dfoa_", subject_type=SubjectType.ACCOUNT, scopes=frozenset({Scope.FULL}), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, resolver=MagicMock(), ) ] diff --git a/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py b/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py index 898e4578e6..e8204a6e2e 100644 --- a/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py +++ b/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py @@ -19,6 +19,7 @@ from libs.oauth_bearer import ( AuthContext, Scope, SubjectType, + TokenType, require_scope, reset_auth_ctx, set_auth_ctx, @@ -50,7 +51,7 @@ def _ctx(scopes) -> AuthContext: client_id="difyctl", scopes=scopes, token_id=uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, expires_at=None, token_hash="h1", verified_tenants={}, diff --git a/api/tests/unit_tests/libs/test_workspace_member_helper.py b/api/tests/unit_tests/libs/test_workspace_member_helper.py index 540e19ad9e..f4933e7f59 100644 --- a/api/tests/unit_tests/libs/test_workspace_member_helper.py +++ b/api/tests/unit_tests/libs/test_workspace_member_helper.py @@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch import pytest from werkzeug.exceptions import Forbidden -from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member +from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, require_workspace_member def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext: @@ -20,7 +20,7 @@ def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> Au client_id="difyctl", scopes=frozenset({Scope.FULL}), token_id=uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT if account else TokenType.OAUTH_EXTERNAL_SSO, expires_at=None, token_hash="h1", verified_tenants=dict(verified or {}), diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 5e89d9fb42..e09102a788 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from configs import dify_config -from models.account import Account, AccountStatus, TenantStatus +from models.account import Account, AccountStatus, TenantAccountRole, TenantStatus from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import ( AccountAlreadyInTenantError, @@ -567,6 +567,52 @@ class TestTenantService: with pytest.raises(exception_type): callable_func(*args, **kwargs) + # ==================== get_account_role_in_tenant Tests ==================== + # Backs `require_workspace_role`: None => non-member (gate maps to 404), + # otherwise the caller's role (gate maps an out-of-set role to 403). + + def test_get_account_role_in_tenant_returns_role_for_member(self): + """A row in TenantAccountJoin yields the caller's role.""" + mock_session = MagicMock() + mock_session.execute.return_value.scalar_one_or_none.return_value = TenantAccountRole.ADMIN + + role = TenantService.get_account_role_in_tenant(mock_session, "account-1", "tenant-1") + + assert role == TenantAccountRole.ADMIN + + def test_get_account_role_in_tenant_returns_none_for_non_member(self): + """No join row => None, so the gate cannot leak the workspace's existence.""" + mock_session = MagicMock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None + + role = TenantService.get_account_role_in_tenant(mock_session, "account-1", "tenant-1") + + assert role is None + + def test_get_account_role_in_tenant_short_circuits_empty_account_id(self): + """None/empty account_id (SSO bearer, missing identity) returns None + without ever touching the session.""" + mock_session = MagicMock() + + assert TenantService.get_account_role_in_tenant(mock_session, None, "tenant-1") is None + mock_session.execute.assert_not_called() + + def test_get_account_role_in_tenant_query_is_scoped(self): + """The lookup must filter on BOTH tenant_id and account_id — otherwise + a member of workspace A could read their role for workspace B. Compile + the statement and assert both identifiers appear in the WHERE clause.""" + account_id = "11111111-1111-1111-1111-111111111111" + tenant_id = "22222222-2222-2222-2222-222222222222" + mock_session = MagicMock() + mock_session.execute.return_value.scalar_one_or_none.return_value = TenantAccountRole.NORMAL + + TenantService.get_account_role_in_tenant(mock_session, account_id, tenant_id) + + stmt = mock_session.execute.call_args.args[0] + compiled = str(stmt.compile(compile_kwargs={"literal_binds": True})) + assert account_id in compiled + assert tenant_id in compiled + # ==================== Tenant Creation Tests ==================== def test_create_owner_tenant_if_not_exist_new_user( diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index f5879d973d..352a765de2 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -39,7 +39,7 @@ def fake_features(monkeypatch: pytest.MonkeyPatch): ) monkeypatch.setattr( "services.dataset_service.FeatureService.get_features", - lambda tenant_id: features, + lambda tenant_id, **_kwargs: features, ) return features diff --git a/api/tests/unit_tests/services/test_oauth_device_flow.py b/api/tests/unit_tests/services/test_oauth_device_flow.py index b2e95c93a3..fcb3f29a76 100644 --- a/api/tests/unit_tests/services/test_oauth_device_flow.py +++ b/api/tests/unit_tests/services/test_oauth_device_flow.py @@ -3,7 +3,7 @@ from __future__ import annotations import uuid from unittest.mock import MagicMock -from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType +from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType, TokenType from services.oauth_device_flow import ( list_active_sessions, revoke_oauth_token, @@ -21,7 +21,7 @@ def _account_ctx() -> AuthContext: client_id="difyctl", scopes=frozenset({"full"}), token_id=uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, expires_at=None, token_hash="h1", verified_tenants={}, @@ -37,7 +37,7 @@ def _sso_ctx() -> AuthContext: client_id="difyctl", scopes=frozenset({"apps:run"}), token_id=uuid.uuid4(), - source="oauth_external_sso", + token_type=TokenType.OAUTH_EXTERNAL_SSO, expires_at=None, token_hash="h1", verified_tenants={}, diff --git a/api/tests/unit_tests/services/workflow/test_inspector_events.py b/api/tests/unit_tests/services/workflow/test_inspector_events.py new file mode 100644 index 0000000000..8e1992861f --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_inspector_events.py @@ -0,0 +1,224 @@ +"""Unit tests for :mod:`services.workflow.inspector_events`. + +The publisher and subscriber both touch redis, so we mock it out at the +``redis_client`` boundary. The goal is to lock down: + +1. the channel-naming convention (frontend SSE doesn't need to know it but + tests catch accidental renames), +2. the JSON envelope (``kind / workflow_run_id / node_id / status``), +3. publisher robustness when redis is unavailable, +4. subscriber's tolerance of malformed payloads and bytes-vs-str messages, +5. subscriber's heartbeat-on-idle behaviour. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator +from typing import Any +from unittest.mock import MagicMock, patch + +from services.workflow import inspector_events +from services.workflow.inspector_events import InspectorMessage + +# ────────────────────────────────────────────────────────────────────────────── +# Channel + envelope +# ────────────────────────────────────────────────────────────────────────────── + + +def test_channel_for_returns_namespaced_key(): + assert inspector_events.channel_for("run-42") == "dify:inspector:workflow_run:run-42" + + +def test_inspector_message_to_json_round_trip(): + msg = InspectorMessage(kind="node_changed", workflow_run_id="r1", node_id="agent-1", status="succeeded") + parsed = json.loads(msg.to_json()) + assert parsed == {"kind": "node_changed", "workflow_run_id": "r1", "node_id": "agent-1", "status": "succeeded"} + + +def test_inspector_message_from_json_rejects_bad_kind(): + blob = json.dumps({"kind": "something_else", "workflow_run_id": "r1"}) + assert InspectorMessage.from_json(blob) is None + + +def test_inspector_message_from_json_rejects_bad_workflow_run_id(): + blob = json.dumps({"kind": "node_changed", "workflow_run_id": ""}) + assert InspectorMessage.from_json(blob) is None + + +def test_inspector_message_from_json_rejects_non_string_node_id(): + blob = json.dumps({"kind": "node_changed", "workflow_run_id": "r1", "node_id": 42}) + assert InspectorMessage.from_json(blob) is None + + +def test_inspector_message_from_json_returns_none_for_invalid_json(): + assert InspectorMessage.from_json("{not json") is None + + +def test_inspector_message_from_json_rejects_non_dict_payload(): + """Defensive: a JSON array or scalar is not an InspectorMessage.""" + assert InspectorMessage.from_json("[1, 2, 3]") is None + assert InspectorMessage.from_json('"plain string"') is None + + +def test_inspector_message_from_json_rejects_non_string_status(): + """Status field, if present, must be a string.""" + blob = json.dumps({"kind": "workflow_completed", "workflow_run_id": "r1", "status": 42}) + assert InspectorMessage.from_json(blob) is None + + +# ────────────────────────────────────────────────────────────────────────────── +# Publisher +# ────────────────────────────────────────────────────────────────────────────── + + +def test_publish_node_changed_writes_to_run_channel(): + fake_redis = MagicMock() + with patch.object(inspector_events, "redis_client", fake_redis): + inspector_events.publish_node_changed(workflow_run_id="run-1", node_id="agent-1", status="running") + + fake_redis.publish.assert_called_once() + channel, blob = fake_redis.publish.call_args.args + assert channel == "dify:inspector:workflow_run:run-1" + msg = InspectorMessage.from_json(blob) + assert msg is not None + assert msg.kind == "node_changed" + assert msg.node_id == "agent-1" + assert msg.status == "running" + + +def test_publish_workflow_completed_emits_terminal_message(): + fake_redis = MagicMock() + with patch.object(inspector_events, "redis_client", fake_redis): + inspector_events.publish_workflow_completed(workflow_run_id="run-1", status="succeeded") + + blob = fake_redis.publish.call_args.args[1] + msg = InspectorMessage.from_json(blob) + assert msg is not None + assert msg.kind == "workflow_completed" + assert msg.node_id is None + assert msg.status == "succeeded" + + +def test_publish_swallows_redis_errors(): + """Persistence must not crash if redis blows up — we publish best-effort.""" + + class _BrokenRedis: + def publish(self, *_args: Any, **_kwargs: Any) -> None: + raise RuntimeError("redis offline") + + with patch.object(inspector_events, "redis_client", _BrokenRedis()): + # No exception should escape. + inspector_events.publish_node_changed(workflow_run_id="run-1", node_id="agent-1", status="running") + + +# ────────────────────────────────────────────────────────────────────────────── +# Subscriber +# ────────────────────────────────────────────────────────────────────────────── + + +def _make_fake_pubsub(messages: list[dict[str, Any] | None]) -> MagicMock: + """Build a redis pubsub stub that replays ``messages`` then raises StopIteration.""" + pubsub = MagicMock() + it: Iterator[dict[str, Any] | None] = iter(messages) + pubsub.get_message.side_effect = lambda **_kwargs: next(it, None) + return pubsub + + +def test_subscribe_yields_heartbeat_then_real_message(): + """Idle ticks (``get_message`` returns None) surface as a sentinel; real + payloads decode to ``InspectorMessage`` instances.""" + payload = json.dumps( + {"kind": "node_changed", "workflow_run_id": "run-1", "node_id": "agent-1", "status": "succeeded"} + ) + fake_redis = MagicMock() + fake_redis.pubsub.return_value = _make_fake_pubsub( + [ + None, # heartbeat tick + {"data": payload.encode("utf-8")}, # bytes payload, real message + None, # heartbeat + ] + ) + with patch.object(inspector_events, "redis_client", fake_redis): + gen = inspector_events.subscribe("run-1", timeout_seconds=0.0) + first = next(gen) + second = next(gen) + third = next(gen) + + # First message is the heartbeat sentinel (both node_id and status are None). + assert first.node_id is None + assert first.status is None + # Second is the real one. + assert second.kind == "node_changed" + assert second.node_id == "agent-1" + assert second.status == "succeeded" + # Third is another heartbeat. + assert third.node_id is None + + +def test_subscribe_skips_malformed_payloads(): + fake_redis = MagicMock() + fake_redis.pubsub.return_value = _make_fake_pubsub( + [ + {"data": b"not json at all"}, + {"data": json.dumps({"kind": "node_changed", "workflow_run_id": "run-1"}).encode("utf-8")}, + ] + ) + with patch.object(inspector_events, "redis_client", fake_redis): + gen = inspector_events.subscribe("run-1", timeout_seconds=0.0) + msg = next(gen) + assert msg.kind == "node_changed" + assert msg.node_id is None + + +def test_subscribe_unsubscribes_on_teardown(): + fake_pubsub = _make_fake_pubsub([None]) + fake_redis = MagicMock() + fake_redis.pubsub.return_value = fake_pubsub + with patch.object(inspector_events, "redis_client", fake_redis): + gen = inspector_events.subscribe("run-1", timeout_seconds=0.0) + next(gen) + gen.close() + fake_pubsub.unsubscribe.assert_called_once_with("dify:inspector:workflow_run:run-1") + fake_pubsub.close.assert_called_once() + + +def test_subscribe_swallows_teardown_errors(): + """``unsubscribe`` / ``close`` failures must not propagate out of the + generator — they're best-effort cleanup.""" + fake_pubsub = MagicMock() + fake_pubsub.get_message.return_value = None + fake_pubsub.unsubscribe.side_effect = RuntimeError("redis offline") + fake_pubsub.close.side_effect = RuntimeError("close failed") + fake_redis = MagicMock() + fake_redis.pubsub.return_value = fake_pubsub + with patch.object(inspector_events, "redis_client", fake_redis): + gen = inspector_events.subscribe("run-1", timeout_seconds=0.0) + next(gen) + # The teardown path runs in ``finally``; closing the generator + # exercises it. No exception should escape. + gen.close() + + +def test_subscribe_skips_non_string_data_payloads(): + """``raw["data"]`` can be ``None`` / int / bytes — only str is decodable + and the rest are silently skipped.""" + fake_pubsub = MagicMock() + msgs: list[dict[str, Any] | None] = [ + {"data": None}, # missing payload + {"data": 12345}, # int payload (shouldn't happen, defensive) + { + "data": json.dumps( + {"kind": "node_changed", "workflow_run_id": "run-1", "node_id": "agent-1", "status": "running"} + ) + }, + ] + it = iter(msgs) + fake_pubsub.get_message.side_effect = lambda **_kw: next(it, None) + fake_redis = MagicMock() + fake_redis.pubsub.return_value = fake_pubsub + with patch.object(inspector_events, "redis_client", fake_redis): + gen = inspector_events.subscribe("run-1", timeout_seconds=0.0) + msg = next(gen) + assert msg.kind == "node_changed" + assert msg.node_id == "agent-1" diff --git a/api/tests/unit_tests/services/workflow/test_node_output_inspector_service.py b/api/tests/unit_tests/services/workflow/test_node_output_inspector_service.py new file mode 100644 index 0000000000..601a51e830 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_node_output_inspector_service.py @@ -0,0 +1,499 @@ +"""Unit tests for NodeOutputInspectorService (Stage 4 §8). + +The service reads from postgres and resolves agent v2 bindings; this suite +mocks ``session_factory`` and the binding resolver so we exercise the +view-construction logic without DB / network access. +""" + +from __future__ import annotations + +import json +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from models.agent_config_entities import ( + DeclaredArrayItem, + DeclaredOutputConfig, + DeclaredOutputType, +) +from models.enums import WorkflowRunTriggeredFrom +from services.workflow.node_output_inspector_service import ( + NodeOutputInspectorError, + NodeOutputInspectorService, + NodeOutputStatus, + NodeStatus, +) + +# ────────────────────────────────────────────────────────────────────────────── +# Fixtures +# ────────────────────────────────────────────────────────────────────────────── + + +def _app_model(*, tenant_id: str = "tenant-1", app_id: str = "app-1"): + return SimpleNamespace(tenant_id=tenant_id, id=app_id) + + +def _workflow_run( + *, + run_id: str = "run-1", + workflow_id: str = "workflow-1", + tenant_id: str = "tenant-1", + app_id: str = "app-1", + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, + status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING, + nodes: list[dict[str, Any]] | None = None, +): + return SimpleNamespace( + id=run_id, + workflow_id=workflow_id, + tenant_id=tenant_id, + app_id=app_id, + triggered_from=triggered_from, + status=status, + graph=json.dumps({"nodes": nodes or []}), + ) + + +def _execution( + *, + node_id: str, + node_type: str = "agent", + title: str = "", + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED, + outputs: dict[str, Any] | None = None, + execution_metadata: dict[str, Any] | None = None, + index: int = 1, + created_at: datetime | None = None, + finished_at: datetime | None = None, +): + return SimpleNamespace( + node_id=node_id, + node_type=node_type, + title=title or node_id, + status=status, + outputs=json.dumps(outputs) if outputs is not None else None, + execution_metadata=json.dumps(execution_metadata) if execution_metadata is not None else None, + index=index, + created_at=created_at or datetime.now(UTC), + finished_at=finished_at, + ) + + +def _agent_v2_node(*, node_id: str = "agent-node-1", title: str = "My Agent") -> dict[str, Any]: + return { + "id": node_id, + "data": {"type": "agent", "version": "2", "title": title}, + } + + +def _non_agent_node(*, node_id: str = "tool-node-1", node_type: str = "tool", title: str = "Slack") -> dict[str, Any]: + return { + "id": node_id, + "data": {"type": node_type, "title": title}, + } + + +def _patch_session( + *, + workflow_run: SimpleNamespace | None, + executions: list[SimpleNamespace] | None = None, +): + """Patch ``session_factory.create_session`` to return the configured rows. + + Returns a context manager that the test uses with ``with``. + """ + executions = executions or [] + mock_session = MagicMock() + mock_session.scalar.return_value = workflow_run + mock_session.scalars.return_value.all.return_value = executions + cm = MagicMock() + cm.__enter__.return_value = mock_session + cm.__exit__.return_value = False + return patch( + "services.workflow.node_output_inspector_service.session_factory.create_session", + return_value=cm, + ) + + +def _stub_binding_resolver(*, declared_outputs: list[DeclaredOutputConfig]): + """Build a fake ``WorkflowAgentBindingResolver`` whose ``.resolve`` returns + a binding with ``node_job_config_dict.declared_outputs``.""" + binding = SimpleNamespace( + id="binding-1", + node_job_config_dict={ + "workflow_prompt": "stub", + "declared_outputs": [o.model_dump() for o in declared_outputs], + }, + ) + bundle = SimpleNamespace(binding=binding, agent=None, snapshot=None) + resolver = MagicMock() + resolver.resolve.return_value = bundle + return resolver + + +def _make_service(declared_outputs: list[DeclaredOutputConfig] | None = None) -> NodeOutputInspectorService: + return NodeOutputInspectorService(binding_resolver=_stub_binding_resolver(declared_outputs=declared_outputs or [])) + + +# ────────────────────────────────────────────────────────────────────────────── +# 404 paths +# ────────────────────────────────────────────────────────────────────────────── + + +def test_snapshot_404_when_workflow_run_missing(): + service = _make_service() + with _patch_session(workflow_run=None): + with pytest.raises(NodeOutputInspectorError) as exc: + service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="missing") + assert exc.value.code == "workflow_run_not_found" + + +def test_snapshot_accepts_published_run_d1_lifted(): + """D-1 was lifted 2026-05-26: any ``triggered_from`` is now accepted.""" + service = _make_service() + run = _workflow_run( + nodes=[_agent_v2_node(node_id="agent-1")], + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + with _patch_session(workflow_run=run, executions=[]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + assert snapshot.workflow_run_id == "run-1" + assert [n.node_id for n in snapshot.node_outputs] == ["agent-1"] + + +def test_snapshot_accepts_webhook_triggered_run(): + """Webhook / schedule / plugin triggers are also published-side.""" + service = _make_service() + run = _workflow_run( + nodes=[_agent_v2_node(node_id="agent-1")], + triggered_from=WorkflowRunTriggeredFrom.WEBHOOK, + ) + with _patch_session(workflow_run=run, executions=[]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + assert snapshot.workflow_run_id == "run-1" + + +def test_node_detail_404_when_node_id_absent_from_graph(): + service = _make_service() + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + with _patch_session(workflow_run=run, executions=[]): + with pytest.raises(NodeOutputInspectorError) as exc: + service.node_detail(app_model=_app_model(), workflow_run_id="run-1", node_id="ghost") + assert exc.value.code == "node_not_in_workflow_run" + + +def test_output_preview_404_when_output_name_unknown(): + service = _make_service( + declared_outputs=[DeclaredOutputConfig(name="text", type=DeclaredOutputType.STRING)], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution(node_id="agent-1", outputs={"text": "hello"}) + with _patch_session(workflow_run=run, executions=[ex]): + with pytest.raises(NodeOutputInspectorError) as exc: + service.output_preview( + app_model=_app_model(), + workflow_run_id="run-1", + node_id="agent-1", + output_name="missing", + ) + assert exc.value.code == "node_output_not_declared" + + +# ────────────────────────────────────────────────────────────────────────────── +# Snapshot happy path +# ────────────────────────────────────────────────────────────────────────────── + + +def test_snapshot_status_pending_when_node_has_no_execution(): + service = _make_service( + declared_outputs=[DeclaredOutputConfig(name="text", type=DeclaredOutputType.STRING)], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + with _patch_session(workflow_run=run, executions=[]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + + assert len(snapshot.node_outputs) == 1 + node = snapshot.node_outputs[0] + assert node.node_status == NodeStatus.IDLE + assert node.outputs[0].status == NodeOutputStatus.PENDING + + +def test_snapshot_status_running(): + service = _make_service( + declared_outputs=[DeclaredOutputConfig(name="text", type=DeclaredOutputType.STRING)], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution(node_id="agent-1", status=WorkflowNodeExecutionStatus.RUNNING) + with _patch_session(workflow_run=run, executions=[ex]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + assert snapshot.node_outputs[0].node_status == NodeStatus.RUNNING + assert snapshot.node_outputs[0].outputs[0].status == NodeOutputStatus.RUNNING + + +def test_snapshot_status_failed_node_marks_all_outputs_failed(): + service = _make_service( + declared_outputs=[ + DeclaredOutputConfig(name="a", type=DeclaredOutputType.STRING), + DeclaredOutputConfig(name="b", type=DeclaredOutputType.NUMBER), + ], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution(node_id="agent-1", status=WorkflowNodeExecutionStatus.FAILED) + with _patch_session(workflow_run=run, executions=[ex]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + statuses = {o.name: o.status for o in snapshot.node_outputs[0].outputs} + assert statuses == {"a": NodeOutputStatus.FAILED, "b": NodeOutputStatus.FAILED} + + +def test_snapshot_status_ready_when_outputs_present_and_no_failure_metadata(): + service = _make_service( + declared_outputs=[DeclaredOutputConfig(name="text", type=DeclaredOutputType.STRING)], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution(node_id="agent-1", outputs={"text": "hello"}) + with _patch_session(workflow_run=run, executions=[ex]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + output = snapshot.node_outputs[0].outputs[0] + assert output.status == NodeOutputStatus.READY + assert output.value_preview == "hello" + + +def test_snapshot_marks_type_check_failure(): + service = _make_service( + declared_outputs=[DeclaredOutputConfig(name="text", type=DeclaredOutputType.STRING)], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution( + node_id="agent-1", + outputs={"text": "ok"}, + execution_metadata={ + "output_type_check": { + "passed": False, + "results": [{"name": "text", "type": "string", "status": "type_check_failed", "reason": "wrong shape"}], + } + }, + ) + with _patch_session(workflow_run=run, executions=[ex]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + output = snapshot.node_outputs[0].outputs[0] + assert output.status == NodeOutputStatus.TYPE_CHECK_FAILED + assert output.type_check is not None + assert output.type_check.passed is False + assert output.type_check.reason == "wrong shape" + + +def test_snapshot_marks_output_check_failure_when_type_check_passed(): + service = _make_service( + declared_outputs=[ + DeclaredOutputConfig( + name="report", + type=DeclaredOutputType.FILE, + ) + ], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution( + node_id="agent-1", + outputs={"report": {"file_id": "550e8400-e29b-41d4-a716-446655440000"}}, + execution_metadata={ + "output_type_check": {"passed": True, "results": [{"name": "report", "status": "ready"}]}, + "output_check": { + "passed": False, + "results": [{"name": "report", "status": "failed", "reason": "benchmark mismatch"}], + }, + }, + ) + with ( + _patch_session(workflow_run=run, executions=[ex]), + patch( + "services.workflow.node_output_inspector_service.file_helpers.get_signed_file_url", + return_value="https://signed.example/x", + ), + ): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + output = snapshot.node_outputs[0].outputs[0] + assert output.status == NodeOutputStatus.OUTPUT_CHECK_FAILED + assert output.output_check is not None + assert output.output_check.passed is False + assert output.output_check.reason == "benchmark mismatch" + + +def test_snapshot_marks_not_produced_when_declared_output_missing_from_payload(): + service = _make_service( + declared_outputs=[ + DeclaredOutputConfig(name="text", type=DeclaredOutputType.STRING), + DeclaredOutputConfig(name="optional_meta", type=DeclaredOutputType.OBJECT, required=False), + ], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution(node_id="agent-1", outputs={"text": "hi"}) # optional_meta missing + with _patch_session(workflow_run=run, executions=[ex]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + statuses = {o.name: o.status for o in snapshot.node_outputs[0].outputs} + assert statuses == {"text": NodeOutputStatus.READY, "optional_meta": NodeOutputStatus.NOT_PRODUCED} + + +# ────────────────────────────────────────────────────────────────────────────── +# Non-agent node — outputs inferred from execution payload +# ────────────────────────────────────────────────────────────────────────────── + + +def test_non_agent_node_outputs_inferred_from_payload_keys(): + service = _make_service() + run = _workflow_run(nodes=[_non_agent_node(node_id="tool-1", node_type="tool")]) + ex = _execution( + node_id="tool-1", + node_type="tool", + outputs={"message": "sent", "thread_ts": "1234"}, + ) + with _patch_session(workflow_run=run, executions=[ex]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + output_names = sorted(o.name for o in snapshot.node_outputs[0].outputs) + assert output_names == ["message", "thread_ts"] + # All inferred outputs should have ``type=None`` since we don't know the + # schema yet. + assert all(o.type is None for o in snapshot.node_outputs[0].outputs) + + +# ────────────────────────────────────────────────────────────────────────────── +# File preview / signed URL +# ────────────────────────────────────────────────────────────────────────────── + + +def test_file_output_preview_includes_signed_url(): + service = _make_service( + declared_outputs=[ + DeclaredOutputConfig(name="report", type=DeclaredOutputType.FILE), + ], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + file_payload = {"file_id": "550e8400-e29b-41d4-a716-446655440000", "filename": "x.pdf"} + ex = _execution(node_id="agent-1", outputs={"report": file_payload}) + with ( + _patch_session(workflow_run=run, executions=[ex]), + patch( + "services.workflow.node_output_inspector_service.file_helpers.get_signed_file_url", + return_value="https://signed.example/x.pdf", + ), + ): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + preview_value = snapshot.node_outputs[0].outputs[0].value_preview + assert isinstance(preview_value, dict) + assert preview_value["preview_url"] == "https://signed.example/x.pdf" + assert preview_value["filename"] == "x.pdf" + + +def test_file_output_preview_endpoint_returns_full_value_with_signed_url(): + service = _make_service( + declared_outputs=[ + DeclaredOutputConfig(name="report", type=DeclaredOutputType.FILE), + ], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + file_payload = {"file_id": "550e8400-e29b-41d4-a716-446655440000", "filename": "x.pdf"} + ex = _execution(node_id="agent-1", outputs={"report": file_payload}) + with ( + _patch_session(workflow_run=run, executions=[ex]), + patch( + "services.workflow.node_output_inspector_service.file_helpers.get_signed_file_url", + return_value="https://signed.example/x.pdf", + ), + ): + preview = service.output_preview( + app_model=_app_model(), + workflow_run_id="run-1", + node_id="agent-1", + output_name="report", + ) + assert preview.output_name == "report" + assert preview.status == NodeOutputStatus.READY + assert isinstance(preview.value, dict) + assert preview.value["preview_url"] == "https://signed.example/x.pdf" + + +# ────────────────────────────────────────────────────────────────────────────── +# Retry / metadata +# ────────────────────────────────────────────────────────────────────────────── + + +def test_retried_count_pulled_from_attempt_metadata(): + service = _make_service( + declared_outputs=[DeclaredOutputConfig(name="text", type=DeclaredOutputType.STRING)], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution( + node_id="agent-1", + outputs={"text": "ok"}, + execution_metadata={"attempt": 2}, + ) + with _patch_session(workflow_run=run, executions=[ex]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + assert snapshot.node_outputs[0].outputs[0].retried == 2 + + +# ────────────────────────────────────────────────────────────────────────────── +# Latest-execution-per-node grouping +# ────────────────────────────────────────────────────────────────────────────── + + +def test_keeps_latest_execution_per_node_by_index(): + """When a node has multiple executions (retries / iterations) keep the + canonical one — the row with the highest ``index``.""" + service = _make_service( + declared_outputs=[DeclaredOutputConfig(name="text", type=DeclaredOutputType.STRING)], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + older = _execution(node_id="agent-1", outputs={"text": "old"}, index=1) + newer = _execution(node_id="agent-1", outputs={"text": "new"}, index=5) + with _patch_session(workflow_run=run, executions=[older, newer]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + assert snapshot.node_outputs[0].outputs[0].value_preview == "new" + + +# ────────────────────────────────────────────────────────────────────────────── +# Array item declarations round-trip correctly +# ────────────────────────────────────────────────────────────────────────────── + + +def test_array_typed_output_with_array_item_renders_correctly(): + service = _make_service( + declared_outputs=[ + DeclaredOutputConfig( + name="files", + type=DeclaredOutputType.ARRAY, + array_item=DeclaredArrayItem(type=DeclaredOutputType.FILE), + ) + ], + ) + run = _workflow_run(nodes=[_agent_v2_node(node_id="agent-1")]) + ex = _execution(node_id="agent-1", outputs={"files": []}) + with _patch_session(workflow_run=run, executions=[ex]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + output = snapshot.node_outputs[0].outputs[0] + assert output.type == DeclaredOutputType.ARRAY + + +# ────────────────────────────────────────────────────────────────────────────── +# Graph parsing edge cases +# ────────────────────────────────────────────────────────────────────────────── + + +def test_unparseable_graph_blob_yields_empty_snapshot_not_500(): + service = _make_service() + run = SimpleNamespace( + id="run-1", + workflow_id="workflow-1", + tenant_id="tenant-1", + app_id="app-1", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=WorkflowExecutionStatus.RUNNING, + graph="{not valid json", + ) + with _patch_session(workflow_run=run, executions=[]): + snapshot = service.snapshot_workflow_run(app_model=_app_model(), workflow_run_id="run-1") + assert snapshot.node_outputs == [] diff --git a/api/tests/unit_tests/tasks/test_delete_account_task.py b/api/tests/unit_tests/tasks/test_delete_account_task.py deleted file mode 100644 index f949c13158..0000000000 --- a/api/tests/unit_tests/tasks/test_delete_account_task.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Unit tests for delete_account_task. - -Covers: -- Billing enabled with existing account: calls billing and sends success email -- Billing disabled with existing account: skips billing, sends success email -- Account not found: still calls billing when enabled, does not send email -- Billing deletion raises: logs and re-raises, no email -""" - -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest - -from tasks.delete_account_task import delete_account_task - - -@pytest.fixture -def mock_db_session(): - """Mock session via session_factory.create_session().""" - with patch("tasks.delete_account_task.session_factory") as mock_sf: - session = MagicMock() - cm = MagicMock() - cm.__enter__.return_value = session - cm.__exit__.return_value = None - mock_sf.create_session.return_value = cm - - yield session - - -@pytest.fixture -def mock_deps(): - """Patch external dependencies: BillingService and send_deletion_success_task.""" - with ( - patch("tasks.delete_account_task.BillingService") as mock_billing, - patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task, - ): - # ensure .delay exists on the mail task - mock_mail_task.delay = MagicMock() - yield { - "billing": mock_billing, - "mail_task": mock_mail_task, - } - - -def _set_account_found(mock_db_session, email: str = "user@example.com"): - account = SimpleNamespace(email=email) - mock_db_session.scalar.return_value = account - return account - - -def _set_account_missing(mock_db_session): - mock_db_session.scalar.return_value = None - - -class TestDeleteAccountTask: - def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps): - # Arrange - account_id = "acc-123" - account = _set_account_found(mock_db_session, email="a@b.com") - - # Enable billing - with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): - # Act - delete_account_task(account_id) - - # Assert - mock_deps["billing"].delete_account.assert_called_once_with(account_id) - mock_deps["mail_task"].delay.assert_called_once_with(account.email) - - def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps): - # Arrange - account_id = "acc-456" - account = _set_account_found(mock_db_session, email="x@y.com") - - # Disable billing - with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False): - # Act - delete_account_task(account_id) - - # Assert - mock_deps["billing"].delete_account.assert_not_called() - mock_deps["mail_task"].delay.assert_called_once_with(account.email) - - def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog): - # Arrange - account_id = "missing-id" - _set_account_missing(mock_db_session) - - # Enable billing - with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): - # Act - delete_account_task(account_id) - - # Assert - mock_deps["billing"].delete_account.assert_called_once_with(account_id) - mock_deps["mail_task"].delay.assert_not_called() - # Optional: verify log contains not found message - assert any("not found" in rec.getMessage().lower() for rec in caplog.records) - - def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps): - # Arrange - account_id = "acc-err" - _set_account_found(mock_db_session, email="err@ex.com") - mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down") - - # Enable billing - with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): - # Act & Assert - with pytest.raises(RuntimeError): - delete_account_task(account_id) - - # Ensure email was not sent - mock_deps["mail_task"].delay.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py index 37b7a85451..f8f9ec9971 100644 --- a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py @@ -54,7 +54,7 @@ def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: py monkeypatch.setattr( task_module.FeatureService, "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + lambda _tenant_id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=True), ) jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)] monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs) @@ -78,7 +78,7 @@ def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch monkeypatch.setattr( task_module.FeatureService, "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), + lambda _tenant_id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=False), ) monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: []) @@ -109,7 +109,7 @@ def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: py monkeypatch.setattr( task_module.FeatureService, "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + lambda _tenant_id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=True), ) monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job]) monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool) @@ -142,7 +142,7 @@ def test_dispatch_human_input_email_task_sanitizes_subject( monkeypatch.setattr( task_module.FeatureService, "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + lambda _tenant_id, **_kwargs: SimpleNamespace(human_input_email_delivery_enabled=True), ) monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job]) monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: None) diff --git a/api/uv.lock b/api/uv.lock index 5e8792207e..7a54f130a1 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -595,16 +595,16 @@ wheels = [ [[package]] name = "boto3" -version = "1.43.10" +version = "1.43.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ff/27/ae1a71e945ce7bde39b0677b252fe7d8a0ad7fa3d6b724d78b81469c08fe/boto3-1.43.10.tar.gz", hash = "sha256:27342e5d5f6170fcc8d1e21cdd939af2448d58ac56b08d494250eaad998e30c7", size = 113159, upload-time = "2026-05-18T20:42:34.454Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/4b/616367e871ce3f1cb3e8545a97736b6331b9fb081497f2d44c5b2aa6959d/boto3-1.43.14.tar.gz", hash = "sha256:5c0a994b3182061ee101812e721100717a4d664f9f4ceaf4a86b6d032ce9fc2d", size = 113142, upload-time = "2026-05-22T19:28:47.861Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/1b/439234598449f846b17333e67ec63c3dd8f8880c13de9089383b4bab58c3/boto3-1.43.10-py3-none-any.whl", hash = "sha256:83918184d95967e4c6e9ed1e9a2f58250b291e6ea2cb847ab0825d52596b39e5", size = 140534, upload-time = "2026-05-18T20:42:32.009Z" }, + { url = "https://files.pythonhosted.org/packages/cb/00/59cb9329c18e2d3aa23062ceaa87d065f2e81e7d2931df24d64e9a7815aa/boto3-1.43.14-py3-none-any.whl", hash = "sha256:574335744656cfed0b362a0a0467aaf2eb2bf15526edcd02d31d3c661f4b09e4", size = 140536, upload-time = "2026-05-22T19:28:46.49Z" }, ] [[package]] @@ -627,16 +627,16 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.43.10" +version = "1.43.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e2/4e/c127dd0628c551f10cb890e279a9c0e367523b880c4cd3e81a1e76886174/botocore-1.43.10.tar.gz", hash = "sha256:2f4af585b41dbccdfc9f49677d7bd72d713a12ef89a1dc9c8538a927649498bf", size = 15365344, upload-time = "2026-05-18T20:42:21.562Z" } +sdist = { url = "https://files.pythonhosted.org/packages/78/3c/798d2f7deb118241930c7c6bcfb0b970d3f0245bf580700663199aeed2c3/botocore-1.43.14.tar.gz", hash = "sha256:b9e500737e43d2f147c9d4e23b54360335e77d4c0ba90a318f51b65e06cb8516", size = 15382604, upload-time = "2026-05-22T19:28:36.363Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a6/0e/41f64d6c267edf03f4fe8f461edc4c644243e77c8d5a1fef1e0166ac4ed0/botocore-1.43.10-py3-none-any.whl", hash = "sha256:8a0176d8c2f8bebe95d4f923a824a1ace04b02f360e220681c388e097f32c3b6", size = 15043571, upload-time = "2026-05-18T20:42:16.664Z" }, + { url = "https://files.pythonhosted.org/packages/27/7e/6e64821077cd2efc4aa51b7d638fb6d48e1c7c450201c529fbaf1de8bfd3/botocore-1.43.14-py3-none-any.whl", hash = "sha256:1f4a2a95ea78c10398e78431e98c1fe47adb54a7b10a32975144c1f541186658", size = 15061424, upload-time = "2026-05-22T19:28:32.682Z" }, ] [[package]] @@ -1612,7 +1612,7 @@ requires-dist = [ { name = "aliyun-log-python-sdk", specifier = "==0.9.44" }, { name = "azure-identity", specifier = ">=1.25.3,<2.0.0" }, { name = "bleach", specifier = ">=6.3.0,<7.0.0" }, - { name = "boto3", specifier = ">=1.43.10,<2.0.0" }, + { name = "boto3", specifier = ">=1.43.14,<2.0.0" }, { name = "celery", specifier = ">=5.6.3,<6.0.0" }, { name = "croniter", specifier = ">=6.2.2,<7.0.0" }, { name = "dify-agent", directory = "../dify-agent" }, diff --git a/cli/AGENTS.md b/cli/AGENTS.md index 0d579af2c7..96df6f2bdc 100644 --- a/cli/AGENTS.md +++ b/cli/AGENTS.md @@ -47,7 +47,7 @@ Layer rules: - Commands thin shells. Use `this.authedCtx(opts)` for bearer context; delegate to domain function. - Domain receives deps via options; never imports `src/framework/`. - Only `src/http/client.ts` and `src/api/*` import ky at runtime; elsewhere use `import type { KyInstance }`. -- `process.*` lives in `src/io/`, `src/config/dir.ts`, `src/util/browser.ts`. Nowhere else. +- `process.*` lives in `src/io/`, `src/store/dir.ts`, `src/util/browser.ts`. Nowhere else. - No circular imports. `types/` pure leaf. ## Dev commands diff --git a/cli/ARD.md b/cli/ARD.md index b8813fe920..de7a4b359f 100644 --- a/cli/ARD.md +++ b/cli/ARD.md @@ -103,7 +103,7 @@ import { ErrorCode } from '../../errors/codes.js' throw new BaseError({ code: ErrorCode.UsageMissingArg, message: 'workspace id required', - hint: 'pass --workspace or run \'difyctl auth use \'', + hint: 'pass --workspace or run \'difyctl use workspace \'', }) ``` diff --git a/cli/package.json b/cli/package.json index 1b10986d7f..59286c2880 100644 --- a/cli/package.json +++ b/cli/package.json @@ -50,6 +50,7 @@ "eventsource-parser": "catalog:", "js-yaml": "catalog:", "ky": "catalog:", + "lockfile": "catalog:", "open": "catalog:", "ora": "catalog:", "picocolors": "catalog:", @@ -60,6 +61,7 @@ "@dify/tsconfig": "workspace:*", "@hono/node-server": "catalog:", "@types/js-yaml": "catalog:", + "@types/lockfile": "catalog:", "@types/node": "catalog:", "@vitest/coverage-v8": "catalog:", "eslint": "catalog:", diff --git a/cli/src/api/account-sessions.ts b/cli/src/api/account-sessions.ts index 102927bf8e..83950c9bde 100644 --- a/cli/src/api/account-sessions.ts +++ b/cli/src/api/account-sessions.ts @@ -8,8 +8,15 @@ export class AccountSessionsClient { this.http = http } - async list(): Promise { - return this.http.get('account/sessions').json() + async list(q?: { page?: number, limit?: number }): Promise { + const params = new URLSearchParams() + if (q?.page !== undefined) + params.set('page', String(q.page)) + if (q?.limit !== undefined) + params.set('limit', String(q.limit)) + const hasParams = Array.from(params.keys()).length > 0 + const opts = hasParams ? { searchParams: params } : undefined + return this.http.get('account/sessions', opts).json() } async revoke(sessionId: string): Promise { diff --git a/cli/src/api/app-meta.test.ts b/cli/src/api/app-meta.test.ts index b1b2cf00ab..1ec9e4b698 100644 --- a/cli/src/api/app-meta.test.ts +++ b/cli/src/api/app-meta.test.ts @@ -6,6 +6,8 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { startMock } from '../../test/fixtures/dify-mock/server.js' import { loadAppInfoCache } from '../cache/app-info.js' import { createClient } from '../http/client.js' +import { CACHE_APP_INFO, cachePath } from '../store/manager.js' +import { YamlStore } from '../store/store.js' import { FieldInfo, FieldParameters } from '../types/app-meta.js' import { AppMetaClient } from './app-meta.js' import { AppsClient } from './apps.js' @@ -23,7 +25,7 @@ describe('AppMetaClient', () => { }) it('cache miss → fetch → populate; warm hit skips network', async () => { - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) const apps = new AppsClient(createClient({ host: mock.url, bearer: 'dfoa_test' })) const spy = vi.spyOn(apps, 'describe') const client = new AppMetaClient({ apps, host: mock.url, cache }) @@ -38,7 +40,7 @@ describe('AppMetaClient', () => { }) it('slim hit + full request triggers fresh fetch + merges', async () => { - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) const apps = new AppsClient(createClient({ host: mock.url, bearer: 'dfoa_test' })) const spy = vi.spyOn(apps, 'describe') const client = new AppMetaClient({ apps, host: mock.url, cache }) @@ -52,7 +54,7 @@ describe('AppMetaClient', () => { }) it('expired cache entry refetches', async () => { - const cache = await loadAppInfoCache({ configDir: dir, ttlMs: 100, now: () => new Date('2026-05-09T00:00:00Z') }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)), ttlMs: 100, now: () => new Date('2026-05-09T00:00:00Z') }) const apps = new AppsClient(createClient({ host: mock.url, bearer: 'dfoa_test' })) const spy = vi.spyOn(apps, 'describe') const client = new AppMetaClient({ apps, host: mock.url, cache, now: () => new Date('2026-05-09T00:00:00Z') }) @@ -66,7 +68,7 @@ describe('AppMetaClient', () => { }) it('invalidate forces next get to fetch', async () => { - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) const apps = new AppsClient(createClient({ host: mock.url, bearer: 'dfoa_test' })) const spy = vi.spyOn(apps, 'describe') const client = new AppMetaClient({ apps, host: mock.url, cache }) diff --git a/cli/src/api/members.test.ts b/cli/src/api/members.test.ts new file mode 100644 index 0000000000..5442e296c3 --- /dev/null +++ b/cli/src/api/members.test.ts @@ -0,0 +1,280 @@ +import type { AddressInfo } from 'node:net' +import { Buffer } from 'node:buffer' +import * as http from 'node:http' +import { afterEach, describe, expect, it } from 'vitest' +import { isBaseError } from '../errors/base.js' +import { createClient } from '../http/client.js' +import { MembersClient } from './members.js' + +type StubServer = { + url: string + lastRequest: { method?: string, url?: string, body?: string } + stop: () => Promise +} + +function jsonResponder( + status: number, + body: unknown, + captured: StubServer['lastRequest'], +): http.RequestListener { + return (req, res) => { + captured.method = req.method + captured.url = req.url + const chunks: Buffer[] = [] + req.on('data', c => chunks.push(c)) + req.on('end', () => { + captured.body = Buffer.concat(chunks).toString('utf8') + const payload = JSON.stringify(body) + res.writeHead(status, { + 'content-type': 'application/json', + 'content-length': Buffer.byteLength(payload), + }) + res.end(payload) + }) + } +} + +function startServer(handler: http.RequestListener): Promise { + const captured: StubServer['lastRequest'] = {} + return new Promise((resolve, reject) => { + const server = http.createServer((req, res) => handler(req, res)) + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo + resolve({ + url: `http://127.0.0.1:${addr.port}`, + lastRequest: captured, + stop: () => + new Promise((res, rej) => server.close(err => (err ? rej(err) : res()))), + }) + }) + server.on('error', reject) + }) +} + +function makeClient(host: string): MembersClient { + return new MembersClient(createClient({ host, bearer: 'dfoa_test' })) +} + +describe('MembersClient.list', () => { + let stub: StubServer + + afterEach(async () => { + await stub?.stop() + }) + + it('GETs /workspaces//members and returns parsed envelope', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer( + jsonResponder( + 200, + { + page: 1, + limit: 20, + total: 1, + has_more: false, + data: [ + { id: 'm-1', name: 'Mia', email: 'mia@e.com', role: 'admin', status: 'active' }, + ], + }, + captured, + ), + ) + stub.lastRequest = captured + + const result = await makeClient(stub.url).list('ws-1') + expect(captured.method).toBe('GET') + expect(captured.url).toBe('/openapi/v1/workspaces/ws-1/members') + expect(result.data[0].email).toBe('mia@e.com') + }) + + it('URL-encodes workspace id', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer( + jsonResponder(200, { page: 1, limit: 20, total: 0, has_more: false, data: [] }, captured), + ) + stub.lastRequest = captured + + await makeClient(stub.url).list('ws with space') + expect(captured.url).toBe('/openapi/v1/workspaces/ws%20with%20space/members') + }) + + it('forwards page/limit as query params', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer( + jsonResponder(200, { page: 2, limit: 50, total: 0, has_more: false, data: [] }, captured), + ) + stub.lastRequest = captured + + await makeClient(stub.url).list('ws-1', { page: 2, limit: 50 }) + expect(captured.url).toBe('/openapi/v1/workspaces/ws-1/members?page=2&limit=50') + }) + + it('propagates server 403 as HTTPError', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer(jsonResponder(403, { error: 'forbidden' }, captured)) + + await expect(makeClient(stub.url).list('ws-1')).rejects.toSatisfy( + err => isBaseError(err) && err.httpStatus === 403, + ) + }) + + it('propagates 404 as classified BaseError', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer(jsonResponder(404, { error: 'not found' }, captured)) + + await expect(makeClient(stub.url).list('ws-missing')).rejects.toSatisfy( + err => isBaseError(err) && err.httpStatus === 404, + ) + }) +}) + +describe('MembersClient.invite', () => { + let stub: StubServer + + afterEach(async () => { + await stub?.stop() + }) + + it('POSTs JSON body and returns parsed invite response', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer( + jsonResponder( + 201, + { + result: 'success', + email: 'new@e.com', + role: 'normal', + member_id: 'acct-9', + invite_url: 'https://console.example.com/activate?email=new&token=tok', + tenant_id: 'ws-1', + }, + captured, + ), + ) + stub.lastRequest = captured + + const result = await makeClient(stub.url).invite('ws-1', { + email: 'new@e.com', + role: 'normal', + }) + expect(captured.method).toBe('POST') + expect(captured.url).toBe('/openapi/v1/workspaces/ws-1/members') + expect(JSON.parse(captured.body ?? '{}')).toEqual({ + email: 'new@e.com', + role: 'normal', + }) + expect(result.member_id).toBe('acct-9') + expect(result.invite_url).toContain('token=tok') + }) + + it('propagates 400 (already in tenant) as HTTPError', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer(jsonResponder(400, { error: 'already in tenant' }, captured)) + + await expect( + makeClient(stub.url).invite('ws-1', { email: 'u@e.com', role: 'normal' }), + ).rejects.toSatisfy(err => isBaseError(err) && err.httpStatus === 400) + }) +}) + +describe('MembersClient.remove', () => { + let stub: StubServer + + afterEach(async () => { + await stub?.stop() + }) + + it('DELETEs member by id and returns success', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer(jsonResponder(200, { result: 'success' }, captured)) + stub.lastRequest = captured + + const result = await makeClient(stub.url).remove('ws-1', 'm-1') + expect(captured.method).toBe('DELETE') + expect(captured.url).toBe('/openapi/v1/workspaces/ws-1/members/m-1') + expect(result.result).toBe('success') + }) + + it('propagates 400 (cannot operate self / cannot remove owner)', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer(jsonResponder(400, { error: 'cannot operate self' }, captured)) + + await expect(makeClient(stub.url).remove('ws-1', 'm-1')).rejects.toSatisfy( + err => isBaseError(err) && err.httpStatus === 400, + ) + }) +}) + +describe('MembersClient.updateRole', () => { + let stub: StubServer + + afterEach(async () => { + await stub?.stop() + }) + + it('PUTs role payload to /role subresource', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer(jsonResponder(200, { result: 'success' }, captured)) + stub.lastRequest = captured + + const result = await makeClient(stub.url).updateRole('ws-1', 'm-1', { role: 'admin' }) + expect(captured.method).toBe('PUT') + expect(captured.url).toBe('/openapi/v1/workspaces/ws-1/members/m-1/role') + expect(JSON.parse(captured.body ?? '{}')).toEqual({ role: 'admin' }) + expect(result.result).toBe('success') + }) + + it('propagates 400 (admin cannot demote owner)', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer(jsonResponder(400, { error: 'no permission' }, captured)) + + await expect( + makeClient(stub.url).updateRole('ws-1', 'm-1', { role: 'admin' }), + ).rejects.toSatisfy(err => isBaseError(err) && err.httpStatus === 400) + }) +}) + +describe('WorkspacesClient.switch (integration with stub)', () => { + let stub: StubServer + + afterEach(async () => { + await stub?.stop() + }) + + it('POSTs /workspaces//switch and returns workspace detail', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer( + jsonResponder( + 200, + { + id: 'ws-1', + name: 'Workspace 1', + role: 'owner', + status: 'normal', + current: true, + created_at: '2026-05-18T00:00:00Z', + }, + captured, + ), + ) + stub.lastRequest = captured + + const { WorkspacesClient } = await import('./workspaces.js') + const client = new WorkspacesClient(createClient({ host: stub.url, bearer: 'dfoa_test' })) + const result = await client.switch('ws-1') + expect(captured.method).toBe('POST') + expect(captured.url).toBe('/openapi/v1/workspaces/ws-1/switch') + expect(result.current).toBe(true) + }) + + it('propagates 404 (non-member)', async () => { + const captured: StubServer['lastRequest'] = {} + stub = await startServer(jsonResponder(404, { error: 'not found' }, captured)) + + const { WorkspacesClient } = await import('./workspaces.js') + const client = new WorkspacesClient(createClient({ host: stub.url, bearer: 'dfoa_test' })) + await expect(client.switch('ws-x')).rejects.toSatisfy( + err => isBaseError(err) && err.httpStatus === 404, + ) + }) +}) diff --git a/cli/src/api/members.ts b/cli/src/api/members.ts new file mode 100644 index 0000000000..1152ae1474 --- /dev/null +++ b/cli/src/api/members.ts @@ -0,0 +1,61 @@ +import type { + MemberActionResponse, + MemberInvitePayload, + MemberInviteResponse, + MemberListResponse, + MemberRoleUpdatePayload, +} from '@dify/contracts/api/openapi/types.gen' +import type { KyInstance } from 'ky' + +/** + * Thin client for /openapi/v1/workspaces//members. + * + * Errors are surfaced as ky HTTPErrors with the server's status code + * (400/403/404/422). The CLI's AuthedCommand base layer maps those to + * user-visible messages — clients never swallow status codes here. + */ +export class MembersClient { + private readonly http: KyInstance + + constructor(http: KyInstance) { + this.http = http + } + + async list(workspaceId: string, q?: { page?: number, limit?: number }): Promise { + const params = new URLSearchParams() + if (q?.page !== undefined) + params.set('page', String(q.page)) + if (q?.limit !== undefined) + params.set('limit', String(q.limit)) + const hasParams = Array.from(params.keys()).length > 0 + const opts = hasParams ? { searchParams: params } : undefined + return this.http + .get(`workspaces/${encodeURIComponent(workspaceId)}/members`, opts) + .json() + } + + async invite(workspaceId: string, payload: MemberInvitePayload): Promise { + return this.http + .post(`workspaces/${encodeURIComponent(workspaceId)}/members`, { json: payload }) + .json() + } + + async remove(workspaceId: string, memberId: string): Promise { + return this.http + .delete(`workspaces/${encodeURIComponent(workspaceId)}/members/${encodeURIComponent(memberId)}`) + .json() + } + + async updateRole( + workspaceId: string, + memberId: string, + payload: MemberRoleUpdatePayload, + ): Promise { + return this.http + .put( + `workspaces/${encodeURIComponent(workspaceId)}/members/${encodeURIComponent(memberId)}/role`, + { json: payload }, + ) + .json() + } +} diff --git a/cli/src/api/workspaces.ts b/cli/src/api/workspaces.ts index a3feac23d0..f497ae25db 100644 --- a/cli/src/api/workspaces.ts +++ b/cli/src/api/workspaces.ts @@ -1,4 +1,4 @@ -import type { WorkspaceListResponse } from '@dify/contracts/api/openapi/types.gen' +import type { WorkspaceDetailResponse, WorkspaceListResponse } from '@dify/contracts/api/openapi/types.gen' import type { KyInstance } from 'ky' export class WorkspacesClient { @@ -11,4 +11,19 @@ export class WorkspacesClient { async list(): Promise { return this.http.get('workspaces').json() } + + /** + * Server-side workspace switch via OpenAPI POST + * `/workspaces/{id}/switch` — the bearer-authed equivalent of the + * console's POST `/workspaces/switch`. The server updates the caller's + * `current` tenant_account_join row. Callers MUST refresh their local + * `hosts.yml` only after this resolves — never fall back to a local + * write if the request fails, or `hosts.yml` will drift from the + * server's state. + */ + async switch(workspaceId: string): Promise { + return this.http + .post(`workspaces/${encodeURIComponent(workspaceId)}/switch`) + .json() + } } diff --git a/cli/src/auth/file-backend.test.ts b/cli/src/auth/file-backend.test.ts index 65ee66f6a9..e633d1d724 100644 --- a/cli/src/auth/file-backend.test.ts +++ b/cli/src/auth/file-backend.test.ts @@ -2,7 +2,7 @@ import { mkdtemp, rm, stat, writeFile } from 'node:fs/promises' import { tmpdir } from 'node:os' import { join } from 'node:path' import { afterEach, beforeEach, describe, expect, it } from 'vitest' -import { FILE_PERM } from '../config/dir.js' +import { FILE_PERM } from '../store/dir.js' import { FileBackend, TOKENS_FILE_NAME } from './file-backend.js' describe('FileBackend', () => { diff --git a/cli/src/auth/file-backend.ts b/cli/src/auth/file-backend.ts index 49bf4d44ed..0f8c2280c9 100644 --- a/cli/src/auth/file-backend.ts +++ b/cli/src/auth/file-backend.ts @@ -2,7 +2,7 @@ import type { TokenStore } from './store.js' import { mkdir, readFile, rename, stat, unlink, writeFile } from 'node:fs/promises' import { join } from 'node:path' import yaml from 'js-yaml' -import { DIR_PERM, FILE_PERM } from '../config/dir.js' +import { DIR_PERM, FILE_PERM } from '../store/dir.js' export const TOKENS_FILE_NAME = 'tokens.yml' diff --git a/cli/src/auth/hosts.test.ts b/cli/src/auth/hosts.test.ts index 2bc1b2fea9..9f1c50fb25 100644 --- a/cli/src/auth/hosts.test.ts +++ b/cli/src/auth/hosts.test.ts @@ -2,7 +2,7 @@ import { mkdtemp, readFile, rm, stat, writeFile } from 'node:fs/promises' import { tmpdir } from 'node:os' import { join } from 'node:path' import { afterEach, beforeEach, describe, expect, it } from 'vitest' -import { FILE_PERM } from '../config/dir.js' +import { FILE_PERM } from '../store/dir.js' import { HOSTS_FILE_NAME, HostsBundleSchema, loadHosts, saveHosts } from './hosts.js' describe('HostsBundleSchema', () => { diff --git a/cli/src/auth/hosts.ts b/cli/src/auth/hosts.ts index fc90b3238c..f6504dd06c 100644 --- a/cli/src/auth/hosts.ts +++ b/cli/src/auth/hosts.ts @@ -2,7 +2,7 @@ import { mkdir, readFile, rename, unlink, writeFile } from 'node:fs/promises' import { join } from 'node:path' import yaml from 'js-yaml' import { z } from 'zod' -import { DIR_PERM, FILE_PERM } from '../config/dir.js' +import { DIR_PERM, FILE_PERM } from '../store/dir.js' export const HOSTS_FILE_NAME = 'hosts.yml' diff --git a/cli/src/cache/app-info.test.ts b/cli/src/cache/app-info.test.ts index c562519790..6fcf53cc2f 100644 --- a/cli/src/cache/app-info.test.ts +++ b/cli/src/cache/app-info.test.ts @@ -2,9 +2,17 @@ import type { AppMeta } from '../types/app-meta.js' import { mkdtemp, readFile, rm } from 'node:fs/promises' import { tmpdir } from 'node:os' import { join } from 'node:path' +import yaml from 'js-yaml' import { afterEach, beforeEach, describe, expect, it } from 'vitest' +import { CACHE_APP_INFO, cachePath } from '../store/manager.js' +import { YamlStore } from '../store/store.js' +import { platform } from '../sys/index.js' import { FieldInfo, FieldParameters } from '../types/app-meta.js' -import { APP_INFO_TTL_MS, cachePath, loadAppInfoCache } from './app-info.js' +import { APP_INFO_TTL_MS, loadAppInfoCache } from './app-info.js' + +function appInfoPath(dir: string): string { + return cachePath(dir, CACHE_APP_INFO) +} function metaInfoOnly(): AppMeta { return { @@ -35,10 +43,10 @@ describe('app-info disk cache', () => { }) it('round-trips an entry across reloads', async () => { - const c1 = await loadAppInfoCache({ configDir: dir }) + const c1 = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await c1.set('http://localhost:9999', 'app-1', metaInfoOnly()) - const c2 = await loadAppInfoCache({ configDir: dir }) + const c2 = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) const got = c2.get('http://localhost:9999', 'app-1') expect(got).toBeDefined() expect(got?.meta.info?.id).toBe('app-1') @@ -47,7 +55,7 @@ describe('app-info disk cache', () => { it('isFresh respects TTL', async () => { const now = new Date('2026-05-09T00:00:00Z') - const c = await loadAppInfoCache({ configDir: dir, now: () => now }) + const c = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)), now: () => now }) await c.set('h', 'app-1', metaInfoOnly()) const r = c.get('h', 'app-1') expect(r).toBeDefined() @@ -58,45 +66,44 @@ describe('app-info disk cache', () => { }) it('keys by (host, app_id) — different hosts isolate', async () => { - const c = await loadAppInfoCache({ configDir: dir }) + const c = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await c.set('h1', 'app-1', metaInfoOnly()) expect(c.get('h2', 'app-1')).toBeUndefined() expect(c.get('h1', 'app-1')).toBeDefined() }) it('delete removes entry from disk', async () => { - const c1 = await loadAppInfoCache({ configDir: dir }) + const c1 = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await c1.set('h', 'app-1', metaInfoOnly()) await c1.delete('h', 'app-1') - const c2 = await loadAppInfoCache({ configDir: dir }) + const c2 = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) expect(c2.get('h', 'app-1')).toBeUndefined() }) it('writes file with 0600 permission', async () => { - const c = await loadAppInfoCache({ configDir: dir }) + const c = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await c.set('h', 'app-1', metaInfoOnly()) const { stat } = await import('node:fs/promises') - const s = await stat(cachePath(dir)) - if (process.platform !== 'win32') + const s = await stat(appInfoPath(dir)) + if (platform() !== 'win32') expect(s.mode & 0o777).toBe(0o600) }) it('missing cache file is not an error', async () => { - const c = await loadAppInfoCache({ configDir: dir }) + const c = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) expect(c.get('h', 'app-1')).toBeUndefined() }) it('corrupt cache file is treated as empty', async () => { - const { mkdir, writeFile } = await import('node:fs/promises') - await mkdir(join(dir, 'cache'), { recursive: true }) - await writeFile(cachePath(dir), '{not json', 'utf8') - const c = await loadAppInfoCache({ configDir: dir }) + const { writeFile } = await import('node:fs/promises') + await writeFile(appInfoPath(dir), ': : not valid yaml', 'utf8') + const c = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) expect(c.get('h', 'app-1')).toBeUndefined() }) it('updates same key in place (no growth)', async () => { - const c = await loadAppInfoCache({ configDir: dir }) + const c = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await c.set('h', 'app-1', metaInfoOnly()) const slim: AppMeta = { ...metaInfoOnly(), @@ -104,8 +111,8 @@ describe('app-info disk cache', () => { parameters: { opening_statement: 'hi' }, } await c.set('h', 'app-1', slim) - const raw = await readFile(cachePath(dir), 'utf8') - const parsed = JSON.parse(raw) as { entries: Record } + const raw = await readFile(appInfoPath(dir), 'utf8') + const parsed = yaml.load(raw) as { entries: Record } expect(Object.keys(parsed.entries)).toHaveLength(1) }) }) diff --git a/cli/src/cache/app-info.ts b/cli/src/cache/app-info.ts index e6aef5a168..5d8ea27642 100644 --- a/cli/src/cache/app-info.ts +++ b/cli/src/cache/app-info.ts @@ -1,15 +1,14 @@ +import type { Store } from '../store/store.js' import type { AppMeta, AppMetaCacheRecord, AppMetaFieldKey } from '../types/app-meta.js' -import { mkdir, readFile, rename, writeFile } from 'node:fs/promises' -import { dirname, join } from 'node:path' -import { DIR_PERM, FILE_PERM } from '../config/dir.js' +import { CACHE_APP_INFO, getCache } from '../store/manager.js' import { FieldInfo, FieldInputSchema, FieldParameters } from '../types/app-meta.js' -const CACHE_FILE = 'app-info.json' export const APP_INFO_TTL_MS = 60 * 60 * 1000 -type DiskShape = { - entries: Record -} +// All entries live under one top-level key; the inner record uses +// `host::appId` composites that contain `::` (never `.`), so they're +// safe as map keys without colliding with Store's dot-path semantics. +const ENTRIES_KEY = { key: 'entries', default: {} as Record } as const type DiskEntry = { meta: SerializedMeta @@ -35,26 +34,25 @@ type State = { } export type AppInfoCacheOptions = { - readonly configDir: string + readonly store?: Store readonly ttlMs?: number readonly now?: () => Date } -export async function loadAppInfoCache(opts: AppInfoCacheOptions): Promise { - const path = cachePath(opts.configDir) +export async function loadAppInfoCache(opts: AppInfoCacheOptions = {}): Promise { + const store = opts.store ?? getCache(CACHE_APP_INFO) const ttlMs = opts.ttlMs ?? APP_INFO_TTL_MS - const state: State = { entries: new Map() } - await readDisk(path, state) + const state: State = { entries: readEntries(store) } return { get: (host, appId) => state.entries.get(key(host, appId)), set: async (host, appId, meta) => { const record: AppMetaCacheRecord = { meta, fetchedAt: (opts.now ?? (() => new Date()))().toISOString() } state.entries.set(key(host, appId), record) - await persist(path, state) + writeEntries(store, state.entries) }, delete: async (host, appId) => { state.entries.delete(key(host, appId)) - await persist(path, state) + writeEntries(store, state.entries) }, isFresh: (record, now) => { const t = (now ?? new Date()).getTime() - new Date(record.fetchedAt).getTime() @@ -63,36 +61,22 @@ export async function loadAppInfoCache(opts: AppInfoCacheOptions): Promise { - let raw: string +function readEntries(store: Store): Map { + const out = new Map() + let raw: Record try { - raw = await readFile(path, 'utf8') - } - catch (err) { - if ((err as NodeJS.ErrnoException).code === 'ENOENT') - return - throw err - } - let parsed: DiskShape - try { - parsed = JSON.parse(raw) as DiskShape + raw = store.get(ENTRIES_KEY) } catch { - return - } - if (parsed.entries === undefined) - return - for (const [k, e] of Object.entries(parsed.entries)) { - state.entries.set(k, deserialize(e)) + return out } + for (const [k, e] of Object.entries(raw)) + out.set(k, deserialize(e)) + return out } function deserialize(e: DiskEntry): AppMetaCacheRecord { @@ -127,12 +111,8 @@ function serialize(record: AppMetaCacheRecord): DiskEntry { } } -async function persist(path: string, state: State): Promise { - const dir = dirname(path) - await mkdir(dir, { recursive: true, mode: DIR_PERM }) - const disk: DiskShape = { entries: {} } - for (const [k, v] of state.entries) disk.entries[k] = serialize(v) - const tmp = `${path}.${process.pid}.${Date.now()}.tmp` - await writeFile(tmp, JSON.stringify(disk), { mode: FILE_PERM }) - await rename(tmp, path) +function writeEntries(store: Store, entries: Map): void { + const out: Record = {} + for (const [k, v] of entries) out[k] = serialize(v) + store.set(ENTRIES_KEY, out) } diff --git a/cli/src/cache/nudge-store.test.ts b/cli/src/cache/nudge-store.test.ts index 90068a1821..974b094620 100644 --- a/cli/src/cache/nudge-store.test.ts +++ b/cli/src/cache/nudge-store.test.ts @@ -1,8 +1,15 @@ import { mkdtemp, readFile, rm, writeFile } from 'node:fs/promises' import { tmpdir } from 'node:os' import { dirname, join } from 'node:path' +import yaml from 'js-yaml' import { afterEach, beforeEach, describe, expect, it } from 'vitest' -import { loadNudgeStore, nudgeStorePath, WARN_INTERVAL_MS } from './nudge-store.js' +import { CACHE_NUDGE, cachePath } from '../store/manager.js' +import { YamlStore } from '../store/store.js' +import { loadNudgeStore, WARN_INTERVAL_MS } from './nudge-store.js' + +function nudgeStorePath(dir: string): string { + return cachePath(dir, CACHE_NUDGE) +} const HOST = 'https://cloud.dify.ai' @@ -16,13 +23,13 @@ describe('NudgeStore', () => { }) it('canWarn=true when no prior record exists', async () => { - const store = await loadNudgeStore({ configDir: dir }) + const store = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)) }) expect(store.canWarn(HOST)).toBe(true) }) it('canWarn=false within the silence window, true past it', async () => { const t0 = new Date('2026-05-19T12:00:00.000Z') - const store = await loadNudgeStore({ configDir: dir, now: () => t0 }) + const store = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => t0 }) await store.markWarned(HOST) expect(store.canWarn(HOST, new Date('2026-05-19T18:00:00.000Z'))).toBe(false) expect(store.canWarn(HOST, new Date('2026-05-20T12:00:00.000Z'))).toBe(true) @@ -30,7 +37,7 @@ describe('NudgeStore', () => { it('canWarn clamps negative elapsed under clock skew (treats as still in window)', async () => { const t0 = new Date('2026-05-19T12:00:00.000Z') - const store = await loadNudgeStore({ configDir: dir, now: () => t0 }) + const store = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => t0 }) await store.markWarned(HOST) const pastClock = new Date('2026-05-19T11:00:00.000Z') // clock moved backwards 1h expect(store.canWarn(HOST, pastClock)).toBe(false) @@ -38,33 +45,25 @@ describe('NudgeStore', () => { it('markWarned persists across store reloads', async () => { const t0 = new Date('2026-05-19T12:00:00.000Z') - const s1 = await loadNudgeStore({ configDir: dir, now: () => t0 }) + const s1 = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => t0 }) await s1.markWarned(HOST) - const s2 = await loadNudgeStore({ configDir: dir, now: () => t0 }) + const s2 = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => t0 }) expect(s2.canWarn(HOST)).toBe(false) }) it('treats a corrupt cache file as empty', async () => { const path = nudgeStorePath(dir) await writeCacheFile(path, '{ not valid json') - const store = await loadNudgeStore({ configDir: dir }) + const store = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)) }) expect(store.canWarn(HOST)).toBe(true) }) - it('ignores file with mismatched schema', async () => { - const path = nudgeStorePath(dir) - await writeCacheFile(path, JSON.stringify({ schema: 99, warned: { [HOST]: '2026-05-19T12:00:00.000Z' } })) - const store = await loadNudgeStore({ configDir: dir }) - expect(store.canWarn(HOST)).toBe(true) - }) - - it('writes ISO timestamps under schema:1/warned on disk', async () => { + it('writes ISO timestamps under warned/ on disk', async () => { const t = new Date('2026-05-19T12:00:00.000Z') - const store = await loadNudgeStore({ configDir: dir, now: () => t }) + const store = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => t }) await store.markWarned(HOST) const raw = await readFile(nudgeStorePath(dir), 'utf8') - const parsed = JSON.parse(raw) as Record - expect(parsed.schema).toBe(1) + const parsed = yaml.load(raw) as Record expect((parsed.warned as Record)[HOST]).toBe(t.toISOString()) }) @@ -73,11 +72,11 @@ describe('NudgeStore', () => { // warns about a different host. Without merge-on-write the second writer // would clobber the first. const t = new Date('2026-05-19T12:00:00.000Z') - const a = await loadNudgeStore({ configDir: dir, now: () => t }) - const b = await loadNudgeStore({ configDir: dir, now: () => t }) + const a = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => t }) + const b = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => t }) await a.markWarned('https://a.example') await b.markWarned('https://b.example') - const reread = await loadNudgeStore({ configDir: dir, now: () => t }) + const reread = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => t }) expect(reread.canWarn('https://a.example')).toBe(false) expect(reread.canWarn('https://b.example')).toBe(false) }) diff --git a/cli/src/cache/nudge-store.ts b/cli/src/cache/nudge-store.ts index 2a0d0ab994..e16c4a435f 100644 --- a/cli/src/cache/nudge-store.ts +++ b/cli/src/cache/nudge-store.ts @@ -1,37 +1,29 @@ -import { randomUUID } from 'node:crypto' -import { mkdir, readFile, rename, writeFile } from 'node:fs/promises' -import { dirname, join } from 'node:path' -import { DIR_PERM, FILE_PERM } from '../config/dir.js' +import type { Store } from '../store/store.js' +import { CACHE_NUDGE, getCache } from '../store/manager.js' -const CACHE_FILE = 'nudge.json' -const DISK_SCHEMA = 1 export const WARN_INTERVAL_MS = 24 * 60 * 60 * 1000 +// Single top-level key holding host→ISO map. Hosts contain dots +// (cloud.dify.ai), so we cannot use them as Store paths directly — +// `doSet` would split on dots and create nested objects. +const WARNED_KEY = { key: 'warned', default: {} as Record } as const + export type NudgeStore = { readonly canWarn: (host: string, now?: Date) => boolean readonly markWarned: (host: string, now?: Date) => Promise } export type NudgeStoreOptions = { - readonly configDir: string + readonly store?: Store readonly now?: () => Date readonly intervalMs?: number } -type DiskShape = { - schema?: number - warned?: Record -} - -export function nudgeStorePath(configDir: string): string { - return join(configDir, 'cache', CACHE_FILE) -} - -export async function loadNudgeStore(opts: NudgeStoreOptions): Promise { - const path = nudgeStorePath(opts.configDir) +export async function loadNudgeStore(opts: NudgeStoreOptions = {}): Promise { + const store = opts.store ?? getCache(CACHE_NUDGE) const intervalMs = opts.intervalMs ?? WARN_INTERVAL_MS const clock = opts.now ?? (() => new Date()) - const memory = await readDisk(path) + const memory = readWarned(store) return { canWarn: (host, now) => { @@ -47,34 +39,23 @@ export async function loadNudgeStore(opts: NudgeStoreOptions): Promise> { +function readWarned(store: Store): Map { const out = new Map() - let raw: string + let raw: Record try { - raw = await readFile(path, 'utf8') - } - catch (err) { - if ((err as NodeJS.ErrnoException).code === 'ENOENT') - return out - throw err - } - let parsed: DiskShape - try { - parsed = JSON.parse(raw) as DiskShape + raw = store.get(WARNED_KEY) } catch { return out } - if (parsed.schema !== DISK_SCHEMA || parsed.warned === undefined) - return out - for (const [host, iso] of Object.entries(parsed.warned)) { + for (const [host, iso] of Object.entries(raw)) { const t = Date.parse(iso) if (!Number.isNaN(t)) out.set(host, t) @@ -82,15 +63,9 @@ async function readDisk(path: string): Promise> { return out } -async function persist(path: string, state: Map): Promise { - const dir = dirname(path) - await mkdir(dir, { recursive: true, mode: DIR_PERM }) - const disk: DiskShape = { schema: DISK_SCHEMA, warned: {} } +function writeWarned(store: Store, state: Map): void { + const warned: Record = {} for (const [host, t] of state) - disk.warned![host] = new Date(t).toISOString() - // randomUUID is collision-proof even when two writers stamp the same - // millisecond — pid+timestamp alone can still collide under tight loops. - const tmp = `${path}.${randomUUID()}.tmp` - await writeFile(tmp, JSON.stringify(disk), { mode: FILE_PERM }) - await rename(tmp, path) + warned[host] = new Date(t).toISOString() + store.set(WARNED_KEY, warned) } diff --git a/cli/src/commands/_shared/authed-command.ts b/cli/src/commands/_shared/authed-command.ts index 97d8648f3f..67c1378f24 100644 --- a/cli/src/commands/_shared/authed-command.ts +++ b/cli/src/commands/_shared/authed-command.ts @@ -2,17 +2,18 @@ import type { KyInstance } from 'ky' import type { HostsBundle } from '../../auth/hosts.js' import type { AppInfoCache } from '../../cache/app-info.js' import type { Command } from '../../framework/command.js' -import type { IOStreams } from '../../io/streams.js' +import type { IOStreams } from '../../sys/io/streams' import { META_PROBE_TIMEOUT_MS, MetaClient } from '../../api/meta.js' import { loadHosts } from '../../auth/hosts.js' import { loadAppInfoCache } from '../../cache/app-info.js' import { loadNudgeStore } from '../../cache/nudge-store.js' -import { resolveConfigDir } from '../../config/dir.js' +import { getEnv } from '../../env/registry.js' import { BaseError } from '../../errors/base.js' import { ErrorCode } from '../../errors/codes.js' import { formatErrorForCli } from '../../errors/format.js' import { createClient } from '../../http/client.js' -import { realStreams } from '../../io/streams.js' +import { resolveConfigDir } from '../../store/dir.js' +import { realStreams } from '../../sys/io/streams' import { hostWithScheme } from '../../util/host.js' import { versionInfo } from '../../version/info.js' import { maybeNudgeCompat } from '../../version/nudge.js' @@ -38,6 +39,7 @@ export async function buildAuthedContext( opts: AuthedContextOptions, ): Promise { const configDir = resolveConfigDir() + const io = realStreams(opts.format ?? '') const bundle = await loadHosts(configDir) if (bundle === undefined || bundle.tokens?.bearer === undefined || bundle.tokens.bearer === '') { const err = new BaseError({ @@ -45,20 +47,19 @@ export async function buildAuthedContext( message: 'not logged in', hint: 'run \'difyctl auth login\'', }) - cmd.error(formatErrorForCli(err, { format: opts.format, isErrTTY: process.stderr.isTTY }), { exit: err.exit() }) + cmd.error(formatErrorForCli(err, { format: opts.format, isErrTTY: io.isErrTTY }), { exit: err.exit() }) } const host = hostWithScheme(bundle.current_host, bundle.scheme) const retryAttempts = resolveRetryAttempts({ flag: opts.retryFlag, - env: (k: string) => process.env[k], + env: getEnv, }) const http = createClient({ host, bearer: bundle.tokens.bearer, retryAttempts }) - const io = realStreams(opts.format ?? '') - const cache = opts.withCache === true ? await loadAppInfoCache({ configDir }) : undefined + const cache = opts.withCache === true ? await loadAppInfoCache() : undefined - await runCompatNudge({ configDir, host, io }) + await runCompatNudge({ host, io }) return { bundle, http, host, io, configDir, cache } } @@ -66,12 +67,11 @@ export async function buildAuthedContext( // Best-effort nudge: never throws, never blocks. Lives here so every authed // command flows through it without per-command wiring. async function runCompatNudge(opts: { - readonly configDir: string readonly host: string readonly io: IOStreams }): Promise { try { - const store = await loadNudgeStore({ configDir: opts.configDir }) + const store = await loadNudgeStore() await maybeNudgeCompat(opts.host, { store, probe: async (host) => { diff --git a/cli/src/commands/auth/devices/_shared/devices.test.ts b/cli/src/commands/auth/devices/_shared/devices.test.ts index 92e9bc8826..5d96a6ca64 100644 --- a/cli/src/commands/auth/devices/_shared/devices.test.ts +++ b/cli/src/commands/auth/devices/_shared/devices.test.ts @@ -1,15 +1,17 @@ +import type { SessionListResponse, SessionRow } from '@dify/contracts/api/openapi/types.gen' import type { DifyMock } from '../../../../../test/fixtures/dify-mock/server.js' +import type { AccountSessionsClient } from '../../../../api/account-sessions.js' import type { HostsBundle } from '../../../../auth/hosts.js' import type { TokenStore } from '../../../../auth/store.js' import { mkdtemp, readFile, rm } from 'node:fs/promises' import { tmpdir } from 'node:os' import { join } from 'node:path' -import { afterEach, beforeEach, describe, expect, it } from 'vitest' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { startMock } from '../../../../../test/fixtures/dify-mock/server.js' import { saveHosts } from '../../../../auth/hosts.js' import { createClient } from '../../../../http/client.js' -import { bufferStreams } from '../../../../io/streams.js' -import { runDevicesList, runDevicesRevoke } from './devices.js' +import { bufferStreams } from '../../../../sys/io/streams' +import { listAllSessions, runDevicesList, runDevicesRevoke } from './devices.js' class MemStore implements TokenStore { readonly entries = new Map() @@ -187,3 +189,47 @@ describe('runDevicesRevoke', () => { .toThrow(/specify a device label/) }) }) + +describe('listAllSessions', () => { + const row = (id: string, label = `dev-${id}`): SessionRow => ({ + id, + prefix: 'dfoa_xxx', + client_id: 'difyctl', + device_label: label, + created_at: null, + last_used_at: null, + expires_at: null, + }) + + function stubClient(pages: readonly SessionListResponse[]): { client: AccountSessionsClient, list: ReturnType } { + const list = vi.fn(async (q?: { page?: number, limit?: number }) => { + const page = q?.page ?? 1 + const env = pages[page - 1] + if (env === undefined) + throw new Error(`stub: no page ${page}`) + return env + }) + return { client: { list } as unknown as AccountSessionsClient, list } + } + + it('exhausts pages until has_more=false', async () => { + const { client, list } = stubClient([ + { page: 1, limit: 200, total: 250, has_more: true, data: Array.from({ length: 200 }, (_, i) => row(`s-${i}`)) }, + { page: 2, limit: 200, total: 250, has_more: false, data: Array.from({ length: 50 }, (_, i) => row(`s-${200 + i}`)) }, + ]) + const all = await listAllSessions(client) + expect(all.length).toBe(250) + expect(list).toHaveBeenCalledTimes(2) + expect(list).toHaveBeenNthCalledWith(1, { page: 1, limit: 200 }) + expect(list).toHaveBeenNthCalledWith(2, { page: 2, limit: 200 }) + }) + + it('single page (has_more=false): one call', async () => { + const { client, list } = stubClient([ + { page: 1, limit: 200, total: 3, has_more: false, data: [row('a'), row('b'), row('c')] }, + ]) + const all = await listAllSessions(client) + expect(all.length).toBe(3) + expect(list).toHaveBeenCalledTimes(1) + }) +}) diff --git a/cli/src/commands/auth/devices/_shared/devices.ts b/cli/src/commands/auth/devices/_shared/devices.ts index 5af41cd4db..78a0cba73b 100644 --- a/cli/src/commands/auth/devices/_shared/devices.ts +++ b/cli/src/commands/auth/devices/_shared/devices.ts @@ -2,37 +2,73 @@ import type { SessionRow } from '@dify/contracts/api/openapi/types.gen' import type { KyInstance } from 'ky' import type { HostsBundle } from '../../../../auth/hosts.js' import type { TokenStore } from '../../../../auth/store.js' -import type { IOStreams } from '../../../../io/streams.js' +import type { IOStreams } from '../../../../sys/io/streams' import { unlink } from 'node:fs/promises' import { join } from 'node:path' import { AccountSessionsClient } from '../../../../api/account-sessions.js' import { HOSTS_FILE_NAME } from '../../../../auth/hosts.js' import { BaseError } from '../../../../errors/base.js' import { ErrorCode } from '../../../../errors/codes.js' -import { colorEnabled, colorScheme } from '../../../../io/color.js' -import { runWithSpinner } from '../../../../io/spinner.js' +import { LIMIT_DEFAULT, LIMIT_MAX, parseLimit } from '../../../../limit/limit.js' +import { colorEnabled, colorScheme } from '../../../../sys/io/color.js' +import { runWithSpinner } from '../../../../sys/io/spinner.js' export type DevicesListOptions = { readonly io: IOStreams readonly bundle: HostsBundle | undefined readonly http: KyInstance readonly json?: boolean + readonly page?: number + readonly limitRaw?: string + readonly envLookup?: (k: string) => string | undefined } export async function runDevicesList(opts: DevicesListOptions): Promise { const b = requireLogin(opts.bundle) const sessions = new AccountSessionsClient(opts.http) - const env = await runWithSpinner( + const env = opts.envLookup ?? ((k: string) => process.env[k]) + const limit = resolveLimit(opts.limitRaw, env) + const page = opts.page === undefined || opts.page <= 0 ? 1 : opts.page + const envelope = await runWithSpinner( { io: opts.io, label: 'Fetching devices' }, - () => sessions.list(), + () => sessions.list({ page, limit }), ) if (opts.json === true) { - opts.io.out.write(`${JSON.stringify(env)}\n`) + opts.io.out.write(`${JSON.stringify(envelope)}\n`) return } - opts.io.out.write(renderTable(env.data, b.token_id ?? '')) + opts.io.out.write(renderTable(envelope.data, b.token_id ?? '')) +} + +function resolveLimit(raw: string | undefined, env: (k: string) => string | undefined): number { + if (raw !== undefined && raw !== '') + return parseLimit(raw, '--limit') + const envValue = env('DIFY_LIMIT') + if (envValue !== undefined && envValue !== '') + return parseLimit(envValue, 'DIFY_LIMIT') + return LIMIT_DEFAULT +} + +/** + * Fetches every session across all pages. Used by revoke paths so that a + * session sitting on page 2+ is still findable / revocable. Uses the max + * page size (LIMIT_MAX) to minimize round-trips. + */ +export async function listAllSessions(client: AccountSessionsClient): Promise { + const out: SessionRow[] = [] + let page = 1 + // Hard guard against a misbehaving server that lies about has_more. + const MAX_PAGES = 100 + while (page <= MAX_PAGES) { + const env = await client.list({ page, limit: LIMIT_MAX }) + out.push(...env.data) + if (!env.has_more) + return out + page++ + } + return out } export type DevicesRevokeOptions = { @@ -58,8 +94,8 @@ export async function runDevicesRevoke(opts: DevicesRevokeOptions): Promise auth devices list', '<%= config.bin %> auth devices list --json', + '<%= config.bin %> auth devices list --page 2 --limit 50', ] static override flags = { 'http-retry': httpRetryFlag, 'json': Flags.boolean({ description: 'emit JSON', default: false }), + 'page': Flags.integer({ description: 'page number', default: 1 }), + 'limit': Flags.string({ description: 'page size [1..200]' }), } async run(argv: string[]): Promise { const { flags } = this.parse(DevicesList, argv) const format = flags.json ? 'json' : '' const ctx = await this.authedCtx({ retryFlag: flags['http-retry'], format }) - await runDevicesList({ io: ctx.io, bundle: ctx.bundle, http: ctx.http, json: flags.json }) + await runDevicesList({ + io: ctx.io, + bundle: ctx.bundle, + http: ctx.http, + json: flags.json, + page: flags.page, + limitRaw: flags.limit, + }) } } diff --git a/cli/src/commands/auth/login/index.ts b/cli/src/commands/auth/login/index.ts index ab8c32cd74..dadce6f990 100644 --- a/cli/src/commands/auth/login/index.ts +++ b/cli/src/commands/auth/login/index.ts @@ -1,6 +1,6 @@ -import { resolveConfigDir } from '../../../config/dir.js' import { Flags } from '../../../framework/flags.js' -import { realStreams } from '../../../io/streams.js' +import { resolveConfigDir } from '../../../store/dir.js' +import { realStreams } from '../../../sys/io/streams' import { DifyCommand } from '../../_shared/dify-command.js' import { runLogin } from './login.js' diff --git a/cli/src/commands/auth/login/login.test.ts b/cli/src/commands/auth/login/login.test.ts index 522623982b..c2e4749a3f 100644 --- a/cli/src/commands/auth/login/login.test.ts +++ b/cli/src/commands/auth/login/login.test.ts @@ -8,7 +8,7 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest' import { startMock } from '../../../../test/fixtures/dify-mock/server.js' import { DeviceFlowApi } from '../../../api/oauth-device.js' import { createClient } from '../../../http/client.js' -import { bufferStreams } from '../../../io/streams.js' +import { bufferStreams } from '../../../sys/io/streams' import { runLogin } from './login.js' const noopClock: Clock = { diff --git a/cli/src/commands/auth/login/login.ts b/cli/src/commands/auth/login/login.ts index de05f52997..77a5c00b94 100644 --- a/cli/src/commands/auth/login/login.ts +++ b/cli/src/commands/auth/login/login.ts @@ -1,7 +1,7 @@ import type { CodeResponse, PollSuccess } from '../../../api/oauth-device.js' import type { HostsBundle, StorageMode, Workspace } from '../../../auth/hosts.js' import type { TokenStore } from '../../../auth/store.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import type { BrowserEnv, BrowserOpener } from '../../../util/browser.js' import type { Clock } from './device-flow.js' import * as os from 'node:os' @@ -10,7 +10,7 @@ import { DeviceFlowApi } from '../../../api/oauth-device.js' import { saveHosts } from '../../../auth/hosts.js' import { selectStore } from '../../../auth/store.js' import { createClient } from '../../../http/client.js' -import { colorEnabled, colorScheme } from '../../../io/color.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' import { decideOpen, OpenDecision, openUrl, realEnv } from '../../../util/browser.js' import { bareHost, DEFAULT_HOST, resolveHost, validateVerificationURI } from '../../../util/host.js' import { awaitAuthorization, realClock } from './device-flow.js' diff --git a/cli/src/commands/auth/logout/index.ts b/cli/src/commands/auth/logout/index.ts index c11ca97284..7915abb242 100644 --- a/cli/src/commands/auth/logout/index.ts +++ b/cli/src/commands/auth/logout/index.ts @@ -1,10 +1,10 @@ import type { KyInstance } from 'ky' import { loadHosts } from '../../../auth/hosts.js' import { selectStore } from '../../../auth/store.js' -import { resolveConfigDir } from '../../../config/dir.js' import { createClient } from '../../../http/client.js' -import { runWithSpinner } from '../../../io/spinner.js' -import { realStreams } from '../../../io/streams.js' +import { resolveConfigDir } from '../../../store/dir.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' +import { realStreams } from '../../../sys/io/streams' import { hostWithScheme } from '../../../util/host.js' import { DifyCommand } from '../../_shared/dify-command.js' import { runLogout } from './logout.js' diff --git a/cli/src/commands/auth/logout/logout.test.ts b/cli/src/commands/auth/logout/logout.test.ts index 4fd3f53e8b..73bd8429bb 100644 --- a/cli/src/commands/auth/logout/logout.test.ts +++ b/cli/src/commands/auth/logout/logout.test.ts @@ -8,7 +8,7 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest' import { startMock } from '../../../../test/fixtures/dify-mock/server.js' import { saveHosts } from '../../../auth/hosts.js' import { createClient } from '../../../http/client.js' -import { bufferStreams } from '../../../io/streams.js' +import { bufferStreams } from '../../../sys/io/streams' import { runLogout } from './logout.js' class MemStore implements TokenStore { diff --git a/cli/src/commands/auth/logout/logout.ts b/cli/src/commands/auth/logout/logout.ts index 48660b6b35..ddcee3b5d4 100644 --- a/cli/src/commands/auth/logout/logout.ts +++ b/cli/src/commands/auth/logout/logout.ts @@ -1,14 +1,14 @@ import type { KyInstance } from 'ky' import type { HostsBundle } from '../../../auth/hosts.js' import type { TokenStore } from '../../../auth/store.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import { unlink } from 'node:fs/promises' import { join } from 'node:path' import { AccountSessionsClient } from '../../../api/account-sessions.js' import { HOSTS_FILE_NAME } from '../../../auth/hosts.js' import { BaseError } from '../../../errors/base.js' import { ErrorCode } from '../../../errors/codes.js' -import { colorEnabled, colorScheme } from '../../../io/color.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' export type LogoutOptions = { readonly configDir: string diff --git a/cli/src/commands/auth/status/index.ts b/cli/src/commands/auth/status/index.ts index c779595d1f..57208b93f4 100644 --- a/cli/src/commands/auth/status/index.ts +++ b/cli/src/commands/auth/status/index.ts @@ -1,7 +1,7 @@ import { loadHosts } from '../../../auth/hosts.js' -import { resolveConfigDir } from '../../../config/dir.js' import { Flags } from '../../../framework/flags.js' -import { realStreams } from '../../../io/streams.js' +import { resolveConfigDir } from '../../../store/dir.js' +import { realStreams } from '../../../sys/io/streams' import { DifyCommand } from '../../_shared/dify-command.js' import { runStatus } from './status.js' diff --git a/cli/src/commands/auth/status/status.test.ts b/cli/src/commands/auth/status/status.test.ts index 0000e9cd59..f039d54866 100644 --- a/cli/src/commands/auth/status/status.test.ts +++ b/cli/src/commands/auth/status/status.test.ts @@ -1,6 +1,6 @@ import type { HostsBundle } from '../../../auth/hosts.js' import { describe, expect, it } from 'vitest' -import { bufferStreams } from '../../../io/streams.js' +import { bufferStreams } from '../../../sys/io/streams' import { runStatus } from './status.js' function accountBundle(): HostsBundle { diff --git a/cli/src/commands/auth/status/status.ts b/cli/src/commands/auth/status/status.ts index c666b08b0a..83ca626827 100644 --- a/cli/src/commands/auth/status/status.ts +++ b/cli/src/commands/auth/status/status.ts @@ -1,5 +1,5 @@ import type { HostsBundle } from '../../../auth/hosts.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import { BaseError } from '../../../errors/base.js' import { ErrorCode } from '../../../errors/codes.js' diff --git a/cli/src/commands/auth/use/index.ts b/cli/src/commands/auth/use/index.ts deleted file mode 100644 index 5803e6450e..0000000000 --- a/cli/src/commands/auth/use/index.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { loadHosts } from '../../../auth/hosts.js' -import { resolveConfigDir } from '../../../config/dir.js' -import { Args } from '../../../framework/flags.js' -import { realStreams } from '../../../io/streams.js' -import { DifyCommand } from '../../_shared/dify-command.js' -import { runUse } from './use.js' - -export default class Use extends DifyCommand { - static override description = 'Switch the active workspace for the current host' - - static override examples = [ - '<%= config.bin %> auth use ws-abc123', - ] - - static override args = { - workspaceId: Args.string({ description: 'workspace id to activate', required: true }), - } - - async run(argv: string[]): Promise { - const { args } = this.parse(Use, argv) - const configDir = resolveConfigDir() - const bundle = await loadHosts(configDir) - await runUse({ configDir, io: realStreams(), bundle, workspaceId: args.workspaceId }) - } -} diff --git a/cli/src/commands/auth/use/use.test.ts b/cli/src/commands/auth/use/use.test.ts deleted file mode 100644 index 178785a630..0000000000 --- a/cli/src/commands/auth/use/use.test.ts +++ /dev/null @@ -1,71 +0,0 @@ -import type { HostsBundle } from '../../../auth/hosts.js' -import { mkdtemp, rm } from 'node:fs/promises' -import { tmpdir } from 'node:os' -import { join } from 'node:path' -import { afterEach, beforeEach, describe, expect, it } from 'vitest' -import { loadHosts, saveHosts } from '../../../auth/hosts.js' -import { bufferStreams } from '../../../io/streams.js' -import { runUse } from './use.js' - -function accountBundle(): HostsBundle { - return { - current_host: 'cloud.dify.ai', - token_storage: 'file', - token_id: 'tok-1', - tokens: { bearer: 'dfoa_test' }, - account: { id: 'acct-1', email: 'tester@dify.ai', name: 'Test Tester' }, - workspace: { id: 'ws-1', name: 'Default', role: 'owner' }, - available_workspaces: [ - { id: 'ws-1', name: 'Default', role: 'owner' }, - { id: 'ws-2', name: 'Other', role: 'normal' }, - ], - } -} - -describe('runUse', () => { - let configDir: string - beforeEach(async () => { - configDir = await mkdtemp(join(tmpdir(), 'difyctl-use-')) - }) - afterEach(async () => { - await rm(configDir, { recursive: true, force: true }) - }) - - it('switches workspace + persists hosts.yml', async () => { - const io = bufferStreams() - const b = accountBundle() - await saveHosts(configDir, b) - const next = await runUse({ configDir, io, bundle: b, workspaceId: 'ws-2' }) - expect(next.workspace).toEqual({ id: 'ws-2', name: 'Other', role: 'normal' }) - const reloaded = await loadHosts(configDir) - expect(reloaded?.workspace?.id).toBe('ws-2') - expect(io.outBuf()).toContain('Switched to workspace Other (ws-2)') - }) - - it('not-logged-in: throws NotLoggedIn', async () => { - const io = bufferStreams() - await expect(runUse({ configDir, io, bundle: undefined, workspaceId: 'ws-1' })) - .rejects - .toThrow(/not logged in/) - }) - - it('sso: throws workspace-unavailable', async () => { - const io = bufferStreams() - const b: HostsBundle = { - current_host: 'cloud.dify.ai', - token_storage: 'file', - tokens: { bearer: 'dfoe_test' }, - external_subject: { email: 'sso@dify.ai', issuer: 'https://issuer.example' }, - } - await expect(runUse({ configDir, io, bundle: b, workspaceId: 'ws-1' })) - .rejects - .toThrow(/workspace context unavailable/) - }) - - it('unknown workspace: throws UsageMissingArg', async () => { - const io = bufferStreams() - await expect(runUse({ configDir, io, bundle: accountBundle(), workspaceId: 'ws-bogus' })) - .rejects - .toThrow(/ws-bogus.*not found/) - }) -}) diff --git a/cli/src/commands/auth/use/use.ts b/cli/src/commands/auth/use/use.ts deleted file mode 100644 index 04454785b2..0000000000 --- a/cli/src/commands/auth/use/use.ts +++ /dev/null @@ -1,49 +0,0 @@ -import type { HostsBundle, Workspace } from '../../../auth/hosts.js' -import type { IOStreams } from '../../../io/streams.js' -import { saveHosts } from '../../../auth/hosts.js' -import { BaseError } from '../../../errors/base.js' -import { ErrorCode } from '../../../errors/codes.js' -import { colorEnabled, colorScheme } from '../../../io/color.js' - -export type UseOptions = { - readonly configDir: string - readonly io: IOStreams - readonly bundle: HostsBundle | undefined - readonly workspaceId: string -} - -export async function runUse(opts: UseOptions): Promise { - const cs = colorScheme(colorEnabled(opts.io.isErrTTY)) - const b = opts.bundle - if (b === undefined || b.tokens?.bearer === undefined || b.tokens.bearer === '') { - throw new BaseError({ - code: ErrorCode.NotLoggedIn, - message: 'not logged in', - hint: 'run \'difyctl auth login\'', - }) - } - if (b.external_subject !== undefined) { - throw new BaseError({ - code: ErrorCode.UsageInvalidFlag, - message: 'workspace context unavailable for external SSO sessions', - hint: 'external SSO subjects don\'t carry tenant memberships in difyctl', - }) - } - - const found = (b.available_workspaces ?? []).find(w => w.id === opts.workspaceId) - if (found === undefined) { - throw new BaseError({ - code: ErrorCode.UsageMissingArg, - message: `workspace "${opts.workspaceId}" not found in available_workspaces; run 'difyctl auth status' to list`, - }) - } - - const next: HostsBundle = { ...b, workspace: pickWorkspace(found) } - await saveHosts(opts.configDir, next) - opts.io.out.write(`${cs.successIcon()} Switched to workspace ${found.name} (${found.id})\n`) - return next -} - -function pickWorkspace(w: Workspace): Workspace { - return { id: w.id, name: w.name, role: w.role } -} diff --git a/cli/src/commands/auth/whoami/index.ts b/cli/src/commands/auth/whoami/index.ts index dbf51fb1e3..a89a6e76bf 100644 --- a/cli/src/commands/auth/whoami/index.ts +++ b/cli/src/commands/auth/whoami/index.ts @@ -1,7 +1,7 @@ import { loadHosts } from '../../../auth/hosts.js' -import { resolveConfigDir } from '../../../config/dir.js' import { Flags } from '../../../framework/flags.js' -import { realStreams } from '../../../io/streams.js' +import { resolveConfigDir } from '../../../store/dir.js' +import { realStreams } from '../../../sys/io/streams' import { DifyCommand } from '../../_shared/dify-command.js' import { runWhoami } from './whoami.js' diff --git a/cli/src/commands/auth/whoami/whoami.test.ts b/cli/src/commands/auth/whoami/whoami.test.ts index f38a4b634f..98ea0a9bcb 100644 --- a/cli/src/commands/auth/whoami/whoami.test.ts +++ b/cli/src/commands/auth/whoami/whoami.test.ts @@ -1,6 +1,6 @@ import type { HostsBundle } from '../../../auth/hosts.js' import { describe, expect, it } from 'vitest' -import { bufferStreams } from '../../../io/streams.js' +import { bufferStreams } from '../../../sys/io/streams' import { runWhoami } from './whoami.js' function accountBundle(): HostsBundle { diff --git a/cli/src/commands/auth/whoami/whoami.ts b/cli/src/commands/auth/whoami/whoami.ts index fca750ae86..908daaddec 100644 --- a/cli/src/commands/auth/whoami/whoami.ts +++ b/cli/src/commands/auth/whoami/whoami.ts @@ -1,5 +1,5 @@ import type { HostsBundle } from '../../../auth/hosts.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import { BaseError } from '../../../errors/base.js' import { ErrorCode } from '../../../errors/codes.js' diff --git a/cli/src/commands/config/get/index.ts b/cli/src/commands/config/get/index.ts index 1505077f98..d02fcdb18d 100644 --- a/cli/src/commands/config/get/index.ts +++ b/cli/src/commands/config/get/index.ts @@ -1,6 +1,6 @@ -import { resolveConfigDir } from '../../../config/dir.js' import { Args } from '../../../framework/flags.js' import { raw } from '../../../framework/output.js' +import { getConfigurationStore } from '../../../store/manager.js' import { DifyCommand } from '../../_shared/dify-command.js' import { runConfigGet } from './run.js' @@ -17,6 +17,6 @@ export default class ConfigGet extends DifyCommand { async run(argv: string[]) { const { args } = this.parse(ConfigGet, argv) - return raw(await runConfigGet({ dir: resolveConfigDir(), key: args.key })) + return raw(runConfigGet({ store: getConfigurationStore(), key: args.key })) } } diff --git a/cli/src/commands/config/get/run.test.ts b/cli/src/commands/config/get/run.test.ts index 7274a6e624..5f594bd7de 100644 --- a/cli/src/commands/config/get/run.test.ts +++ b/cli/src/commands/config/get/run.test.ts @@ -5,8 +5,13 @@ import { beforeEach, describe, expect, it } from 'vitest' import { FILE_NAME } from '../../../config/schema.js' import { isBaseError } from '../../../errors/base.js' import { ErrorCode } from '../../../errors/codes.js' +import { YamlStore } from '../../../store/store.js' import { runConfigGet } from './run.js' +function makeStore(dir: string): YamlStore { + return new YamlStore(join(dir, FILE_NAME)) +} + describe('runConfigGet', () => { let dir: string @@ -20,19 +25,19 @@ describe('runConfigGet', () => { 'schema_version: 1\ndefaults:\n format: yaml\n', 'utf8', ) - const out = await runConfigGet({ dir, key: 'defaults.format' }) + const out = runConfigGet({ store: makeStore(dir), key: 'defaults.format' }) expect(out).toBe('yaml\n') }) - it('returns empty line when key is unset (matches Go fmt.Fprintln)', async () => { - const out = await runConfigGet({ dir, key: 'defaults.format' }) + it('returns empty line when key is unset (matches Go fmt.Fprintln)', () => { + const out = runConfigGet({ store: makeStore(dir), key: 'defaults.format' }) expect(out).toBe('\n') }) - it('throws BaseError(config_invalid_key) on unknown key', async () => { + it('throws BaseError(config_invalid_key) on unknown key', () => { let caught: unknown try { - await runConfigGet({ dir, key: 'bogus.key' }) + runConfigGet({ store: makeStore(dir), key: 'bogus.key' }) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) @@ -46,7 +51,7 @@ describe('runConfigGet', () => { 'schema_version: 1\ndefaults:\n limit: 75\n', 'utf8', ) - const out = await runConfigGet({ dir, key: 'defaults.limit' }) + const out = runConfigGet({ store: makeStore(dir), key: 'defaults.limit' }) expect(out).toBe('75\n') }) }) diff --git a/cli/src/commands/config/get/run.ts b/cli/src/commands/config/get/run.ts index 0f43213318..8fb486e60b 100644 --- a/cli/src/commands/config/get/run.ts +++ b/cli/src/commands/config/get/run.ts @@ -1,15 +1,16 @@ import type { ConfigFile } from '../../../config/schema.js' +import type { YamlStore } from '../../../store/store.js' +import { loadConfig } from '../../../config/config-loader.js' import { getKey } from '../../../config/keys.js' -import { loadConfig } from '../../../config/loader.js' import { emptyConfig } from '../../../config/schema.js' export type RunConfigGetOptions = { readonly key: string - readonly dir: string + readonly store: YamlStore } -export async function runConfigGet(opts: RunConfigGetOptions): Promise { - const loaded = await loadConfig(opts.dir) +export function runConfigGet(opts: RunConfigGetOptions): string { + const loaded = loadConfig(opts.store) const config: ConfigFile = loaded.found ? loaded.config : emptyConfig() return `${getKey(config, opts.key)}\n` } diff --git a/cli/src/commands/config/path/index.ts b/cli/src/commands/config/path/index.ts index 1f529ec385..466aa6a6db 100644 --- a/cli/src/commands/config/path/index.ts +++ b/cli/src/commands/config/path/index.ts @@ -1,5 +1,5 @@ -import { resolveConfigDir } from '../../../config/dir.js' import { raw } from '../../../framework/output.js' +import { resolveConfigDir } from '../../../store/dir.js' import { DifyCommand } from '../../_shared/dify-command.js' import { runConfigPath } from './run.js' diff --git a/cli/src/commands/config/set/index.ts b/cli/src/commands/config/set/index.ts index b8f22eed2b..d747a8783b 100644 --- a/cli/src/commands/config/set/index.ts +++ b/cli/src/commands/config/set/index.ts @@ -1,6 +1,6 @@ -import { resolveConfigDir } from '../../../config/dir.js' import { Args } from '../../../framework/flags.js' import { raw } from '../../../framework/output.js' +import { getConfigurationStore } from '../../../store/manager.js' import { DifyCommand } from '../../_shared/dify-command.js' import { runConfigSet } from './run.js' @@ -19,6 +19,6 @@ export default class ConfigSet extends DifyCommand { async run(argv: string[]) { const { args } = this.parse(ConfigSet, argv) - return raw(await runConfigSet({ dir: resolveConfigDir(), key: args.key, value: args.value })) + return raw(runConfigSet({ store: getConfigurationStore(), key: args.key, value: args.value })) } } diff --git a/cli/src/commands/config/set/run.test.ts b/cli/src/commands/config/set/run.test.ts index 959b331344..54be290271 100644 --- a/cli/src/commands/config/set/run.test.ts +++ b/cli/src/commands/config/set/run.test.ts @@ -5,8 +5,13 @@ import { beforeEach, describe, expect, it } from 'vitest' import { FILE_NAME } from '../../../config/schema.js' import { isBaseError } from '../../../errors/base.js' import { ErrorCode, ExitCode } from '../../../errors/codes.js' +import { YamlStore } from '../../../store/store.js' import { runConfigSet } from './run.js' +function makeStore(dir: string): YamlStore { + return new YamlStore(join(dir, FILE_NAME)) +} + describe('runConfigSet', () => { let dir: string @@ -15,7 +20,7 @@ describe('runConfigSet', () => { }) it('writes config.yml and returns "set k = v\\n"', async () => { - const out = await runConfigSet({ dir, key: 'defaults.format', value: 'json' }) + const out = runConfigSet({ store: makeStore(dir), key: 'defaults.format', value: 'json' }) expect(out).toBe('set defaults.format = json\n') const raw = await readFile(join(dir, FILE_NAME), 'utf8') expect(raw).toContain('format: json') @@ -24,7 +29,7 @@ describe('runConfigSet', () => { it('rejects invalid format value with config_invalid_value', async () => { let caught: unknown try { - await runConfigSet({ dir, key: 'defaults.format', value: 'csv' }) + runConfigSet({ store: makeStore(dir), key: 'defaults.format', value: 'csv' }) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) @@ -32,10 +37,10 @@ describe('runConfigSet', () => { expect(caught.code).toBe(ErrorCode.ConfigInvalidValue) }) - it('rejects unknown key with config_invalid_key', async () => { + it('rejects unknown key with config_invalid_key', () => { let caught: unknown try { - await runConfigSet({ dir, key: 'bogus', value: 'x' }) + runConfigSet({ store: makeStore(dir), key: 'bogus', value: 'x' }) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) @@ -44,17 +49,17 @@ describe('runConfigSet', () => { }) it('preserves prior keys when setting a new one', async () => { - await runConfigSet({ dir, key: 'defaults.format', value: 'yaml' }) - await runConfigSet({ dir, key: 'defaults.limit', value: '40' }) + runConfigSet({ store: makeStore(dir), key: 'defaults.format', value: 'yaml' }) + runConfigSet({ store: makeStore(dir), key: 'defaults.limit', value: '40' }) const raw = await readFile(join(dir, FILE_NAME), 'utf8') expect(raw).toContain('format: yaml') expect(raw).toContain('limit: 40') }) - it('exit code for invalid value is Usage (2)', async () => { + it('exit code for invalid value is Usage (2)', () => { let caught: unknown try { - await runConfigSet({ dir, key: 'defaults.format', value: 'csv' }) + runConfigSet({ store: makeStore(dir), key: 'defaults.format', value: 'csv' }) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) @@ -62,10 +67,10 @@ describe('runConfigSet', () => { expect(caught.exit()).toBe(ExitCode.Usage) }) - it('exit code for unknown key is Usage (2)', async () => { + it('exit code for unknown key is Usage (2)', () => { let caught: unknown try { - await runConfigSet({ dir, key: 'bogus', value: 'x' }) + runConfigSet({ store: makeStore(dir), key: 'bogus', value: 'x' }) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) @@ -73,10 +78,10 @@ describe('runConfigSet', () => { expect(caught.exit()).toBe(ExitCode.Usage) }) - it('typed wrap chain: invalid defaults.limit surfaces ConfigInvalidValue (not UsageInvalidFlag)', async () => { + it('typed wrap chain: invalid defaults.limit surfaces ConfigInvalidValue (not UsageInvalidFlag)', () => { let caught: unknown try { - await runConfigSet({ dir, key: 'defaults.limit', value: 'abc' }) + runConfigSet({ store: makeStore(dir), key: 'defaults.limit', value: 'abc' }) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) diff --git a/cli/src/commands/config/set/run.ts b/cli/src/commands/config/set/run.ts index d59b065a4d..c7a1e752b5 100644 --- a/cli/src/commands/config/set/run.ts +++ b/cli/src/commands/config/set/run.ts @@ -1,19 +1,20 @@ import type { ConfigFile } from '../../../config/schema.js' +import type { YamlStore } from '../../../store/store.js' +import { loadConfig } from '../../../config/config-loader.js' import { setKey } from '../../../config/keys.js' -import { loadConfig } from '../../../config/loader.js' import { emptyConfig } from '../../../config/schema.js' -import { saveConfig } from '../../../config/writer.js' +import { saveConfig } from '../../../store/config-writer.js' export type RunConfigSetOptions = { readonly key: string readonly value: string - readonly dir: string + readonly store: YamlStore } -export async function runConfigSet(opts: RunConfigSetOptions): Promise { - const loaded = await loadConfig(opts.dir) +export function runConfigSet(opts: RunConfigSetOptions): string { + const loaded = loadConfig(opts.store) const config: ConfigFile = loaded.found ? loaded.config : emptyConfig() const next = setKey(config, opts.key, opts.value) - await saveConfig(opts.dir, next) + saveConfig(opts.store, next) return `set ${opts.key} = ${opts.value}\n` } diff --git a/cli/src/commands/config/unset/index.ts b/cli/src/commands/config/unset/index.ts index f1e9a48be3..a7e7d08096 100644 --- a/cli/src/commands/config/unset/index.ts +++ b/cli/src/commands/config/unset/index.ts @@ -1,6 +1,6 @@ -import { resolveConfigDir } from '../../../config/dir.js' import { Args } from '../../../framework/flags.js' import { raw } from '../../../framework/output.js' +import { getConfigurationStore } from '../../../store/manager.js' import { DifyCommand } from '../../_shared/dify-command.js' import { runConfigUnset } from './run.js' @@ -17,6 +17,6 @@ export default class ConfigUnset extends DifyCommand { async run(argv: string[]) { const { args } = this.parse(ConfigUnset, argv) - return raw(await runConfigUnset({ dir: resolveConfigDir(), key: args.key })) + return raw(runConfigUnset({ store: getConfigurationStore(), key: args.key })) } } diff --git a/cli/src/commands/config/unset/run.test.ts b/cli/src/commands/config/unset/run.test.ts index e67753149d..53fbce6735 100644 --- a/cli/src/commands/config/unset/run.test.ts +++ b/cli/src/commands/config/unset/run.test.ts @@ -5,8 +5,13 @@ import { beforeEach, describe, expect, it } from 'vitest' import { FILE_NAME } from '../../../config/schema.js' import { isBaseError } from '../../../errors/base.js' import { ErrorCode } from '../../../errors/codes.js' +import { YamlStore } from '../../../store/store.js' import { runConfigUnset } from './run.js' +function makeStore(dir: string): YamlStore { + return new YamlStore(join(dir, FILE_NAME)) +} + describe('runConfigUnset', () => { let dir: string @@ -20,7 +25,7 @@ describe('runConfigUnset', () => { 'schema_version: 1\ndefaults:\n format: json\n limit: 25\n', 'utf8', ) - const out = await runConfigUnset({ dir, key: 'defaults.format' }) + const out = runConfigUnset({ store: makeStore(dir), key: 'defaults.format' }) expect(out).toBe('unset defaults.format\n') const raw = await readFile(join(dir, FILE_NAME), 'utf8') expect(raw).not.toContain('format:') @@ -28,16 +33,16 @@ describe('runConfigUnset', () => { }) it('is a no-op (writes empty config) when key was already unset', async () => { - const out = await runConfigUnset({ dir, key: 'defaults.format' }) + const out = runConfigUnset({ store: makeStore(dir), key: 'defaults.format' }) expect(out).toBe('unset defaults.format\n') const raw = await readFile(join(dir, FILE_NAME), 'utf8') expect(raw).toContain('schema_version: 1') }) - it('rejects unknown key', async () => { + it('rejects unknown key', () => { let caught: unknown try { - await runConfigUnset({ dir, key: 'bogus' }) + runConfigUnset({ store: makeStore(dir), key: 'bogus' }) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) diff --git a/cli/src/commands/config/unset/run.ts b/cli/src/commands/config/unset/run.ts index 8bd0a512a5..377013ff0f 100644 --- a/cli/src/commands/config/unset/run.ts +++ b/cli/src/commands/config/unset/run.ts @@ -1,18 +1,19 @@ import type { ConfigFile } from '../../../config/schema.js' +import type { YamlStore } from '../../../store/store.js' +import { loadConfig } from '../../../config/config-loader.js' import { unsetKey } from '../../../config/keys.js' -import { loadConfig } from '../../../config/loader.js' import { emptyConfig } from '../../../config/schema.js' -import { saveConfig } from '../../../config/writer.js' +import { saveConfig } from '../../../store/config-writer.js' export type RunConfigUnsetOptions = { readonly key: string - readonly dir: string + readonly store: YamlStore } -export async function runConfigUnset(opts: RunConfigUnsetOptions): Promise { - const loaded = await loadConfig(opts.dir) +export function runConfigUnset(opts: RunConfigUnsetOptions): string { + const loaded = loadConfig(opts.store) const config: ConfigFile = loaded.found ? loaded.config : emptyConfig() const next = unsetKey(config, opts.key) - await saveConfig(opts.dir, next) + saveConfig(opts.store, next) return `unset ${opts.key}\n` } diff --git a/cli/src/commands/config/view/index.ts b/cli/src/commands/config/view/index.ts index 89401f4497..f9e216ade1 100644 --- a/cli/src/commands/config/view/index.ts +++ b/cli/src/commands/config/view/index.ts @@ -1,6 +1,6 @@ -import { resolveConfigDir } from '../../../config/dir.js' import { Flags } from '../../../framework/flags.js' import { raw } from '../../../framework/output.js' +import { getConfigurationStore } from '../../../store/manager.js' import { DifyCommand } from '../../_shared/dify-command.js' import { runConfigView } from './run.js' @@ -18,6 +18,6 @@ export default class ConfigView extends DifyCommand { async run(argv: string[]) { const { flags } = this.parse(ConfigView, argv) - return raw(await runConfigView({ dir: resolveConfigDir(), json: flags.json })) + return raw(runConfigView({ store: getConfigurationStore(), json: flags.json })) } } diff --git a/cli/src/commands/config/view/run.test.ts b/cli/src/commands/config/view/run.test.ts index b3bc93115e..4716aad2f4 100644 --- a/cli/src/commands/config/view/run.test.ts +++ b/cli/src/commands/config/view/run.test.ts @@ -3,8 +3,13 @@ import { tmpdir } from 'node:os' import { join } from 'node:path' import { afterEach, beforeEach, describe, expect, it } from 'vitest' import { FILE_NAME } from '../../../config/schema.js' +import { YamlStore } from '../../../store/store.js' import { runConfigView } from './run.js' +function makeStore(dir: string): YamlStore { + return new YamlStore(join(dir, FILE_NAME)) +} + describe('runConfigView', () => { let dir: string @@ -16,8 +21,8 @@ describe('runConfigView', () => { // tmpdir cleanup is best-effort }) - it('text format: empty config returns empty string', async () => { - const out = await runConfigView({ dir }) + it('text format: empty config returns empty string', () => { + const out = runConfigView({ store: makeStore(dir) }) expect(out).toBe('') }) @@ -27,7 +32,7 @@ describe('runConfigView', () => { 'schema_version: 1\ndefaults:\n format: json\n limit: 50\nstate:\n current_app: app-1\n', 'utf8', ) - const out = await runConfigView({ dir }) + const out = runConfigView({ store: makeStore(dir) }) expect(out).toBe( 'defaults.format = json\ndefaults.limit = 50\nstate.current_app = app-1\n', ) @@ -39,14 +44,14 @@ describe('runConfigView', () => { 'schema_version: 1\ndefaults:\n format: yaml\n', 'utf8', ) - const out = await runConfigView({ dir }) + const out = runConfigView({ store: makeStore(dir) }) expect(out).toBe('defaults.format = yaml\n') expect(out).not.toContain('defaults.limit') expect(out).not.toContain('state.current_app') }) - it('json format: empty config returns "{}\\n"', async () => { - const out = await runConfigView({ dir, json: true }) + it('json format: empty config returns "{}\\n"', () => { + const out = runConfigView({ store: makeStore(dir), json: true }) expect(out).toBe('{}\n') }) @@ -56,15 +61,15 @@ describe('runConfigView', () => { 'schema_version: 1\ndefaults:\n format: table\n limit: 100\nstate:\n current_app: app-x\n', 'utf8', ) - const out = await runConfigView({ dir, json: true }) + const out = runConfigView({ store: makeStore(dir), json: true }) const parsed = JSON.parse(out) as Record expect(parsed['defaults.format']).toBe('table') expect(parsed['defaults.limit']).toBe(100) expect(parsed['state.current_app']).toBe('app-x') }) - it('json format: trailing newline matches Go encoder.Encode', async () => { - const out = await runConfigView({ dir, json: true }) + it('json format: trailing newline matches Go encoder.Encode', () => { + const out = runConfigView({ store: makeStore(dir), json: true }) expect(out.endsWith('\n')).toBe(true) }) }) diff --git a/cli/src/commands/config/view/run.ts b/cli/src/commands/config/view/run.ts index bda070ef46..78b9e1ca52 100644 --- a/cli/src/commands/config/view/run.ts +++ b/cli/src/commands/config/view/run.ts @@ -1,17 +1,18 @@ import type { ConfigFile } from '../../../config/schema.js' +import type { YamlStore } from '../../../store/store.js' +import { loadConfig } from '../../../config/config-loader.js' import { knownKeyNames, lookupKey } from '../../../config/keys.js' -import { loadConfig } from '../../../config/loader.js' import { emptyConfig } from '../../../config/schema.js' export type RunConfigViewOptions = { readonly json?: boolean - readonly dir: string + readonly store: YamlStore } type ViewOut = Record -export async function runConfigView(opts: RunConfigViewOptions): Promise { - const loaded = await loadConfig(opts.dir) +export function runConfigView(opts: RunConfigViewOptions): string { + const loaded = loadConfig(opts.store) const config: ConfigFile = loaded.found ? loaded.config : emptyConfig() const out = collect(config) if (opts.json) diff --git a/cli/src/commands/create/member/handlers.ts b/cli/src/commands/create/member/handlers.ts new file mode 100644 index 0000000000..c2e43577e7 --- /dev/null +++ b/cli/src/commands/create/member/handlers.ts @@ -0,0 +1,23 @@ +import type { MemberInviteResponse } from '@dify/contracts/api/openapi/types.gen' + +export class InviteOutput { + readonly response: MemberInviteResponse + readonly textLine: string + + constructor(response: MemberInviteResponse, textLine: string) { + this.response = response + this.textLine = textLine + } + + text(): string { + return this.textLine + } + + json(): MemberInviteResponse { + return this.response + } + + name(): string { + return this.response.member_id + } +} diff --git a/cli/src/commands/create/member/index.ts b/cli/src/commands/create/member/index.ts new file mode 100644 index 0000000000..fe5b712769 --- /dev/null +++ b/cli/src/commands/create/member/index.ts @@ -0,0 +1,40 @@ +import { Flags } from '../../../framework/flags.js' +import { formatted } from '../../../framework/output.js' +import { DifyCommand } from '../../_shared/dify-command.js' +import { httpRetryFlag } from '../../_shared/global-flags.js' +import { runCreateMember } from './run.js' + +export default class CreateMember extends DifyCommand { + static override description = 'Invite a member to the active (or specified) workspace by email' + + static override examples = [ + '<%= config.bin %> create member --email user@example.com --role normal', + '<%= config.bin %> create member --email user@example.com --role admin -w ws-1', + '<%= config.bin %> create member --email user@example.com --role normal -o json', + ] + + static override flags = { + 'email': Flags.string({ description: 'invitee email address', required: true }), + 'role': Flags.string({ + description: 'role to assign (normal|admin); owner is not assignable here', + required: true, + }), + 'workspace': Flags.string({ + char: 'w', + description: 'workspace id (overrides DIFY_WORKSPACE_ID and stored default)', + }), + 'http-retry': httpRetryFlag, + 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|text)', default: '' }), + } + + async run(argv: string[]) { + const { flags } = this.parse(CreateMember, argv) + const format = flags.output + const ctx = await this.authedCtx({ retryFlag: flags['http-retry'], format }) + const result = await runCreateMember( + { email: flags.email, role: flags.role, workspace: flags.workspace, format }, + { bundle: ctx.bundle, http: ctx.http, io: ctx.io }, + ) + return formatted({ format, data: result.data }) + } +} diff --git a/cli/src/commands/create/member/run.test.ts b/cli/src/commands/create/member/run.test.ts new file mode 100644 index 0000000000..5c739f5a2a --- /dev/null +++ b/cli/src/commands/create/member/run.test.ts @@ -0,0 +1,102 @@ +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import { describe, expect, it, vi } from 'vitest' +import { bufferStreams } from '../../../sys/io/streams.js' +import { runCreateMember } from './run.js' + +function bundle(): HostsBundle { + return { + current_host: 'cloud.dify.ai', + token_storage: 'file', + tokens: { bearer: 'dfoa_test' }, + account: { id: 'acct-1', email: 'inviter@example.com', name: 'Inviter' }, + workspace: { id: 'ws-1', name: 'Default', role: 'owner' }, + available_workspaces: [{ id: 'ws-1', name: 'Default', role: 'owner' }], + } +} + +function fakeClient() { + return { + invite: vi.fn((_ws: string, body: { email: string, role: string }) => + Promise.resolve({ + result: 'success' as const, + email: body.email.toLowerCase(), + role: body.role, + member_id: 'acct-new', + invite_url: 'https://console.example.com/activate?email=x&token=tok', + tenant_id: 'ws-1', + })), + } +} + +describe('runCreateMember', () => { + it('happy path: POSTs invite, returns InviteOutput with text/json/name', async () => { + const client = fakeClient() + const result = await runCreateMember( + { email: 'new@example.com', role: 'normal' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.invite).toHaveBeenCalledWith('ws-1', { email: 'new@example.com', role: 'normal' }) + expect(result.data.text()).toMatch(/Invited new@example\.com as normal/) + expect(result.data.name()).toBe('acct-new') + expect(result.data.json()).toMatchObject({ + email: 'new@example.com', + role: 'normal', + member_id: 'acct-new', + invite_url: 'https://console.example.com/activate?email=x&token=tok', + tenant_id: 'ws-1', + }) + expect(result.workspaceId).toBe('ws-1') + }) + + it('rejects unknown role before any HTTP call', async () => { + const client = fakeClient() + await expect( + runCreateMember( + { email: 'new@example.com', role: 'owner' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ), + ).rejects.toThrow(/invalid --role/) + expect(client.invite).not.toHaveBeenCalled() + }) + + it('rejects empty email', async () => { + const client = fakeClient() + await expect( + runCreateMember( + { email: '', role: 'normal' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ), + ).rejects.toThrow(/--email is required/) + expect(client.invite).not.toHaveBeenCalled() + }) + + it('-w flag overrides resolved workspace', async () => { + const client = fakeClient() + await runCreateMember( + { email: 'new@example.com', role: 'admin', workspace: 'ws-9' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.invite).toHaveBeenCalledWith('ws-9', { email: 'new@example.com', role: 'admin' }) + }) +}) diff --git a/cli/src/commands/create/member/run.ts b/cli/src/commands/create/member/run.ts new file mode 100644 index 0000000000..0608b2fb7b --- /dev/null +++ b/cli/src/commands/create/member/run.ts @@ -0,0 +1,75 @@ +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import type { IOStreams } from '../../../sys/io/streams.js' +import { MembersClient } from '../../../api/members.js' +import { BaseError } from '../../../errors/base.js' +import { ErrorCode } from '../../../errors/codes.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' +import { nullStreams } from '../../../sys/io/streams.js' +import { resolveWorkspaceId } from '../../../workspace/resolver.js' +import { InviteOutput } from './handlers.js' + +export type CreateMemberOptions = { + readonly email: string + readonly role: string + readonly workspace?: string + readonly format?: string +} + +export type CreateMemberDeps = { + readonly bundle: HostsBundle + readonly http: KyInstance + readonly io?: IOStreams + readonly envLookup?: (k: string) => string | undefined + readonly membersFactory?: (http: KyInstance) => MembersClient +} + +export type CreateMemberResult = { + readonly data: InviteOutput + readonly workspaceId: string +} + +// `owner` is intentionally absent — ownership transfer is console-only. +const ASSIGNABLE_ROLES = new Set(['normal', 'admin']) + +export async function runCreateMember( + opts: CreateMemberOptions, + deps: CreateMemberDeps, +): Promise { + if (opts.email === undefined || opts.email === '') { + throw new BaseError({ + code: ErrorCode.UsageMissingArg, + message: '--email is required', + }) + } + if (!ASSIGNABLE_ROLES.has(opts.role)) { + throw new BaseError({ + code: ErrorCode.UsageInvalidFlag, + message: `invalid --role "${opts.role}"`, + hint: 'expected: normal | admin (ownership transfer is console-only)', + }) + } + + const env = deps.envLookup ?? ((k: string) => process.env[k]) + const factory = deps.membersFactory ?? ((h: KyInstance) => new MembersClient(h)) + const io = deps.io ?? nullStreams() + const cs = colorScheme(colorEnabled(io.isErrTTY)) + + const wsId = resolveWorkspaceId({ + flag: opts.workspace, + env: env('DIFY_WORKSPACE_ID'), + bundle: deps.bundle, + }) + + const response = await runWithSpinner( + { io, label: `Inviting ${opts.email}` }, + () => factory(deps.http).invite(wsId, { + email: opts.email, + role: opts.role as 'normal' | 'admin', + }), + ) + + const textLine = `${cs.successIcon()} Invited ${response.email} as ${response.role}\n` + return { data: new InviteOutput(response, textLine), workspaceId: wsId } +} diff --git a/cli/src/commands/delete/member/handlers.ts b/cli/src/commands/delete/member/handlers.ts new file mode 100644 index 0000000000..1e88ee419e --- /dev/null +++ b/cli/src/commands/delete/member/handlers.ts @@ -0,0 +1,26 @@ +export type DeletedMemberPayload = { + readonly id: string + readonly deleted: true +} + +export class DeleteMemberOutput { + readonly payload: DeletedMemberPayload + readonly textLine: string + + constructor(memberId: string, textLine: string) { + this.payload = { id: memberId, deleted: true } + this.textLine = textLine + } + + text(): string { + return this.textLine + } + + json(): DeletedMemberPayload { + return this.payload + } + + name(): string { + return this.payload.id + } +} diff --git a/cli/src/commands/delete/member/index.ts b/cli/src/commands/delete/member/index.ts new file mode 100644 index 0000000000..f455de9fbb --- /dev/null +++ b/cli/src/commands/delete/member/index.ts @@ -0,0 +1,40 @@ +import { Args, Flags } from '../../../framework/flags.js' +import { formatted } from '../../../framework/output.js' +import { DifyCommand } from '../../_shared/dify-command.js' +import { httpRetryFlag } from '../../_shared/global-flags.js' +import { runDeleteMember } from './run.js' + +export default class DeleteMember extends DifyCommand { + static override description = 'Remove a member from the active (or specified) workspace' + + static override examples = [ + '<%= config.bin %> delete member acct-1', + '<%= config.bin %> delete member acct-1 -w ws-1', + '<%= config.bin %> delete member acct-1 -o json', + ] + + static override args = { + memberId: Args.string({ description: 'account id of the member to remove', required: true }), + } + + static override flags = { + 'workspace': Flags.string({ + char: 'w', + description: 'workspace id (overrides DIFY_WORKSPACE_ID and stored default)', + }), + 'http-retry': httpRetryFlag, + 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|text)', default: '' }), + 'yes': Flags.boolean({ char: 'y', description: 'skip confirmation prompt', default: false }), + } + + async run(argv: string[]) { + const { args, flags } = this.parse(DeleteMember, argv) + const format = flags.output + const ctx = await this.authedCtx({ retryFlag: flags['http-retry'], format }) + const result = await runDeleteMember( + { memberId: args.memberId, workspace: flags.workspace, format, yes: flags.yes }, + { bundle: ctx.bundle, http: ctx.http, io: ctx.io }, + ) + return formatted({ format, data: result.data }) + } +} diff --git a/cli/src/commands/delete/member/run.test.ts b/cli/src/commands/delete/member/run.test.ts new file mode 100644 index 0000000000..15a4f66db2 --- /dev/null +++ b/cli/src/commands/delete/member/run.test.ts @@ -0,0 +1,72 @@ +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import { describe, expect, it, vi } from 'vitest' +import { bufferStreams } from '../../../sys/io/streams.js' +import { runDeleteMember } from './run.js' + +function bundle(): HostsBundle { + return { + current_host: 'cloud.dify.ai', + token_storage: 'file', + tokens: { bearer: 'dfoa_test' }, + account: { id: 'acct-1', email: 'me@example.com', name: 'Me' }, + workspace: { id: 'ws-1', name: 'Default', role: 'owner' }, + available_workspaces: [{ id: 'ws-1', name: 'Default', role: 'owner' }], + } +} + +function fakeClient() { + return { + remove: vi.fn(() => Promise.resolve({ result: 'success' as const })), + } +} + +describe('runDeleteMember', () => { + it('happy path: DELETE, returns DeleteMemberOutput with text/json/name', async () => { + const client = fakeClient() + const result = await runDeleteMember( + { memberId: 'acct-2' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.remove).toHaveBeenCalledExactlyOnceWith('ws-1', 'acct-2') + expect(result.data.text()).toMatch(/Removed acct-2/) + expect(result.data.name()).toBe('acct-2') + expect(result.data.json()).toEqual({ id: 'acct-2', deleted: true }) + expect(result.workspaceId).toBe('ws-1') + }) + + it('-w flag overrides resolved workspace', async () => { + const client = fakeClient() + await runDeleteMember( + { memberId: 'acct-2', workspace: 'ws-9' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.remove).toHaveBeenCalledWith('ws-9', 'acct-2') + }) + + it('rejects empty member id before any HTTP call', async () => { + const client = fakeClient() + await expect( + runDeleteMember( + { memberId: '' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ), + ).rejects.toThrow(/member id is required/) + expect(client.remove).not.toHaveBeenCalled() + }) +}) diff --git a/cli/src/commands/delete/member/run.ts b/cli/src/commands/delete/member/run.ts new file mode 100644 index 0000000000..f89afe86c7 --- /dev/null +++ b/cli/src/commands/delete/member/run.ts @@ -0,0 +1,90 @@ +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import type { IOStreams } from '../../../sys/io/streams.js' +import * as readline from 'node:readline' +import { MembersClient } from '../../../api/members.js' +import { BaseError } from '../../../errors/base.js' +import { ErrorCode } from '../../../errors/codes.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' +import { nullStreams } from '../../../sys/io/streams.js' +import { resolveWorkspaceId } from '../../../workspace/resolver.js' +import { DeleteMemberOutput } from './handlers.js' + +export type DeleteMemberOptions = { + readonly memberId: string + readonly workspace?: string + readonly format?: string + readonly yes?: boolean +} + +export type DeleteMemberDeps = { + readonly bundle: HostsBundle + readonly http: KyInstance + readonly io?: IOStreams + readonly envLookup?: (k: string) => string | undefined + readonly membersFactory?: (http: KyInstance) => MembersClient +} + +export type DeleteMemberResult = { + readonly data: DeleteMemberOutput + readonly workspaceId: string +} + +export async function runDeleteMember( + opts: DeleteMemberOptions, + deps: DeleteMemberDeps, +): Promise { + if (opts.memberId === undefined || opts.memberId === '') { + throw new BaseError({ + code: ErrorCode.UsageMissingArg, + message: 'member id is required', + hint: 'pass it positionally: difyctl delete member ', + }) + } + + const env = deps.envLookup ?? ((k: string) => process.env[k]) + const factory = deps.membersFactory ?? ((h: KyInstance) => new MembersClient(h)) + const io = deps.io ?? nullStreams() + const cs = colorScheme(colorEnabled(io.isErrTTY)) + + const wsId = resolveWorkspaceId({ + flag: opts.workspace, + env: env('DIFY_WORKSPACE_ID'), + bundle: deps.bundle, + }) + + if (!opts.yes && io.isErrTTY) { + const confirmed = await promptConfirm(io, `Remove member ${opts.memberId}? [y/N] `) + if (!confirmed) { + throw new BaseError({ + code: ErrorCode.UsageMissingArg, + message: 'aborted by user', + hint: 'pass --yes to skip confirmation', + }) + } + } + + await runWithSpinner( + { io, label: `Removing ${opts.memberId}` }, + () => factory(deps.http).remove(wsId, opts.memberId), + ) + + const textLine = `${cs.successIcon()} Removed ${opts.memberId}\n` + return { + data: new DeleteMemberOutput(opts.memberId, textLine), + workspaceId: wsId, + } +} + +async function promptConfirm(io: IOStreams, message: string): Promise { + io.err.write(message) + const rl = readline.createInterface({ input: io.in, output: io.err, terminal: false }) + try { + const line: string = await new Promise(resolve => rl.once('line', resolve)) + return line.trim().toLowerCase() === 'y' + } + finally { + rl.close() + } +} diff --git a/cli/src/commands/describe/app/run.test.ts b/cli/src/commands/describe/app/run.test.ts index 5e35a0986f..d769d5db3f 100644 --- a/cli/src/commands/describe/app/run.test.ts +++ b/cli/src/commands/describe/app/run.test.ts @@ -8,6 +8,8 @@ import { startMock } from '../../../../test/fixtures/dify-mock/server.js' import { loadAppInfoCache } from '../../../cache/app-info.js' import { formatted, stringifyOutput } from '../../../framework/output.js' import { createClient } from '../../../http/client.js' +import { CACHE_APP_INFO, cachePath } from '../../../store/manager.js' +import { YamlStore } from '../../../store/store.js' import { runDescribeApp } from './run.js' function bundle(): HostsBundle { @@ -37,7 +39,7 @@ describe('runDescribeApp', () => { }) async function render(opts: Parameters[0]): Promise { - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) const data = await runDescribeApp( opts, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, cache }, @@ -80,7 +82,7 @@ describe('runDescribeApp', () => { }) it('refresh: bypasses cache', async () => { - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runDescribeApp( { appId: 'app-1' }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, cache }, diff --git a/cli/src/commands/describe/app/run.ts b/cli/src/commands/describe/app/run.ts index 089274b8f1..f332cc992a 100644 --- a/cli/src/commands/describe/app/run.ts +++ b/cli/src/commands/describe/app/run.ts @@ -1,11 +1,12 @@ import type { KyInstance } from 'ky' import type { HostsBundle } from '../../../auth/hosts.js' import type { AppInfoCache } from '../../../cache/app-info.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import { AppMetaClient } from '../../../api/app-meta.js' import { AppsClient } from '../../../api/apps.js' -import { runWithSpinner } from '../../../io/spinner.js' -import { nullStreams } from '../../../io/streams.js' +import { getEnv } from '../../../sys/index.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' +import { nullStreams } from '../../../sys/io/streams' import { FieldInfo, FieldInputSchema, FieldParameters } from '../../../types/app-meta.js' import { resolveWorkspaceId } from '../../../workspace/resolver.js' import { AppDescribeOutput } from './handlers.js' @@ -27,7 +28,7 @@ export type DescribeAppDeps = { } export async function runDescribeApp(opts: DescribeAppOptions, deps: DescribeAppDeps): Promise { - const env = deps.envLookup ?? ((k: string) => process.env[k]) + const env = deps.envLookup ?? getEnv const wsId = resolveWorkspaceId({ flag: opts.workspace, env: env('DIFY_WORKSPACE_ID'), bundle: deps.bundle }) const apps = new AppsClient(deps.http) const meta = new AppMetaClient({ apps, host: deps.host, cache: deps.cache }) diff --git a/cli/src/commands/env/list/run-list.ts b/cli/src/commands/env/list/run-list.ts index 2fce948fc5..379c81d4af 100644 --- a/cli/src/commands/env/list/run-list.ts +++ b/cli/src/commands/env/list/run-list.ts @@ -1,4 +1,5 @@ import { ENV_REGISTRY } from '../../../env/registry.js' +import { getEnv } from '../../../sys/index.js' export type EnvLookup = (name: string) => string | undefined @@ -17,7 +18,7 @@ export type EnvListJsonRow = { const COLUMN_PADDING = 2 export function runEnvList(opts: RunEnvListOptions = {}): string { - const lookup = opts.lookup ?? defaultLookup + const lookup = opts.lookup ?? getEnv if (opts.json) { const rows: EnvListJsonRow[] = ENV_REGISTRY.map(v => ({ name: v.name, @@ -67,7 +68,3 @@ function renderTable(rows: readonly (readonly string[])[]): string { } return out } - -function defaultLookup(name: string): string | undefined { - return process.env[name] -} diff --git a/cli/src/commands/get/app/run.ts b/cli/src/commands/get/app/run.ts index 4bfb300fb1..ccb091db57 100644 --- a/cli/src/commands/get/app/run.ts +++ b/cli/src/commands/get/app/run.ts @@ -1,12 +1,13 @@ import type { AppDescribeResponse, AppListResponse, AppMode } from '@dify/contracts/api/openapi/types.gen' import type { KyInstance } from 'ky' import type { HostsBundle } from '../../../auth/hosts.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import { AppsClient } from '../../../api/apps.js' import { WorkspacesClient } from '../../../api/workspaces.js' -import { runWithSpinner } from '../../../io/spinner.js' -import { nullStreams } from '../../../io/streams.js' import { LIMIT_DEFAULT, parseLimit } from '../../../limit/limit.js' +import { getEnv } from '../../../sys/index.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' +import { nullStreams } from '../../../sys/io/streams' import { resolveWorkspaceId } from '../../../workspace/resolver.js' import { AppListOutput, AppRow } from './handlers.js' @@ -38,7 +39,7 @@ export type GetAppResult = { } export async function runGetApp(opts: GetAppOptions, deps: GetAppDeps): Promise { - const env = deps.envLookup ?? ((k: string) => process.env[k]) + const env = deps.envLookup ?? getEnv const appsFactory = deps.appsFactory ?? ((h: KyInstance) => new AppsClient(h)) const wsFactory = deps.workspacesFactory ?? ((h: KyInstance) => new WorkspacesClient(h)) diff --git a/cli/src/commands/get/member/handlers.ts b/cli/src/commands/get/member/handlers.ts new file mode 100644 index 0000000000..b231916bec --- /dev/null +++ b/cli/src/commands/get/member/handlers.ts @@ -0,0 +1,89 @@ +import type { MemberListResponse, MemberResponse } from '@dify/contracts/api/openapi/types.gen' +import type { TableCell } from '../../../framework/output.js' +import type { TableColumn } from '../../../printers/format-table.js' + +export const MEMBER_MODE_KEY = 'member' +const CURRENT_MARKER = '*' + +export const MEMBER_COLUMNS: readonly TableColumn[] = [ + { name: 'ID', priority: 0 }, + { name: 'NAME', priority: 0 }, + { name: 'EMAIL', priority: 0 }, + { name: 'ROLE', priority: 0 }, + { name: 'STATUS', priority: 0 }, + { name: 'CURRENT', priority: 0 }, +] + +export class MemberRow { + readonly id: string + readonly displayName: string + readonly email: string + readonly role: string + readonly status: string + readonly current: boolean + + constructor(member: MemberResponse, current: boolean) { + this.id = member.id + this.displayName = member.name + this.email = member.email + this.role = member.role + this.status = member.status + this.current = current + } + + tableRow(): readonly TableCell[] { + return [ + this.id, + this.displayName, + this.email, + this.role, + this.status, + this.current ? CURRENT_MARKER : '', + ] + } + + name(): string { + return this.id + } + + json() { + return { + id: this.id, + name: this.displayName, + email: this.email, + role: this.role, + status: this.status, + current: this.current, + } + } +} + +export class MemberListOutput { + readonly rows: readonly MemberRow[] + readonly envelope: MemberListResponse + + constructor(rows: readonly MemberRow[], envelope: MemberListResponse) { + this.rows = rows + this.envelope = envelope + } + + static tableColumns(): readonly TableColumn[] { + return MEMBER_COLUMNS + } + + tableColumns(): readonly TableColumn[] { + return MemberListOutput.tableColumns() + } + + tableRows(): readonly (readonly TableCell[])[] { + return this.rows.map(row => row.tableRow()) + } + + name(): string { + return this.rows.map(row => row.name()).join('\n') + } + + json(): MemberListResponse { + return this.envelope + } +} diff --git a/cli/src/commands/get/member/index.ts b/cli/src/commands/get/member/index.ts new file mode 100644 index 0000000000..44a3dd241a --- /dev/null +++ b/cli/src/commands/get/member/index.ts @@ -0,0 +1,44 @@ +import { Flags } from '../../../framework/flags.js' +import { table } from '../../../framework/output.js' +import { DifyCommand } from '../../_shared/dify-command.js' +import { httpRetryFlag } from '../../_shared/global-flags.js' +import { runGetMember } from './run.js' + +export default class GetMember extends DifyCommand { + static override description = 'List members of the active (or specified) workspace' + + static override examples = [ + '<%= config.bin %> get member', + '<%= config.bin %> get member -w ws-1', + '<%= config.bin %> get member --page 2 --limit 50', + '<%= config.bin %> get member -o json', + '<%= config.bin %> get member -o name', + ] + + static override flags = { + 'workspace': Flags.string({ + char: 'w', + description: 'workspace id (overrides DIFY_WORKSPACE_ID and stored default)', + }), + 'page': Flags.integer({ description: 'page number', default: 1 }), + 'limit': Flags.string({ description: 'page size [1..200]' }), + 'http-retry': httpRetryFlag, + 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|wide)', default: '' }), + } + + async run(argv: string[]) { + const { flags } = this.parse(GetMember, argv) + const format = flags.output + const ctx = await this.authedCtx({ retryFlag: flags['http-retry'], format }) + const result = await runGetMember( + { + workspace: flags.workspace, + page: flags.page, + limitRaw: flags.limit, + format, + }, + { bundle: ctx.bundle, http: ctx.http, io: ctx.io }, + ) + return table({ format, data: result.data }) + } +} diff --git a/cli/src/commands/get/member/run.test.ts b/cli/src/commands/get/member/run.test.ts new file mode 100644 index 0000000000..d32b172eb9 --- /dev/null +++ b/cli/src/commands/get/member/run.test.ts @@ -0,0 +1,153 @@ +import type { MemberListResponse } from '@dify/contracts/api/openapi/types.gen' +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import { describe, expect, it, vi } from 'vitest' +import { bufferStreams } from '../../../sys/io/streams.js' +import { runGetMember } from './run.js' + +function bundle(): HostsBundle { + return { + current_host: 'cloud.dify.ai', + token_storage: 'file', + tokens: { bearer: 'dfoa_test' }, + account: { id: 'acct-1', email: 'me@example.com', name: 'Me' }, + workspace: { id: 'ws-1', name: 'Default', role: 'owner' }, + available_workspaces: [{ id: 'ws-1', name: 'Default', role: 'owner' }], + } +} + +function fakeClient(envelope: MemberListResponse) { + return { list: vi.fn(() => Promise.resolve(envelope)) } +} + +describe('runGetMember', () => { + const env: MemberListResponse = { + page: 1, + limit: 20, + total: 2, + has_more: false, + data: [ + { id: 'acct-1', name: 'Me', email: 'me@example.com', role: 'owner', status: 'active' }, + { id: 'acct-2', name: 'Mate', email: 'mate@example.com', role: 'admin', status: 'active' }, + ], + } + + it('lists members and marks the calling account with current=true', async () => { + const client = fakeClient(env) + const r = await runGetMember( + {}, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.list).toHaveBeenCalledExactlyOnceWith('ws-1', { page: 1, limit: 20 }) + expect(r.workspaceId).toBe('ws-1') + expect(r.data.rows.map(row => row.current)).toEqual([true, false]) + expect(r.data.rows.map(row => row.id)).toEqual(['acct-1', 'acct-2']) + }) + + it('-w flag overrides resolved workspace', async () => { + const client = fakeClient(env) + const r = await runGetMember( + { workspace: 'ws-9' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.list).toHaveBeenCalledWith('ws-9', { page: 1, limit: 20 }) + expect(r.workspaceId).toBe('ws-9') + }) + + it('--page/--limit are forwarded to the client', async () => { + const client = fakeClient(env) + await runGetMember( + { page: 3, limitRaw: '50' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.list).toHaveBeenCalledWith('ws-1', { page: 3, limit: 50 }) + }) + + it('marks no row when bundle has no account id', async () => { + const client = fakeClient(env) + const b = bundle() + b.account = { id: '', email: '', name: '' } + const r = await runGetMember( + {}, + { + bundle: b, + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(r.data.rows.every(row => !row.current)).toBe(true) + }) + + it('throws when no workspace can be resolved', async () => { + const client = fakeClient(env) + await expect( + runGetMember( + {}, + { + bundle: { + current_host: '', + token_storage: 'file', + tokens: { bearer: 'dfoa_test' }, + account: { id: 'acct-1', email: '', name: '' }, + }, + http: {} as KyInstance, + io: bufferStreams(), + envLookup: () => undefined, + membersFactory: () => client as never, + }, + ), + ).rejects.toThrow(/no workspace selected/) + expect(client.list).not.toHaveBeenCalled() + }) +}) + +describe('MemberListOutput shape', () => { + it('builds table with CURRENT marker column', async () => { + const env: MemberListResponse = { + page: 1, + limit: 20, + total: 1, + has_more: false, + data: [ + { id: 'acct-1', name: 'Me', email: 'me@example.com', role: 'owner', status: 'active' }, + ], + } + const client = fakeClient(env) + const r = await runGetMember( + {}, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(r.data.tableColumns().map(c => c.name)).toEqual([ + 'ID', + 'NAME', + 'EMAIL', + 'ROLE', + 'STATUS', + 'CURRENT', + ]) + expect(r.data.tableRows()[0]?.[5]).toBe('*') + expect(r.data.name()).toBe('acct-1') + expect(r.data.json().data[0]?.email).toBe('me@example.com') + }) +}) diff --git a/cli/src/commands/get/member/run.ts b/cli/src/commands/get/member/run.ts new file mode 100644 index 0000000000..011cdb1572 --- /dev/null +++ b/cli/src/commands/get/member/run.ts @@ -0,0 +1,65 @@ +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import type { IOStreams } from '../../../sys/io/streams.js' +import { MembersClient } from '../../../api/members.js' +import { LIMIT_DEFAULT, parseLimit } from '../../../limit/limit.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' +import { nullStreams } from '../../../sys/io/streams.js' +import { resolveWorkspaceId } from '../../../workspace/resolver.js' +import { MemberListOutput, MemberRow } from './handlers.js' + +export type GetMemberOptions = { + readonly workspace?: string + readonly page?: number + readonly limitRaw?: string + readonly format?: string +} + +export type GetMemberDeps = { + readonly bundle: HostsBundle + readonly http: KyInstance + readonly io?: IOStreams + readonly envLookup?: (k: string) => string | undefined + readonly membersFactory?: (http: KyInstance) => MembersClient +} + +export type GetMemberResult = { + readonly data: MemberListOutput + readonly workspaceId: string +} + +export async function runGetMember( + opts: GetMemberOptions, + deps: GetMemberDeps, +): Promise { + const env = deps.envLookup ?? ((k: string) => process.env[k]) + const factory = deps.membersFactory ?? ((h: KyInstance) => new MembersClient(h)) + const io = deps.io ?? nullStreams() + + const wsId = resolveWorkspaceId({ + flag: opts.workspace, + env: env('DIFY_WORKSPACE_ID'), + bundle: deps.bundle, + }) + + const limit = resolveLimit(opts.limitRaw, env) + const page = opts.page === undefined || opts.page <= 0 ? 1 : opts.page + + const envelope = await runWithSpinner( + { io, label: 'Fetching members' }, + () => factory(deps.http).list(wsId, { page, limit }), + ) + + const callerId = deps.bundle.account?.id ?? '' + const rows = envelope.data.map(m => new MemberRow(m, callerId !== '' && m.id === callerId)) + return { data: new MemberListOutput(rows, envelope), workspaceId: wsId } +} + +function resolveLimit(raw: string | undefined, env: (k: string) => string | undefined): number { + if (raw !== undefined && raw !== '') + return parseLimit(raw, '--limit') + const envValue = env('DIFY_LIMIT') + if (envValue !== undefined && envValue !== '') + return parseLimit(envValue, 'DIFY_LIMIT') + return LIMIT_DEFAULT +} diff --git a/cli/src/commands/get/workspace/run.ts b/cli/src/commands/get/workspace/run.ts index f2015f4817..f3b86f3c1d 100644 --- a/cli/src/commands/get/workspace/run.ts +++ b/cli/src/commands/get/workspace/run.ts @@ -1,9 +1,9 @@ import type { KyInstance } from 'ky' import type { HostsBundle } from '../../../auth/hosts.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import { WorkspacesClient } from '../../../api/workspaces.js' -import { runWithSpinner } from '../../../io/spinner.js' -import { nullStreams } from '../../../io/streams.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' +import { nullStreams } from '../../../sys/io/streams' import { WorkspaceListOutput, WorkspaceRow } from './handlers.js' export const EMPTY_WORKSPACES_MESSAGE diff --git a/cli/src/commands/resume/app/run.ts b/cli/src/commands/resume/app/run.ts index 06280ebaad..bcd109b21f 100644 --- a/cli/src/commands/resume/app/run.ts +++ b/cli/src/commands/resume/app/run.ts @@ -1,12 +1,13 @@ import type { KyInstance } from 'ky' import type { HostsBundle } from '../../../auth/hosts.js' import type { AppInfoCache } from '../../../cache/app-info.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import type { RunContext } from '../../run/app/_strategies/index.js' import { AppMetaClient } from '../../../api/app-meta.js' import { AppRunClient } from '../../../api/app-run.js' import { AppsClient } from '../../../api/apps.js' -import { colorEnabled, colorScheme } from '../../../io/color.js' +import { getEnv, processExit } from '../../../sys/index.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' import { FieldInfo } from '../../../types/app-meta.js' import { resolveWorkspaceId } from '../../../workspace/resolver.js' import { pickStrategy } from '../../run/app/_strategies/index.js' @@ -76,7 +77,7 @@ async function resolveInputs( } export async function resumeApp(opts: ResumeAppOptions, deps: ResumeAppDeps): Promise { - const env = deps.envLookup ?? ((k: string) => process.env[k]) + const env = deps.envLookup ?? getEnv const wsId = resolveWorkspaceId({ flag: opts.workspace, env: env('DIFY_WORKSPACE_ID'), bundle: deps.bundle }) const apps = new AppsClient(deps.http) @@ -85,7 +86,7 @@ export async function resumeApp(opts: ResumeAppOptions, deps: ResumeAppDeps): Pr const mode = m.info?.mode ?? RUN_MODES.Workflow const runClient = new AppRunClient(deps.http) - const exit = deps.exit ?? ((code: number) => process.exit(code) as never) + const exit = deps.exit ?? processExit let action = opts.action if (action === undefined) { diff --git a/cli/src/commands/run/app/_strategies/streaming-structured.ts b/cli/src/commands/run/app/_strategies/streaming-structured.ts index b85db1d808..b59550ca6c 100644 --- a/cli/src/commands/run/app/_strategies/streaming-structured.ts +++ b/cli/src/commands/run/app/_strategies/streaming-structured.ts @@ -1,9 +1,9 @@ import type { SseEvent } from '../../../../http/sse.js' import type { RunContext, RunStrategy } from './index.js' import { buildRunBody } from '../../../../api/app-run.js' -import { colorEnabled, colorScheme } from '../../../../io/color.js' -import { startSpinner } from '../../../../io/spinner.js' -import { extractThinkBlocks, stripThinkBlocks } from '../../../../io/think-filter.js' +import { colorEnabled, colorScheme } from '../../../../sys/io/color.js' +import { startSpinner } from '../../../../sys/io/spinner.js' +import { extractThinkBlocks, stripThinkBlocks } from '../../../../sys/io/think-filter.js' import { chatConversationHint, newAppRunObject, RUN_MODES } from '../handlers.js' import { renderHitlHint, renderHitlOutput } from '../hitl-render.js' import { collect, HitlPauseError } from '../sse-collector.js' diff --git a/cli/src/commands/run/app/_strategies/streaming-text.ts b/cli/src/commands/run/app/_strategies/streaming-text.ts index 6a5405a918..6f88e394dd 100644 --- a/cli/src/commands/run/app/_strategies/streaming-text.ts +++ b/cli/src/commands/run/app/_strategies/streaming-text.ts @@ -1,5 +1,6 @@ import type { RunContext, RunStrategy } from './index.js' import { buildRunBody } from '../../../../api/app-run.js' +import { handle, unhandle } from '../../../../sys/index.js' import { renderHitlHint, renderHitlOutput } from '../hitl-render.js' import { decodeStreamError, HitlPauseError } from '../sse-collector.js' @@ -22,7 +23,8 @@ export class StreamingTextStrategy implements RunStrategy { ctrl.abort() exit(1) } - process.once('SIGINT', cleanup) + + handle('SIGINT', cleanup) try { const events = await ctx.runClient.runStream(opts.appId, body, { signal: ctrl.signal }) @@ -60,7 +62,7 @@ export class StreamingTextStrategy implements RunStrategy { throw err } finally { - process.off('SIGINT', cleanup) + unhandle('SIGINT', cleanup) } } } diff --git a/cli/src/commands/run/app/handlers.ts b/cli/src/commands/run/app/handlers.ts index 2cc11b026c..b22bfec58d 100644 --- a/cli/src/commands/run/app/handlers.ts +++ b/cli/src/commands/run/app/handlers.ts @@ -1,5 +1,5 @@ -import type { ColorScheme } from '../../../io/color.js' import type { TextHandler } from '../../../printers/format-text.js' +import type { ColorScheme } from '../../../sys/io/color.js' export const RUN_MODES = { Chat: 'chat', diff --git a/cli/src/commands/run/app/hitl-render.ts b/cli/src/commands/run/app/hitl-render.ts index f9c3ed6ac9..da02ecf5cb 100644 --- a/cli/src/commands/run/app/hitl-render.ts +++ b/cli/src/commands/run/app/hitl-render.ts @@ -1,5 +1,5 @@ import type { HitlPausePayload } from './sse-collector.js' -import { colorEnabled, colorScheme } from '../../../io/color.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' export type HitlExitObject = { status: 'paused' diff --git a/cli/src/commands/run/app/run.test.ts b/cli/src/commands/run/app/run.test.ts index a778abc078..5af12e2a41 100644 --- a/cli/src/commands/run/app/run.test.ts +++ b/cli/src/commands/run/app/run.test.ts @@ -7,7 +7,9 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest' import { startMock } from '../../../../test/fixtures/dify-mock/server.js' import { loadAppInfoCache } from '../../../cache/app-info.js' import { createClient } from '../../../http/client.js' -import { bufferStreams } from '../../../io/streams.js' +import { CACHE_APP_INFO, cachePath } from '../../../store/manager.js' +import { YamlStore } from '../../../store/store.js' +import { bufferStreams } from '../../../sys/io/streams' import { resumeApp } from '../../resume/app/run.js' import { runApp } from './run.js' @@ -39,7 +41,7 @@ describe('runApp', () => { it('chat: prints answer + conversation hint to stderr', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-1', message: 'hi' }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -50,7 +52,7 @@ describe('runApp', () => { it('workflow: rejects positional message with usage error', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await expect(runApp( { appId: 'app-2', message: 'hi' }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -59,7 +61,7 @@ describe('runApp', () => { it('workflow: prints single-string output as plain text', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-2', inputs: { x: '1' } }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -69,7 +71,7 @@ describe('runApp', () => { it('json: passes through full envelope', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-1', message: 'hi', format: 'json' }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -102,7 +104,7 @@ describe('runApp', () => { it('--stream chat: streams answer to stdout and hint to stderr', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-1', message: 'hi', stream: true }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -114,7 +116,7 @@ describe('runApp', () => { it('--stream -o json chat: aggregates into blocking-shape envelope', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-1', message: 'hi', stream: true, format: 'json' }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -127,7 +129,7 @@ describe('runApp', () => { it('agent-chat without --stream: collects and prints answer', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-4', workspace: 'ws-2', message: 'do research' }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -138,7 +140,7 @@ describe('runApp', () => { it('agent-chat with --stream: live-prints answer and thoughts to stderr', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-4', workspace: 'ws-2', message: 'go', stream: true }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -149,7 +151,7 @@ describe('runApp', () => { it('--stream workflow -o json: aggregates from workflow_finished', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-2', inputs: { x: '1' }, stream: true, format: 'json' }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -162,7 +164,7 @@ describe('runApp', () => { it('stream-error scenario: error event surfaces typed BaseError', async () => { mock.setScenario('stream-error') const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await expect(runApp( { appId: 'app-1', message: 'hi', stream: true }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test', retryAttempts: 0 }), host: mock.url, io, cache }, @@ -171,7 +173,7 @@ describe('runApp', () => { it('--inputs-file: reads inputs from file', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) const inputsFile = join(dir, 'inputs.json') const { writeFile } = await import('node:fs/promises') await writeFile(inputsFile, JSON.stringify({ x: 'from-file' })) @@ -195,7 +197,7 @@ describe('runApp', () => { it('--inputs: accepts JSON object string', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-2', inputsJson: '{"x":"hello"}' }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -217,7 +219,7 @@ describe('runApp', () => { it('hitl pause (text): writes readable block to stdout, hint to stderr, exits 0', async () => { mock.setScenario('hitl-pause') const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) let exitCode = -1 await expect(runApp( { appId: 'app-2', inputs: {} }, @@ -246,7 +248,7 @@ describe('runApp', () => { it('hitl pause (json): writes JSON envelope to stdout, exits 0', async () => { mock.setScenario('hitl-pause') const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) let exitCode = -1 await expect(runApp( { appId: 'app-2', inputs: {}, format: 'json' }, @@ -272,7 +274,7 @@ describe('runApp', () => { it('resume: withHistory: false completes successfully', async () => { mock.setScenario('hitl-resume') const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await resumeApp( { appId: 'app-2', formToken: 'ft-hitl-1', workflowRunId: 'wf-run-hitl-1', action: 'submit', inputs: {}, withHistory: false }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -283,7 +285,7 @@ describe('runApp', () => { it('resume: submits form and streams workflow to completion', async () => { mock.setScenario('hitl-resume') const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await resumeApp( { appId: 'app-2', formToken: 'ft-hitl-1', workflowRunId: 'wf-run-hitl-1', action: 'submit', inputs: {} }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -294,7 +296,7 @@ describe('runApp', () => { it('resume --stream: live-prints workflow node progress to stderr', async () => { mock.setScenario('hitl-resume') const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await resumeApp( { appId: 'app-2', formToken: 'ft-hitl-1', workflowRunId: 'wf-run-hitl-1', action: 'submit', inputs: {}, stream: true }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -305,7 +307,7 @@ describe('runApp', () => { it('workflow: --file remote URL is passed as remote_url input variable', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-2', files: ['doc=https://example.com/report.pdf'] }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, @@ -324,7 +326,7 @@ describe('runApp', () => { it('workflow: --file @path uploads file and passes local_file input variable', async () => { const { writeFile } = await import('node:fs/promises') const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) const filePath = join(dir, 'test.pdf') await writeFile(filePath, 'fake pdf content') await runApp( @@ -343,7 +345,7 @@ describe('runApp', () => { it('workflow: --file overrides same-named key from --inputs (file wins)', async () => { const io = bufferStreams() - const cache = await loadAppInfoCache({ configDir: dir }) + const cache = await loadAppInfoCache({ store: new YamlStore(cachePath(dir, CACHE_APP_INFO)) }) await runApp( { appId: 'app-2', inputs: { doc: 'old-value' }, files: ['doc=https://example.com/override.pdf'] }, { bundle: bundle(), http: createClient({ host: mock.url, bearer: 'dfoa_test' }), host: mock.url, io, cache }, diff --git a/cli/src/commands/run/app/run.ts b/cli/src/commands/run/app/run.ts index d63787090e..eb9a2e4d53 100644 --- a/cli/src/commands/run/app/run.ts +++ b/cli/src/commands/run/app/run.ts @@ -1,13 +1,14 @@ import type { KyInstance } from 'ky' import type { HostsBundle } from '../../../auth/hosts.js' import type { AppInfoCache } from '../../../cache/app-info.js' -import type { IOStreams } from '../../../io/streams.js' +import type { IOStreams } from '../../../sys/io/streams' import { AppMetaClient } from '../../../api/app-meta.js' import { AppRunClient } from '../../../api/app-run.js' import { AppsClient } from '../../../api/apps.js' import { FileUploadClient } from '../../../api/file-upload.js' import { BaseError } from '../../../errors/base.js' import { ErrorCode } from '../../../errors/codes.js' +import { getEnv, processExit } from '../../../sys/index.js' import { FieldInfo } from '../../../types/app-meta.js' import { resolveWorkspaceId } from '../../../workspace/resolver.js' import { pickStrategy } from './_strategies/index.js' @@ -78,7 +79,7 @@ async function resolveInputs( } export async function runApp(opts: RunAppOptions, deps: RunAppDeps): Promise { - const env = deps.envLookup ?? ((k: string) => process.env[k]) + const env = deps.envLookup ?? getEnv const wsId = resolveWorkspaceId({ flag: opts.workspace, env: env('DIFY_WORKSPACE_ID'), bundle: deps.bundle }) const apps = new AppsClient(deps.http) const meta = new AppMetaClient({ apps, host: deps.host, cache: deps.cache }) @@ -111,7 +112,7 @@ export async function runApp(opts: RunAppOptions, deps: RunAppDeps): Promise process.exit(code) as never) + const exit = deps.exit ?? processExit const ctx = { opts: { ...opts, inputs }, deps, mode, format, isText, livePrint, runClient, printFlags, exit, think: opts.think ?? false } await pickStrategy(isText, livePrint).execute(ctx) } diff --git a/cli/src/commands/run/app/stream-handlers.ts b/cli/src/commands/run/app/stream-handlers.ts index a54dbfe54b..53dc746602 100644 --- a/cli/src/commands/run/app/stream-handlers.ts +++ b/cli/src/commands/run/app/stream-handlers.ts @@ -3,8 +3,8 @@ import type { StreamPrinter } from '../../../printers/stream-printer.js' import type { HitlPausePayload } from './sse-collector.js' import { newError } from '../../../errors/base.js' import { ErrorCode } from '../../../errors/codes.js' -import { colorEnabled, colorScheme } from '../../../io/color.js' -import { ThinkChunkFilter } from '../../../io/think-filter.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' +import { ThinkChunkFilter } from '../../../sys/io/think-filter.js' import { RUN_MODES } from './handlers.js' import { HitlPauseError } from './sse-collector.js' diff --git a/cli/src/commands/set/member/handlers.ts b/cli/src/commands/set/member/handlers.ts new file mode 100644 index 0000000000..23bd04c521 --- /dev/null +++ b/cli/src/commands/set/member/handlers.ts @@ -0,0 +1,26 @@ +export type SetMemberPayload = { + readonly id: string + readonly role: 'normal' | 'admin' +} + +export class SetMemberOutput { + readonly payload: SetMemberPayload + readonly textLine: string + + constructor(payload: SetMemberPayload, textLine: string) { + this.payload = payload + this.textLine = textLine + } + + text(): string { + return this.textLine + } + + json(): SetMemberPayload { + return this.payload + } + + name(): string { + return this.payload.id + } +} diff --git a/cli/src/commands/set/member/index.ts b/cli/src/commands/set/member/index.ts new file mode 100644 index 0000000000..3cbf3bf106 --- /dev/null +++ b/cli/src/commands/set/member/index.ts @@ -0,0 +1,43 @@ +import { Args, Flags } from '../../../framework/flags.js' +import { formatted } from '../../../framework/output.js' +import { DifyCommand } from '../../_shared/dify-command.js' +import { httpRetryFlag } from '../../_shared/global-flags.js' +import { runSetMember } from './run.js' + +export default class SetMember extends DifyCommand { + static override description = 'Change a member\'s role in the active (or specified) workspace' + + static override examples = [ + '<%= config.bin %> set member acct-1 --role admin', + '<%= config.bin %> set member acct-1 --role normal -w ws-1', + '<%= config.bin %> set member acct-1 --role admin -o json', + ] + + static override args = { + memberId: Args.string({ description: 'account id of the member to update', required: true }), + } + + static override flags = { + 'role': Flags.string({ + description: 'new role (normal|admin); owner is not assignable here', + required: true, + }), + 'workspace': Flags.string({ + char: 'w', + description: 'workspace id (overrides DIFY_WORKSPACE_ID and stored default)', + }), + 'http-retry': httpRetryFlag, + 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|text)', default: '' }), + } + + async run(argv: string[]) { + const { args, flags } = this.parse(SetMember, argv) + const format = flags.output + const ctx = await this.authedCtx({ retryFlag: flags['http-retry'], format }) + const result = await runSetMember( + { memberId: args.memberId, role: flags.role, workspace: flags.workspace, format }, + { bundle: ctx.bundle, http: ctx.http, io: ctx.io }, + ) + return formatted({ format, data: result.data }) + } +} diff --git a/cli/src/commands/set/member/run.test.ts b/cli/src/commands/set/member/run.test.ts new file mode 100644 index 0000000000..ff987b815f --- /dev/null +++ b/cli/src/commands/set/member/run.test.ts @@ -0,0 +1,87 @@ +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import { describe, expect, it, vi } from 'vitest' +import { bufferStreams } from '../../../sys/io/streams' +import { runSetMember } from './run.js' + +function bundle(): HostsBundle { + return { + current_host: 'cloud.dify.ai', + token_storage: 'file', + tokens: { bearer: 'dfoa_test' }, + account: { id: 'acct-1', email: 'me@example.com', name: 'Me' }, + workspace: { id: 'ws-1', name: 'Default', role: 'owner' }, + available_workspaces: [{ id: 'ws-1', name: 'Default', role: 'owner' }], + } +} + +function fakeClient() { + return { + updateRole: vi.fn(() => Promise.resolve({ result: 'success' as const })), + } +} + +describe('runSetMember', () => { + it('happy path: PUT new role, returns SetMemberOutput with text/json/name', async () => { + const client = fakeClient() + const result = await runSetMember( + { memberId: 'acct-2', role: 'admin' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.updateRole).toHaveBeenCalledExactlyOnceWith('ws-1', 'acct-2', { role: 'admin' }) + expect(result.data.text()).toMatch(/Set acct-2 role to admin/) + expect(result.data.name()).toBe('acct-2') + expect(result.data.json()).toEqual({ id: 'acct-2', role: 'admin' }) + expect(result.workspaceId).toBe('ws-1') + }) + + it('rejects unknown role before any HTTP call', async () => { + const client = fakeClient() + await expect( + runSetMember( + { memberId: 'acct-2', role: 'owner' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ), + ).rejects.toThrow(/invalid --role/) + expect(client.updateRole).not.toHaveBeenCalled() + }) + + it('rejects empty member id', async () => { + const client = fakeClient() + await expect( + runSetMember( + { memberId: '', role: 'admin' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ), + ).rejects.toThrow(/member id is required/) + }) + + it('-w flag overrides resolved workspace', async () => { + const client = fakeClient() + await runSetMember( + { memberId: 'acct-2', role: 'normal', workspace: 'ws-9' }, + { + bundle: bundle(), + http: {} as KyInstance, + io: bufferStreams(), + membersFactory: () => client as never, + }, + ) + expect(client.updateRole).toHaveBeenCalledWith('ws-9', 'acct-2', { role: 'normal' }) + }) +}) diff --git a/cli/src/commands/set/member/run.ts b/cli/src/commands/set/member/run.ts new file mode 100644 index 0000000000..f77e09d4e6 --- /dev/null +++ b/cli/src/commands/set/member/run.ts @@ -0,0 +1,78 @@ +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import type { IOStreams } from '../../../sys/io/streams.js' +import { MembersClient } from '../../../api/members.js' +import { BaseError } from '../../../errors/base.js' +import { ErrorCode } from '../../../errors/codes.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' +import { nullStreams } from '../../../sys/io/streams.js' +import { resolveWorkspaceId } from '../../../workspace/resolver.js' +import { SetMemberOutput } from './handlers.js' + +export type SetMemberOptions = { + readonly memberId: string + readonly role: string + readonly workspace?: string + readonly format?: string +} + +export type SetMemberDeps = { + readonly bundle: HostsBundle + readonly http: KyInstance + readonly io?: IOStreams + readonly envLookup?: (k: string) => string | undefined + readonly membersFactory?: (http: KyInstance) => MembersClient +} + +export type SetMemberResult = { + readonly data: SetMemberOutput + readonly workspaceId: string +} + +const ASSIGNABLE_ROLES = new Set(['normal', 'admin']) + +export async function runSetMember( + opts: SetMemberOptions, + deps: SetMemberDeps, +): Promise { + if (opts.memberId === undefined || opts.memberId === '') { + throw new BaseError({ + code: ErrorCode.UsageMissingArg, + message: 'member id is required', + hint: 'pass it positionally: difyctl set member --role ', + }) + } + if (!ASSIGNABLE_ROLES.has(opts.role)) { + throw new BaseError({ + code: ErrorCode.UsageInvalidFlag, + message: `invalid --role "${opts.role}"`, + hint: 'expected: normal | admin (ownership transfer is console-only)', + }) + } + + const env = deps.envLookup ?? ((k: string) => process.env[k]) + const factory = deps.membersFactory ?? ((h: KyInstance) => new MembersClient(h)) + const io = deps.io ?? nullStreams() + const cs = colorScheme(colorEnabled(io.isErrTTY)) + + const wsId = resolveWorkspaceId({ + flag: opts.workspace, + env: env('DIFY_WORKSPACE_ID'), + bundle: deps.bundle, + }) + + await runWithSpinner( + { io, label: `Updating role for ${opts.memberId}` }, + () => factory(deps.http).updateRole(wsId, opts.memberId, { + role: opts.role as 'normal' | 'admin', + }), + ) + + const role = opts.role as 'normal' | 'admin' + const textLine = `${cs.successIcon()} Set ${opts.memberId} role to ${role}\n` + return { + data: new SetMemberOutput({ id: opts.memberId, role }, textLine), + workspaceId: wsId, + } +} diff --git a/cli/src/commands/tree.generated.ts b/cli/src/commands/tree.generated.ts index 666884917c..51a77d1a99 100644 --- a/cli/src/commands/tree.generated.ts +++ b/cli/src/commands/tree.generated.ts @@ -7,22 +7,26 @@ import AuthDevicesRevoke from './auth/devices/revoke/index.js' import AuthLogin from './auth/login/index.js' import AuthLogout from './auth/logout/index.js' import AuthStatus from './auth/status/index.js' -import AuthUse from './auth/use/index.js' import AuthWhoami from './auth/whoami/index.js' import ConfigGet from './config/get/index.js' import ConfigPath from './config/path/index.js' import ConfigSet from './config/set/index.js' import ConfigUnset from './config/unset/index.js' import ConfigView from './config/view/index.js' +import CreateMember from './create/member/index.js' +import DeleteMember from './delete/member/index.js' import DescribeApp from './describe/app/index.js' import EnvList from './env/list/index.js' import GetApp from './get/app/index.js' +import GetMember from './get/member/index.js' import GetWorkspace from './get/workspace/index.js' import HelpAccount from './help/account/index.js' import HelpEnvironment from './help/environment/index.js' import HelpExternal from './help/external/index.js' import ResumeApp from './resume/app/index.js' import RunApp from './run/app/index.js' +import SetMember from './set/member/index.js' +import UseWorkspace from './use/workspace/index.js' import Version from './version/index.js' export const commandTree: CommandTree = { @@ -37,7 +41,6 @@ export const commandTree: CommandTree = { login: { command: AuthLogin, subcommands: {} }, logout: { command: AuthLogout, subcommands: {} }, status: { command: AuthStatus, subcommands: {} }, - use: { command: AuthUse, subcommands: {} }, whoami: { command: AuthWhoami, subcommands: {} }, }, }, @@ -50,6 +53,16 @@ export const commandTree: CommandTree = { view: { command: ConfigView, subcommands: {} }, }, }, + create: { + subcommands: { + member: { command: CreateMember, subcommands: {} }, + }, + }, + delete: { + subcommands: { + member: { command: DeleteMember, subcommands: {} }, + }, + }, describe: { subcommands: { app: { command: DescribeApp, subcommands: {} }, @@ -63,6 +76,7 @@ export const commandTree: CommandTree = { get: { subcommands: { app: { command: GetApp, subcommands: {} }, + member: { command: GetMember, subcommands: {} }, workspace: { command: GetWorkspace, subcommands: {} }, }, }, @@ -83,5 +97,15 @@ export const commandTree: CommandTree = { app: { command: RunApp, subcommands: {} }, }, }, + set: { + subcommands: { + member: { command: SetMember, subcommands: {} }, + }, + }, + use: { + subcommands: { + workspace: { command: UseWorkspace, subcommands: {} }, + }, + }, version: { command: Version, subcommands: {} }, } diff --git a/cli/src/commands/use/workspace/index.ts b/cli/src/commands/use/workspace/index.ts new file mode 100644 index 0000000000..239ac9a44f --- /dev/null +++ b/cli/src/commands/use/workspace/index.ts @@ -0,0 +1,31 @@ +import { Args } from '../../../framework/flags.js' +import { DifyCommand } from '../../_shared/dify-command.js' +import { httpRetryFlag } from '../../_shared/global-flags.js' +import { runUseWorkspace } from './use.js' + +export default class UseWorkspace extends DifyCommand { + static override description = 'Switch the active workspace on the server and refresh hosts.yml' + + static override examples = [ + '<%= config.bin %> use workspace ws-abc123', + ] + + static override args = { + workspaceId: Args.string({ description: 'workspace id to switch to', required: true }), + } + + static override flags = { + 'http-retry': httpRetryFlag, + } + + async run(argv: string[]): Promise { + const { args, flags } = this.parse(UseWorkspace, argv) + const ctx = await this.authedCtx({ retryFlag: flags['http-retry'] }) + await runUseWorkspace({ workspaceId: args.workspaceId }, { + configDir: ctx.configDir, + bundle: ctx.bundle, + http: ctx.http, + io: ctx.io, + }) + } +} diff --git a/cli/src/commands/use/workspace/use.test.ts b/cli/src/commands/use/workspace/use.test.ts new file mode 100644 index 0000000000..b3c78988f4 --- /dev/null +++ b/cli/src/commands/use/workspace/use.test.ts @@ -0,0 +1,199 @@ +import type { + WorkspaceDetailResponse, + WorkspaceListResponse, +} from '@dify/contracts/api/openapi/types.gen' +import type { KyInstance } from 'ky' +import type { HostsBundle } from '../../../auth/hosts.js' +import { mkdtemp, rm } from 'node:fs/promises' +import { tmpdir } from 'node:os' +import { join } from 'node:path' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { loadHosts, saveHosts } from '../../../auth/hosts.js' +import { bufferStreams } from '../../../sys/io/streams.js' +import { runUseWorkspace } from './use.js' + +function bundle(): HostsBundle { + return { + current_host: 'cloud.dify.ai', + token_storage: 'file', + tokens: { bearer: 'dfoa_test' }, + account: { id: 'acct-1', email: 'tester@dify.ai', name: 'Tester' }, + workspace: { id: 'ws-1', name: 'Default', role: 'owner' }, + available_workspaces: [ + { id: 'ws-1', name: 'Default', role: 'owner' }, + { id: 'ws-2', name: 'Stale Name', role: 'normal' }, + ], + } +} + +function fakeClient(opts: { + switch?: () => Promise + list?: () => Promise +}) { + return { + switch: vi.fn(opts.switch ?? (() => Promise.resolve({ + id: 'ws-2', + name: 'Switched', + role: 'normal', + status: 'normal', + current: true, + created_at: '2026-05-18T00:00:00Z', + }))), + list: vi.fn(opts.list ?? (() => Promise.resolve({ + workspaces: [ + { id: 'ws-1', name: 'Default', role: 'owner', status: 'normal', current: false }, + { id: 'ws-2', name: 'Switched', role: 'normal', status: 'normal', current: true }, + ], + }))), + } +} + +describe('runUseWorkspace', () => { + let configDir: string + + beforeEach(async () => { + configDir = await mkdtemp(join(tmpdir(), 'difyctl-use-workspace-')) + }) + afterEach(async () => { + await rm(configDir, { recursive: true, force: true }) + }) + + it('happy path: POST /switch → GET /workspaces → write hosts.yml', async () => { + const io = bufferStreams() + const b = bundle() + await saveHosts(configDir, b) + const client = fakeClient({}) + + const next = await runUseWorkspace( + { workspaceId: 'ws-2' }, + { + configDir, + bundle: b, + http: {} as KyInstance, + io, + workspacesFactory: () => client as never, + }, + ) + + expect(client.switch).toHaveBeenCalledExactlyOnceWith('ws-2') + expect(client.list).toHaveBeenCalledOnce() + expect(next.workspace).toEqual({ id: 'ws-2', name: 'Switched', role: 'normal' }) + expect(next.available_workspaces).toEqual([ + { id: 'ws-1', name: 'Default', role: 'owner' }, + { id: 'ws-2', name: 'Switched', role: 'normal' }, + ]) + const reloaded = await loadHosts(configDir) + expect(reloaded?.workspace?.id).toBe('ws-2') + expect(reloaded?.workspace?.name).toBe('Switched') + expect(io.outBuf()).toMatch(/Switched to Switched \(ws-2\)/) + }) + + it('refreshes stale workspace name from server', async () => { + // bundle has ws-2 named "Stale Name"; server returns "Switched". + // We expect saveHosts to record the fresh name from the server. + const io = bufferStreams() + const b = bundle() + await saveHosts(configDir, b) + const client = fakeClient({}) + + await runUseWorkspace( + { workspaceId: 'ws-2' }, + { configDir, bundle: b, http: {} as KyInstance, io, workspacesFactory: () => client as never }, + ) + + const reloaded = await loadHosts(configDir) + expect(reloaded?.workspace?.name).toBe('Switched') + expect(reloaded?.available_workspaces?.find(w => w.id === 'ws-2')?.name).toBe('Switched') + }) + + it('does NOT mutate hosts.yml when POST /switch fails', async () => { + const io = bufferStreams() + const b = bundle() + await saveHosts(configDir, b) + const before = await loadHosts(configDir) + + const client = fakeClient({ + switch: () => Promise.reject(new Error('forbidden')), + }) + + await expect( + runUseWorkspace( + { workspaceId: 'ws-2' }, + { + configDir, + bundle: b, + http: {} as KyInstance, + io, + workspacesFactory: () => client as never, + }, + ), + ).rejects.toThrow(/forbidden/) + + expect(client.list).not.toHaveBeenCalled() + const after = await loadHosts(configDir) + expect(after).toEqual(before) + expect(after?.workspace?.id).toBe('ws-1') + }) + + it('does NOT mutate hosts.yml when GET /workspaces fails after switch', async () => { + const io = bufferStreams() + const b = bundle() + await saveHosts(configDir, b) + const before = await loadHosts(configDir) + + const client = fakeClient({ + list: () => Promise.reject(new Error('transient list failure')), + }) + + await expect( + runUseWorkspace( + { workspaceId: 'ws-2' }, + { + configDir, + bundle: b, + http: {} as KyInstance, + io, + workspacesFactory: () => client as never, + }, + ), + ).rejects.toThrow(/transient list failure/) + + const after = await loadHosts(configDir) + expect(after).toEqual(before) + }) + + it('throws when server returns switch= but id is missing from /workspaces list', async () => { + const io = bufferStreams() + const b = bundle() + await saveHosts(configDir, b) + + const client = fakeClient({ + switch: () => Promise.resolve({ + id: 'ws-7', + name: 'Ghost', + role: 'normal', + status: 'normal', + current: true, + created_at: null as unknown as string, + }), + list: () => Promise.resolve({ + workspaces: [ + { id: 'ws-1', name: 'Default', role: 'owner', status: 'normal', current: false }, + ], + }), + }) + + await expect( + runUseWorkspace( + { workspaceId: 'ws-7' }, + { + configDir, + bundle: b, + http: {} as KyInstance, + io, + workspacesFactory: () => client as never, + }, + ), + ).rejects.toThrow(/not visible in \/workspaces/) + }) +}) diff --git a/cli/src/commands/use/workspace/use.ts b/cli/src/commands/use/workspace/use.ts new file mode 100644 index 0000000000..b97b9dd224 --- /dev/null +++ b/cli/src/commands/use/workspace/use.ts @@ -0,0 +1,76 @@ +import type { KyInstance } from 'ky' +import type { HostsBundle, Workspace } from '../../../auth/hosts.js' +import type { IOStreams } from '../../../sys/io/streams.js' +import { WorkspacesClient } from '../../../api/workspaces.js' +import { saveHosts } from '../../../auth/hosts.js' +import { BaseError } from '../../../errors/base.js' +import { ErrorCode } from '../../../errors/codes.js' +import { colorEnabled, colorScheme } from '../../../sys/io/color.js' +import { runWithSpinner } from '../../../sys/io/spinner.js' + +export type UseWorkspaceOptions = { + readonly workspaceId: string +} + +export type UseWorkspaceDeps = { + readonly configDir: string + readonly bundle: HostsBundle + readonly http: KyInstance + readonly io: IOStreams + readonly workspacesFactory?: (http: KyInstance) => WorkspacesClient +} + +/** + * Switch the caller's active workspace. + * + * Strict ordering: + * 1. POST /workspaces//switch — if this fails (403/404/etc.) we abort + * with no `hosts.yml` mutation, so local state never diverges from the + * server. Any fallback to a pure-local update is explicitly disallowed + * (see workspace-plan.md decision D4). + * 2. GET /workspaces — refresh the membership list so `available_workspaces` + * stays in sync. Failure here also aborts; the server-side current has + * already moved, but the local file is left untouched. A follow-up + * `difyctl get workspace` will reconcile. + * 3. Persist `workspace` + `available_workspaces` atomically via `saveHosts`. + */ +export async function runUseWorkspace( + opts: UseWorkspaceOptions, + deps: UseWorkspaceDeps, +): Promise { + const cs = colorScheme(colorEnabled(deps.io.isErrTTY)) + const factory = deps.workspacesFactory ?? ((h: KyInstance) => new WorkspacesClient(h)) + const client = factory(deps.http) + + const detail = await runWithSpinner( + { io: deps.io, label: `Switching to ${opts.workspaceId}` }, + () => client.switch(opts.workspaceId), + ) + + const list = await runWithSpinner( + { io: deps.io, label: 'Refreshing workspaces' }, + () => client.list(), + ) + + const matched = list.workspaces.find(w => w.id === detail.id) + if (matched === undefined) { + throw new BaseError({ + code: ErrorCode.Unknown, + message: `server returned switch=${detail.id} but it is not visible in /workspaces`, + hint: 'try again or contact your workspace admin', + }) + } + + const next: HostsBundle = { + ...deps.bundle, + workspace: { id: matched.id, name: matched.name, role: matched.role }, + available_workspaces: list.workspaces.map(w => ({ + id: w.id, + name: w.name, + role: w.role, + })), + } + await saveHosts(deps.configDir, next) + deps.io.out.write(`${cs.successIcon()} Switched to ${matched.name} (${matched.id})\n`) + return next +} diff --git a/cli/src/commands/version/index.ts b/cli/src/commands/version/index.ts index b91169b1d4..01d398a001 100644 --- a/cli/src/commands/version/index.ts +++ b/cli/src/commands/version/index.ts @@ -1,7 +1,7 @@ import { Flags } from '../../framework/flags.js' import { formatted, raw, stringifyOutput } from '../../framework/output.js' -import { colorEnabled } from '../../io/color.js' -import { realStreams } from '../../io/streams.js' +import { colorEnabled } from '../../sys/io/color.js' +import { realStreams } from '../../sys/io/streams' import { versionInfo } from '../../version/info.js' import { runVersionProbe } from '../../version/probe.js' import { renderVersionText } from '../../version/render.js' @@ -54,7 +54,7 @@ export default class Version extends DifyCommand { // Emit the full report first so `difyctl version -o json --check-compat | jq` // works exactly like the success path: stdout gets the canonical envelope, // stderr gets the one-line failure reason, exit code signals the verdict. - process.stdout.write(stringifyOutput(output)) + io.out.write(stringifyOutput(output)) this.error(report.compat.detail, { exit: COMPAT_FAIL_EXIT_CODE }) } diff --git a/cli/src/config/loader.test.ts b/cli/src/config/config-loader.test.ts similarity index 80% rename from cli/src/config/loader.test.ts rename to cli/src/config/config-loader.test.ts index da7bac2c1f..54e3d5d027 100644 --- a/cli/src/config/loader.test.ts +++ b/cli/src/config/config-loader.test.ts @@ -2,10 +2,15 @@ import { mkdir, mkdtemp, writeFile } from 'node:fs/promises' import { tmpdir } from 'node:os' import { join } from 'node:path' import { afterEach, beforeEach, describe, expect, it } from 'vitest' -import { isBaseError } from '../errors/base.js' -import { ErrorCode } from '../errors/codes.js' -import { loadConfig } from './loader.js' -import { FILE_NAME } from './schema.js' +import { isBaseError } from '../errors/base' +import { ErrorCode } from '../errors/codes' +import { YamlStore } from '../store/store' +import { loadConfig } from './config-loader' +import { FILE_NAME } from './schema' + +function makeStore(dir: string): YamlStore { + return new YamlStore(join(dir, FILE_NAME)) +} describe('loadConfig', () => { let dir: string @@ -18,14 +23,14 @@ describe('loadConfig', () => { await mkdir(dir, { recursive: true }).catch(() => {}) }) - it('returns found:false when config.yml is missing', async () => { - const r = await loadConfig(dir) + it('returns found:false when config.yml is missing', () => { + const r = loadConfig(makeStore(dir)) expect(r.found).toBe(false) }) it('parses a minimal valid config.yml', async () => { await writeFile(join(dir, FILE_NAME), 'schema_version: 1\n', 'utf8') - const r = await loadConfig(dir) + const r = loadConfig(makeStore(dir)) expect(r.found).toBe(true) if (r.found) expect(r.config.schema_version).toBe(1) @@ -37,7 +42,7 @@ describe('loadConfig', () => { 'schema_version: 1\ndefaults:\n format: json\n limit: 100\nstate:\n current_app: app-1\n', 'utf8', ) - const r = await loadConfig(dir) + const r = loadConfig(makeStore(dir)) expect(r.found).toBe(true) if (r.found) { expect(r.config.defaults.format).toBe('json') @@ -50,7 +55,7 @@ describe('loadConfig', () => { await writeFile(join(dir, FILE_NAME), '::not yaml::: {{[', 'utf8') let caught: unknown try { - await loadConfig(dir) + loadConfig(makeStore(dir)) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) @@ -62,7 +67,7 @@ describe('loadConfig', () => { await writeFile(join(dir, FILE_NAME), 'defaults:\n limit: 9999\n', 'utf8') let caught: unknown try { - await loadConfig(dir) + loadConfig(makeStore(dir)) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) @@ -74,7 +79,7 @@ describe('loadConfig', () => { await writeFile(join(dir, FILE_NAME), 'schema_version: 2\n', 'utf8') let caught: unknown try { - await loadConfig(dir) + loadConfig(makeStore(dir)) } catch (err) { caught = err } expect(isBaseError(caught)).toBe(true) diff --git a/cli/src/config/config-loader.ts b/cli/src/config/config-loader.ts new file mode 100644 index 0000000000..c260d49d8e --- /dev/null +++ b/cli/src/config/config-loader.ts @@ -0,0 +1,42 @@ +import type { YamlStore } from '../store/store' +import type { ConfigFile } from './schema' +import { newError } from '../errors/base' +import { ErrorCode } from '../errors/codes' +import { ConfigFileSchema, CURRENT_SCHEMA_VERSION } from './schema' + +export type LoadResult + = | { found: false } + | { found: true, config: ConfigFile } + +export function loadConfig(store: YamlStore): LoadResult { + let raw: Record | null + try { + raw = store.getTyped>() + } + catch (err) { + throw newError( + ErrorCode.ConfigSchemaUnsupported, + `parse config.yml: ${(err as Error).message}`, + ).wrap(err).withHint('config.yml is not valid YAML') + } + + if (raw === null) + return { found: false } + + const result = ConfigFileSchema.safeParse(raw) + if (!result.success) { + throw newError( + ErrorCode.ConfigSchemaUnsupported, + `validate config.yml: ${result.error.issues.map(i => i.message).join('; ')}`, + ).withHint('config.yml does not match the v1 schema') + } + + if (result.data.schema_version > CURRENT_SCHEMA_VERSION) { + throw newError( + ErrorCode.ConfigSchemaUnsupported, + `config.yml schema_version=${result.data.schema_version} is newer than this binary supports (max=${CURRENT_SCHEMA_VERSION})`, + ).withHint('upgrade difyctl, or remove config.yml') + } + + return { found: true, config: result.data } +} diff --git a/cli/src/config/dir.test.ts b/cli/src/config/dir.test.ts deleted file mode 100644 index 24ecde3986..0000000000 --- a/cli/src/config/dir.test.ts +++ /dev/null @@ -1,71 +0,0 @@ -import { describe, expect, it } from 'vitest' -import { DIR_PERM, FILE_PERM, resolveConfigDir } from './dir.js' - -function fakeEnv(opts: { - override?: string - xdg?: string - home?: string - appData?: string - platform: NodeJS.Platform -}) { - return { - getEnv: (name: string) => { - if (name === 'DIFY_CONFIG_DIR') - return opts.override - if (name === 'XDG_CONFIG_HOME') - return opts.xdg - return undefined - }, - homeDir: () => opts.home ?? '/home/u', - platform: () => opts.platform, - appData: () => opts.appData, - } -} - -describe('config dir', () => { - it('FILE_PERM is 0o600 + DIR_PERM is 0o700 (POSIX defaults)', () => { - expect(FILE_PERM).toBe(0o600) - expect(DIR_PERM).toBe(0o700) - }) - - it('DIFY_CONFIG_DIR override wins on every platform', () => { - for (const platform of ['linux', 'darwin', 'win32'] as const) { - expect(resolveConfigDir(fakeEnv({ override: '/tmp/x', platform }))) - .toBe('/tmp/x') - } - }) - - it('linux uses XDG_CONFIG_HOME when set', () => { - expect(resolveConfigDir(fakeEnv({ xdg: '/x', platform: 'linux' }))) - .toBe('/x/difyctl') - }) - - it('linux falls back to ~/.config when XDG unset', () => { - expect(resolveConfigDir(fakeEnv({ home: '/h', platform: 'linux' }))) - .toBe('/h/.config/difyctl') - }) - - it('linux ignores empty XDG_CONFIG_HOME', () => { - expect(resolveConfigDir(fakeEnv({ xdg: '', home: '/h', platform: 'linux' }))) - .toBe('/h/.config/difyctl') - }) - - it('macos uses ~/.config (not XDG, matches gh/kubectl)', () => { - expect(resolveConfigDir(fakeEnv({ xdg: '/ignored', home: '/h', platform: 'darwin' }))) - .toBe('/h/.config/difyctl') - }) - - it('windows uses APPDATA', () => { - expect(resolveConfigDir(fakeEnv({ appData: 'C:\\Users\\u\\AppData\\Roaming', platform: 'win32' }))) - .toMatch(/difyctl$/) - }) - - it('windows throws if APPDATA unresolvable', () => { - expect(() => resolveConfigDir(fakeEnv({ platform: 'win32' }))).toThrow(/APPDATA/) - }) - - it('unknown platform falls back to ~/.config', () => { - expect(resolveConfigDir(fakeEnv({ home: '/h', platform: 'freebsd' as NodeJS.Platform }))) - .toBe('/h/.config/difyctl') - }) -}) diff --git a/cli/src/config/dir.ts b/cli/src/config/dir.ts deleted file mode 100644 index 6d92953769..0000000000 --- a/cli/src/config/dir.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { homedir } from 'node:os' -import { join } from 'node:path' - -export const ENV_CONFIG_DIR = 'DIFY_CONFIG_DIR' -export const ENV_XDG_CONFIG_HOME = 'XDG_CONFIG_HOME' -export const SUBDIR = 'difyctl' -export const FILE_PERM = 0o600 -export const DIR_PERM = 0o700 - -export type ConfigEnvironment = { - readonly getEnv: (name: string) => string | undefined - readonly homeDir: () => string - readonly platform: () => NodeJS.Platform - readonly appData: () => string | undefined -} - -export const realEnvironment: ConfigEnvironment = { - getEnv: name => process.env[name], - homeDir: () => homedir(), - platform: () => process.platform, - appData: () => process.env.APPDATA ?? process.env.LOCALAPPDATA, -} - -export function resolveConfigDir(env: ConfigEnvironment = realEnvironment): string { - const override = env.getEnv(ENV_CONFIG_DIR) - if (override !== undefined && override !== '') - return override - - const platform = env.platform() - if (platform === 'linux') { - const xdg = env.getEnv(ENV_XDG_CONFIG_HOME) - if (xdg !== undefined && xdg !== '') - return join(xdg, SUBDIR) - return join(env.homeDir(), '.config', SUBDIR) - } - if (platform === 'darwin') - return join(env.homeDir(), '.config', SUBDIR) - if (platform === 'win32') { - const appData = env.appData() - if (appData === undefined || appData === '') - throw new Error('cannot resolve %APPDATA% on Windows') - return join(appData, SUBDIR) - } - return join(env.homeDir(), '.config', SUBDIR) -} diff --git a/cli/src/config/loader.ts b/cli/src/config/loader.ts deleted file mode 100644 index 8ff00b3631..0000000000 --- a/cli/src/config/loader.ts +++ /dev/null @@ -1,58 +0,0 @@ -import type { ConfigFile } from './schema.js' -import { readFile } from 'node:fs/promises' -import { join } from 'node:path' -import { load as parseYaml } from 'js-yaml' -import { newError } from '../errors/base.js' -import { ErrorCode } from '../errors/codes.js' -import { - - ConfigFileSchema, - CURRENT_SCHEMA_VERSION, - FILE_NAME, -} from './schema.js' - -export type LoadResult - = | { found: false } - | { found: true, config: ConfigFile } - -export async function loadConfig(dir: string): Promise { - const path = join(dir, FILE_NAME) - let raw: string - try { - raw = await readFile(path, 'utf8') - } - catch (err) { - if ((err as NodeJS.ErrnoException).code === 'ENOENT') - return { found: false } - throw newError(ErrorCode.Unknown, `read ${path}: ${(err as Error).message}`) - .wrap(err) - } - - let parsed: unknown - try { - parsed = parseYaml(raw) - } - catch (err) { - throw newError( - ErrorCode.ConfigSchemaUnsupported, - `parse ${path}: ${(err as Error).message}`, - ).wrap(err).withHint('config.yml is not valid YAML') - } - - const result = ConfigFileSchema.safeParse(parsed ?? {}) - if (!result.success) { - throw newError( - ErrorCode.ConfigSchemaUnsupported, - `validate ${path}: ${result.error.issues.map(i => i.message).join('; ')}`, - ).withHint('config.yml does not match the v1 schema') - } - - if (result.data.schema_version > CURRENT_SCHEMA_VERSION) { - throw newError( - ErrorCode.ConfigSchemaUnsupported, - `config.yml schema_version=${result.data.schema_version} is newer than this binary supports (max=${CURRENT_SCHEMA_VERSION})`, - ).withHint('upgrade difyctl, or remove config.yml') - } - - return { found: true, config: result.data } -} diff --git a/cli/src/config/writer.ts b/cli/src/config/writer.ts deleted file mode 100644 index 8362ebf884..0000000000 --- a/cli/src/config/writer.ts +++ /dev/null @@ -1,39 +0,0 @@ -import type { ConfigFile } from './schema.js' -import { mkdir, rename, unlink, writeFile } from 'node:fs/promises' -import { join } from 'node:path' -import { dump as dumpYaml } from 'js-yaml' -import { newError } from '../errors/base.js' -import { ErrorCode } from '../errors/codes.js' -import { DIR_PERM, FILE_PERM } from './dir.js' -import { - - CURRENT_SCHEMA_VERSION, - FILE_NAME, -} from './schema.js' - -export async function saveConfig(dir: string, config: ConfigFile): Promise { - await mkdir(dir, { recursive: true, mode: DIR_PERM }) - - const stamped: ConfigFile = { ...config, schema_version: CURRENT_SCHEMA_VERSION } - const yaml = dumpYaml(stamped, { lineWidth: -1, noRefs: true }) - - const target = join(dir, FILE_NAME) - const tmp = `${target}.tmp.${process.pid}.${Date.now()}` - - try { - await writeFile(tmp, yaml, { mode: FILE_PERM }) - await rename(tmp, target) - } - catch (err) { - try { - await unlink(tmp) - } - catch { - // tmp may not exist if writeFile failed before creating it - } - throw newError( - ErrorCode.Unknown, - `save ${target}: ${(err as Error).message}`, - ).wrap(err) - } -} diff --git a/cli/src/env/registry.ts b/cli/src/env/registry.ts index 5a7938e01b..7e257364ee 100644 --- a/cli/src/env/registry.ts +++ b/cli/src/env/registry.ts @@ -1,4 +1,5 @@ import { parseLimit } from '../limit/limit.js' +import { getEnv } from '../sys/index.js' export type EnvVar = { readonly name: string @@ -55,9 +56,7 @@ export function lookupEnv(name: string): EnvVar | undefined { return BY_NAME.get(name) } -export function getEnv(name: string): string | undefined { - return process.env[name] -} +export { getEnv } export function resolveEnv(name: string): unknown { const entry = lookupEnv(name) diff --git a/cli/src/errors/format.ts b/cli/src/errors/format.ts index a65b466f56..4b80f08900 100644 --- a/cli/src/errors/format.ts +++ b/cli/src/errors/format.ts @@ -1,5 +1,5 @@ import type { BaseError } from './base.js' -import { colorEnabled, colorScheme } from '../io/color.js' +import { colorEnabled, colorScheme } from '../sys/io/color.js' import { renderEnvelope } from './envelope.js' export type FormatErrorOptions = { diff --git a/cli/src/config/writer.test.ts b/cli/src/store/config-writer.test.ts similarity index 67% rename from cli/src/config/writer.test.ts rename to cli/src/store/config-writer.test.ts index 0fb08f70de..1463699193 100644 --- a/cli/src/config/writer.test.ts +++ b/cli/src/store/config-writer.test.ts @@ -2,9 +2,15 @@ import { mkdtemp, readdir, readFile, stat } from 'node:fs/promises' import { tmpdir } from 'node:os' import { join } from 'node:path' import { beforeEach, describe, expect, it } from 'vitest' -import { loadConfig } from './loader.js' -import { emptyConfig, FILE_NAME } from './schema.js' -import { saveConfig } from './writer.js' +import { loadConfig } from '../config/config-loader' +import { emptyConfig, FILE_NAME } from '../config/schema' +import { platform } from '../sys' +import { saveConfig } from './config-writer' +import { YamlStore } from './store' + +function makeStore(dir: string): YamlStore { + return new YamlStore(join(dir, FILE_NAME)) +} describe('saveConfig', () => { let dir: string @@ -14,26 +20,26 @@ describe('saveConfig', () => { }) it('writes config.yml in the target dir', async () => { - await saveConfig(dir, { ...emptyConfig(), schema_version: 1 }) + saveConfig(makeStore(dir), { ...emptyConfig(), schema_version: 1 }) const stats = await stat(join(dir, FILE_NAME)) expect(stats.isFile()).toBe(true) }) - it('stamps schema_version=1 even if caller passed 0', async () => { - await saveConfig(dir, { ...emptyConfig() }) - const r = await loadConfig(dir) + it('stamps schema_version=1 even if caller passed 0', () => { + saveConfig(makeStore(dir), { ...emptyConfig() }) + const r = loadConfig(makeStore(dir)) expect(r.found).toBe(true) if (r.found) expect(r.config.schema_version).toBe(1) }) - it('round-trips defaults + state through YAML', async () => { - await saveConfig(dir, { + it('round-trips defaults + state through YAML', () => { + saveConfig(makeStore(dir), { schema_version: 1, defaults: { format: 'wide', limit: 75 }, state: { current_app: 'app-xyz' }, }) - const r = await loadConfig(dir) + const r = loadConfig(makeStore(dir)) expect(r.found).toBe(true) if (r.found) { expect(r.config.defaults.format).toBe('wide') @@ -43,32 +49,32 @@ describe('saveConfig', () => { }) it('writes file with mode 0o600 (POSIX)', async () => { - if (process.platform === 'win32') + if (platform() === 'win32') return - await saveConfig(dir, emptyConfig()) + saveConfig(makeStore(dir), emptyConfig()) const s = await stat(join(dir, FILE_NAME)) expect(s.mode & 0o777).toBe(0o600) }) it('does not leave a tmp file on success', async () => { - await saveConfig(dir, emptyConfig()) + saveConfig(makeStore(dir), emptyConfig()) const entries = await readdir(dir) expect(entries.filter(f => f.endsWith('.tmp'))).toHaveLength(0) expect(entries.filter(f => f.includes('.tmp.'))).toHaveLength(0) }) it('creates parent dir at 0o700 if absent', async () => { - if (process.platform === 'win32') + if (platform() === 'win32') return const nested = join(dir, 'nested', 'sub') - await saveConfig(nested, emptyConfig()) + saveConfig(makeStore(nested), emptyConfig()) const s = await stat(nested) expect(s.isDirectory()).toBe(true) expect(s.mode & 0o777).toBe(0o700) }) it('emits parseable YAML (round-trip via fs.readFile + js-yaml)', async () => { - await saveConfig(dir, { + saveConfig(makeStore(dir), { schema_version: 1, defaults: { format: 'json' }, state: {}, diff --git a/cli/src/store/config-writer.ts b/cli/src/store/config-writer.ts new file mode 100644 index 0000000000..79b8a23d65 --- /dev/null +++ b/cli/src/store/config-writer.ts @@ -0,0 +1,8 @@ +import type { ConfigFile } from '../config/schema' +import type { YamlStore } from './store' +import { CURRENT_SCHEMA_VERSION } from '../config/schema' + +export function saveConfig(store: YamlStore, config: ConfigFile): void { + const stamped: ConfigFile = { ...config, schema_version: CURRENT_SCHEMA_VERSION } + store.setTyped(stamped) +} diff --git a/cli/src/store/dir.ts b/cli/src/store/dir.ts new file mode 100644 index 0000000000..c75e1dbdfd --- /dev/null +++ b/cli/src/store/dir.ts @@ -0,0 +1,20 @@ +import { getEnv, resolvePlatform } from '../sys' + +export const ENV_CONFIG_DIR = 'DIFY_CONFIG_DIR' +export const ENV_CACHE_DIR = 'DIFY_CACHE_DIR' +export const FILE_PERM = 0o600 +export const DIR_PERM = 0o700 + +export function resolveCacheDir(): string { + const override = getEnv(ENV_CACHE_DIR) + if (override !== undefined && override !== '') + return override + return resolvePlatform().cacheDir() +} + +export function resolveConfigDir(): string { + const override = getEnv(ENV_CONFIG_DIR) + if (override !== undefined && override !== '') + return override + return resolvePlatform().configDir() +} diff --git a/cli/src/store/manager.ts b/cli/src/store/manager.ts new file mode 100644 index 0000000000..76e116b917 --- /dev/null +++ b/cli/src/store/manager.ts @@ -0,0 +1,28 @@ +import type { Store } from './store' +import { join } from 'node:path' +import { FILE_NAME } from '../config/schema' +import { resolveCacheDir, resolveConfigDir } from './dir' +import { YamlStore } from './store' + +export const CACHE_APP_INFO = 'app-info' +export const CACHE_NUDGE = 'nudge' + +function getStore(filePath: string): YamlStore { + return new YamlStore(filePath) +} + +function resolveConfigurationPath(): string { + return join(resolveConfigDir(), FILE_NAME) +} + +export function cachePath(cacheDir: string, name: string): string { + return join(cacheDir, `${name}.yml`) +} + +export function getConfigurationStore(): YamlStore { + return getStore(resolveConfigurationPath()) +} + +export function getCache(cacheName: string): Store { + return getStore(cachePath(resolveCacheDir(), cacheName)) +} diff --git a/cli/src/store/store.test.ts b/cli/src/store/store.test.ts new file mode 100644 index 0000000000..3f0c3de1e7 --- /dev/null +++ b/cli/src/store/store.test.ts @@ -0,0 +1,193 @@ +import { readFileSync, writeFileSync } from 'node:fs' +import { mkdtemp, rm, writeFile } from 'node:fs/promises' +import { tmpdir } from 'node:os' +import { join } from 'node:path' +import { afterEach, beforeEach, describe, expect, it } from 'vitest' +import { ConcurrentAccessError, YamlStore } from './store' + +describe('YamlStore.doGet', () => { + it('returns default when content is undefined', () => { + const store = new YamlStore('/irrelevant') + expect(store.doGet({ key: 'name', default: 'fallback' })).toBe('fallback') + }) + + it('reads a flat key', () => { + const store = new YamlStore('/irrelevant') + store.raw_content = 'name: alice\n' + expect(store.doGet({ key: 'name', default: '' })).toBe('alice') + }) + + it('reads a nested key via dot notation', () => { + const store = new YamlStore('/irrelevant') + store.raw_content = 'user:\n id: 42\n' + expect(store.doGet({ key: 'user.id', default: 0 })).toBe(42) + }) + + it('returns default for a missing flat key', () => { + const store = new YamlStore('/irrelevant') + store.raw_content = 'name: alice\n' + expect(store.doGet({ key: 'age', default: -1 })).toBe(-1) + }) + + it('returns default when an intermediate path segment is absent', () => { + const store = new YamlStore('/irrelevant') + store.raw_content = 'user:\n name: bob\n' + expect(store.doGet({ key: 'user.address.city', default: 'unknown' })).toBe('unknown') + }) + + it('returns default when an intermediate path segment is a scalar', () => { + const store = new YamlStore('/irrelevant') + store.raw_content = 'user: scalar\n' + expect(store.doGet({ key: 'user.id', default: 0 })).toBe(0) + }) +}) + +describe('YamlStore.doSet', () => { + it('sets a flat key from empty content', () => { + const store = new YamlStore('/irrelevant') + store.doSet({ key: 'name', default: '' }, 'alice') + expect(store.doGet({ key: 'name', default: '' })).toBe('alice') + }) + + it('sets a nested key, creating intermediate objects', () => { + const store = new YamlStore('/irrelevant') + store.doSet({ key: 'user.id', default: 0 }, 42) + expect(store.doGet({ key: 'user.id', default: 0 })).toBe(42) + }) + + it('overwrites an existing key without disturbing siblings', () => { + const store = new YamlStore('/irrelevant') + store.raw_content = 'name: alice\nage: 30\n' + store.doSet({ key: 'name', default: '' }, 'bob') + expect(store.doGet({ key: 'name', default: '' })).toBe('bob') + expect(store.doGet({ key: 'age', default: 0 })).toBe(30) + }) + + it('replaces a scalar intermediate with an object when path deepens', () => { + const store = new YamlStore('/irrelevant') + store.raw_content = 'user: scalar\n' + store.doSet({ key: 'user.id', default: 0 }, 99) + expect(store.doGet({ key: 'user.id', default: 0 })).toBe(99) + }) +}) + +describe('FileBasedStore.withLock concurrency', () => { + let dir: string + + beforeEach(async () => { + dir = await mkdtemp(join(tmpdir(), 'difyctl-yaml-store-')) + }) + + afterEach(async () => { + await rm(dir, { recursive: true, force: true }) + }) + + it('second get throws while first holds the lock, succeeds after release', async () => { + const path = join(dir, 'config.yml') + await writeFile(path, 'key: value\n') + + const s1 = new YamlStore(path) + const s2 = new YamlStore(path) + + s1.lock() + + expect(() => s2.get({ key: 'key', default: '' })).toThrow(ConcurrentAccessError) + + s1.unlock() + + expect(s2.get({ key: 'key', default: '' })).toBe('value') + }) + + it('second set throws while first holds the lock, succeeds after release', async () => { + const path = join(dir, 'config.yml') + await writeFile(path, 'key: original\n') + + const s1 = new YamlStore(path) + const s2 = new YamlStore(path) + + s1.lock() + + expect(() => s2.set({ key: 'key', default: '' }, 'blocked')).toThrow(ConcurrentAccessError) + + s1.unlock() + + s2.set({ key: 'key', default: '' }, 'written') + expect(s2.get({ key: 'key', default: '' })).toBe('written') + }) +}) + +describe('YamlStore persistence', () => { + let dir: string + + beforeEach(async () => { + dir = await mkdtemp(join(tmpdir(), 'difyctl-yaml-store-')) + }) + + afterEach(async () => { + await rm(dir, { recursive: true, force: true }) + }) + + it('round-trips a flat value through disk', async () => { + const path = join(dir, 'config.yml') + await writeFile(path, '') + + const s1 = new YamlStore(path) + s1.raw_content = '' + s1.doSet({ key: 'workspace', default: '' }, 'ws-123') + writeFileSync(path, s1.raw_content ?? '') + + const s2 = new YamlStore(path) + s2.raw_content = readFileSync(path, 'utf8') + expect(s2.doGet({ key: 'workspace', default: '' })).toBe('ws-123') + }) + + it('round-trips a deep nested value through disk', async () => { + const path = join(dir, 'config.yml') + await writeFile(path, '') + + const s1 = new YamlStore(path) + s1.raw_content = '' + s1.doSet({ key: 'a.b.c', default: '' }, 'deep') + writeFileSync(path, s1.raw_content ?? '') + + const s2 = new YamlStore(path) + s2.raw_content = readFileSync(path, 'utf8') + expect(s2.doGet({ key: 'a.b.c', default: '' })).toBe('deep') + }) + + it('second doSet on a reloaded store does not clobber the first key', async () => { + const path = join(dir, 'config.yml') + await writeFile(path, '') + + const s1 = new YamlStore(path) + s1.raw_content = '' + s1.doSet({ key: 'x', default: '' }, 'first') + writeFileSync(path, s1.raw_content ?? '') + + const s2 = new YamlStore(path) + s2.raw_content = readFileSync(path, 'utf8') + s2.doSet({ key: 'y', default: '' }, 'second') + writeFileSync(path, s2.raw_content ?? '') + + const s3 = new YamlStore(path) + s3.raw_content = readFileSync(path, 'utf8') + expect(s3.doGet({ key: 'x', default: '' })).toBe('first') + expect(s3.doGet({ key: 'y', default: '' })).toBe('second') + }) + + it('load → doSet → flush writes the value to disk', async () => { + const path = join(dir, 'config.yml') + await writeFile(path, 'existing: value\n') + + const store = new YamlStore(path) + store.load() + store.doSet({ key: 'token', default: '' }, 'abc-123') + store.flush() + + const raw = readFileSync(path, 'utf8') + const store2 = new YamlStore(path) + store2.raw_content = raw + expect(store2.doGet({ key: 'token', default: '' })).toBe('abc-123') + expect(store2.doGet({ key: 'existing', default: '' })).toBe('value') + }) +}) diff --git a/cli/src/store/store.ts b/cli/src/store/store.ts new file mode 100644 index 0000000000..f1cb8a2302 --- /dev/null +++ b/cli/src/store/store.ts @@ -0,0 +1,165 @@ +import type { Platform } from '../sys' +import fs from 'node:fs' +import { dirname } from 'node:path' +import yaml from 'js-yaml' +import lockfile from 'lockfile' +import { pid, resolvePlatform } from '../sys' + +const FILE_PERM = 0o600 +const DIR_PERM = 0o700 + +type Key = { + default: T + key: string +} + +export type Store = { + get: (key: Key) => T + set: (key: Key, value: T) => void +} + +export class ConcurrentAccessError extends Error { + constructor(filePath: string) { + super(`Another process is modifying the file ${filePath}. remove ${filePath}.lock to reset lock.`) + } +} + +abstract class FileBasedStore implements Store { + file_path: string + raw_content: string | undefined + private readonly platform: Platform + + constructor(file_path: string) { + this.file_path = file_path + this.platform = resolvePlatform() + fs.mkdirSync(dirname(this.file_path), { recursive: true, mode: DIR_PERM }) + } + + unlock(): void { + lockfile.unlockSync(`${this.file_path}.lock`) + } + + /** + * atomically write raw_content (if any) + */ + flush(): void { + if (this.raw_content !== undefined) { + const tmp = `${this.file_path}.tmp.${pid()}.${Date.now()}` + try { + fs.writeFileSync(tmp, this.raw_content, { mode: FILE_PERM }) + this.platform.atomicReplace(tmp, this.file_path) + } + catch (err) { + try { + fs.unlinkSync(tmp) + } + catch { /* tmp may not exist */ } + throw err + } + } + } + + lock(): void { + try { + lockfile.lockSync(`${this.file_path}.lock`) + } + catch (err) { + const code = (err as NodeJS.ErrnoException).code + if (code === 'EEXIST') { + throw new ConcurrentAccessError(this.file_path) + } + throw err + } + } + + load(): void { + try { + this.raw_content = fs.readFileSync(this.file_path, 'utf8') + } + catch (err) { + const code = (err as NodeJS.ErrnoException).code + if (code !== 'ENOENT') { + throw err + } + } + } + + protected withLock(body: () => R): R { + this.lock() + try { + this.load() + return body() + } + finally { + this.unlock() + } + } + + get(key: Key): T { + return this.withLock(() => this.doGet(key)) + } + + set(key: Key, value: T) { + this.withLock(() => { + this.doSet(key, value) + this.flush() + }) + } + + abstract doGet(key: Key): T + abstract doSet(key: Key, value: T): void +} + +export class YamlStore extends FileBasedStore { + constructor(file_path: string) { + super(file_path) + } + + doGet(key: Key): T { + const data = loadYaml(this.raw_content) + const parts = key.key.split('.') + let current: unknown = data + for (const part of parts) { + if (current === null || current === undefined || typeof current !== 'object') + return key.default + current = (current as Record)[part] + } + return (current as T) ?? key.default + } + + getTyped(): T | null { + return this.withLock(() => { + this.load() + return loadYaml(this.raw_content) as T + }) + } + + setTyped(data: T): void { + this.withLock(() => { + this.raw_content = yaml.dump(data, { lineWidth: -1, noRefs: true }) + this.flush() + }) + } + + doSet(key: Key, value: T): void { + const data = loadYaml(this.raw_content) || {} + const parts = key.key.split('.') + const lastKey = parts.pop() + if (lastKey === undefined) + return + let current: Record = data + for (const part of parts) { + if (current[part] === null || current[part] === undefined || typeof current[part] !== 'object') + current[part] = {} + current = current[part] as Record + } + current[lastKey] = value + this.raw_content = yaml.dump(data, { lineWidth: -1, noRefs: true }) + } +} + +function loadYaml(raw: string | undefined): Record | null { + if (raw === undefined) + return null + return (yaml.load(raw) ?? {}) as Record +} diff --git a/cli/src/sys/index.test.ts b/cli/src/sys/index.test.ts new file mode 100644 index 0000000000..6aeaff19cf --- /dev/null +++ b/cli/src/sys/index.test.ts @@ -0,0 +1,37 @@ +import { homedir } from 'node:os' +import { join } from 'node:path' +import { describe, expect, it } from 'vitest' +import { resolvePlatform, SUBDIR } from './index.js' + +describe('resolvePlatform', () => { + it('id matches process.platform', () => { + expect(resolvePlatform().id()).toBe(process.platform) + }) + + it('configDir ends with the difyctl subdir', () => { + const p = resolvePlatform() + if (p.id() === 'win32') { + expect(p.configDir()).toMatch(/difyctl$/) + } + else { + expect(p.configDir()).toBe(join(homedir(), '.config', SUBDIR)) + } + }) + + it('cacheDir ends with the difyctl subdir', () => { + const p = resolvePlatform() + if (p.id() === 'win32') { + expect(p.cacheDir()).toMatch(/difyctl$/) + } + else if (p.id() === 'darwin') { + expect(p.cacheDir()).toBe(join(homedir(), 'Library', 'Caches', SUBDIR)) + } + else { + expect(p.cacheDir()).toBe(join(homedir(), '.cache', SUBDIR)) + } + }) + + it('atomicReplace is a function', () => { + expect(resolvePlatform().atomicReplace).toBeTypeOf('function') + }) +}) diff --git a/cli/src/sys/index.ts b/cli/src/sys/index.ts new file mode 100644 index 0000000000..3faeaf15d9 --- /dev/null +++ b/cli/src/sys/index.ts @@ -0,0 +1,122 @@ +import fs from 'node:fs' +import { homedir } from 'node:os' +import { join } from 'node:path' + +export function getEnv(name: string): string | undefined { + return process.env[name] +} + +export function env(): NodeJS.ProcessEnv { + return process.env +} + +export function processExit(code: number): never { + return process.exit(code) as never +} + +export function io() { + return { + out: process.stdout, + err: process.stderr, + in: process.stdin, + isOutTTY: Boolean(process.stdout.isTTY), + isErrTTY: Boolean(process.stderr.isTTY), + } +} + +export function handle(sig: string, handler: () => void) { + process.once(sig, handler) +} + +export function unhandle(sig: string, handler: () => void) { + process.off(sig, handler) +} + +export function platform(): NodeJS.Platform { + return process.platform +} + +export function arch(): string { + return process.arch +} + +export function pid(): number { + return Number(process.pid) +} + +export type Platform = { + id: () => NodeJS.Platform + configDir: () => string + cacheDir: () => string + atomicReplace: (src: string, dst: string) => void +} + +export const SUBDIR = 'difyctl' +export const ENV_XDG_CONFIG_HOME = 'XDG_CONFIG_HOME' +export const ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' + +function appDataDir(): string | undefined { + return getEnv('APPDATA') ?? getEnv('LOCALAPPDATA') +} + +type PlatformFactory = () => Platform + +function posixAtomicReplace(src: string, dst: string): void { + fs.renameSync(src, dst) +} + +function win32AtomicReplace(src: string, dst: string): void { + try { + fs.unlinkSync(dst) + } + catch { } + fs.renameSync(src, dst) +} + +const platformImpls: Partial> = { + linux: () => ({ + id: () => 'linux', + configDir: () => { + const xdg = getEnv(ENV_XDG_CONFIG_HOME) + return (xdg !== undefined && xdg !== '') ? join(xdg, SUBDIR) : join(homedir(), '.config', SUBDIR) + }, + cacheDir: () => { + const xdg = getEnv(ENV_XDG_CACHE_HOME) + return (xdg !== undefined && xdg !== '') ? join(xdg, SUBDIR) : join(homedir(), '.cache', SUBDIR) + }, + atomicReplace: posixAtomicReplace, + }), + darwin: () => ({ + id: () => 'darwin', + configDir: () => join(homedir(), '.config', SUBDIR), + cacheDir: () => join(homedir(), 'Library', 'Caches', SUBDIR), + atomicReplace: posixAtomicReplace, + }), + win32: () => ({ + id: () => 'win32', + configDir: () => { + const appData = appDataDir() + if (appData === undefined || appData === '') + throw new Error('cannot resolve %APPDATA% on Windows') + return join(appData, SUBDIR) + }, + cacheDir: () => { + const appData = appDataDir() + if (appData === undefined || appData === '') + throw new Error('cannot resolve %LOCALAPPDATA% on Windows') + return join(appData, SUBDIR) + }, + atomicReplace: win32AtomicReplace, + }), +} + +const defaultPlatformFactory: PlatformFactory = () => ({ + id: () => platform(), + configDir: () => join(homedir(), '.config', SUBDIR), + cacheDir: () => join(homedir(), '.cache', SUBDIR), + atomicReplace: posixAtomicReplace, +}) + +export function resolvePlatform(): Platform { + return (platformImpls[platform()] ?? defaultPlatformFactory)() +} diff --git a/cli/src/io/color.ts b/cli/src/sys/io/color.ts similarity index 100% rename from cli/src/io/color.ts rename to cli/src/sys/io/color.ts diff --git a/cli/src/io/spinner.ts b/cli/src/sys/io/spinner.ts similarity index 100% rename from cli/src/io/spinner.ts rename to cli/src/sys/io/spinner.ts diff --git a/cli/src/io/streams.ts b/cli/src/sys/io/streams.ts similarity index 89% rename from cli/src/io/streams.ts rename to cli/src/sys/io/streams.ts index a51f630f62..62f215f31c 100644 --- a/cli/src/io/streams.ts +++ b/cli/src/sys/io/streams.ts @@ -1,5 +1,6 @@ import { Buffer } from 'node:buffer' import { PassThrough, Readable, Writable } from 'node:stream' +import { io } from '..' export type IOStreams = { out: NodeJS.WritableStream @@ -16,12 +17,8 @@ export function nullStreams(): IOStreams { export function realStreams(outputFormat = ''): IOStreams { return { - out: process.stdout, - err: process.stderr, - in: process.stdin, - isOutTTY: Boolean(process.stdout.isTTY), - isErrTTY: Boolean(process.stderr.isTTY), outputFormat, + ...io(), } } diff --git a/cli/src/io/think-filter.test.ts b/cli/src/sys/io/think-filter.test.ts similarity index 100% rename from cli/src/io/think-filter.test.ts rename to cli/src/sys/io/think-filter.test.ts diff --git a/cli/src/io/think-filter.ts b/cli/src/sys/io/think-filter.ts similarity index 100% rename from cli/src/io/think-filter.ts rename to cli/src/sys/io/think-filter.ts diff --git a/cli/src/util/browser.ts b/cli/src/util/browser.ts index 3a272cc77a..5ec813dd77 100644 --- a/cli/src/util/browser.ts +++ b/cli/src/util/browser.ts @@ -1,4 +1,5 @@ import openModule from 'open' +import { platform } from '../sys' export const OpenDecision = { Auto: 'auto-open', @@ -19,7 +20,7 @@ export type BrowserEnv = { export function realEnv(): BrowserEnv { return { getEnv: k => process.env[k], - platform: process.platform, + platform: platform(), isOutTTY: Boolean(process.stdout.isTTY), isErrTTY: Boolean(process.stderr.isTTY), } diff --git a/cli/src/version/info.ts b/cli/src/version/info.ts index 5f4b6245e9..c40484bbc3 100644 --- a/cli/src/version/info.ts +++ b/cli/src/version/info.ts @@ -1,3 +1,4 @@ +import { arch, platform } from '../sys/index.js' import { compatString } from './compat.js' export type Channel = 'dev' | 'rc' | 'stable' @@ -27,5 +28,5 @@ export function longVersion(): string { } export function userAgent(): string { - return `difyctl/${versionInfo.version} (${process.platform}; ${process.arch}; ${versionInfo.channel})` + return `difyctl/${versionInfo.version} (${platform()}; ${arch()}; ${versionInfo.channel})` } diff --git a/cli/src/version/nudge.test.ts b/cli/src/version/nudge.test.ts index 2eeaa32050..a581d40700 100644 --- a/cli/src/version/nudge.test.ts +++ b/cli/src/version/nudge.test.ts @@ -5,6 +5,8 @@ import { tmpdir } from 'node:os' import { join } from 'node:path' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { loadNudgeStore } from '../cache/nudge-store.js' +import { CACHE_NUDGE, cachePath } from '../store/manager.js' +import { YamlStore } from '../store/store.js' import { maybeNudgeCompat } from './nudge.js' const HOST = 'https://cloud.dify.ai' @@ -44,7 +46,7 @@ describe('maybeNudgeCompat', () => { beforeEach(async () => { dir = await mkdtemp(join(tmpdir(), 'difyctl-nudge-')) - store = await loadNudgeStore({ configDir: dir, now: fixedNow }) + store = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: fixedNow }) }) afterEach(async () => { await rm(dir, { recursive: true, force: true }) @@ -76,12 +78,12 @@ describe('maybeNudgeCompat', () => { it('warns again after the silence window has elapsed', async () => { const yesterday = new Date(NOW.getTime() - 25 * 60 * 60 * 1000) - const tStore = await loadNudgeStore({ configDir: dir, now: () => yesterday }) + const tStore = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: () => yesterday }) await tStore.markWarned(HOST) const probe = vi.fn(async () => UNSUPPORTED) const { emit, lines } = emitterSpy() - const freshStore = await loadNudgeStore({ configDir: dir, now: fixedNow }) + const freshStore = await loadNudgeStore({ store: new YamlStore(cachePath(dir, CACHE_NUDGE)), now: fixedNow }) await maybeNudgeCompat(HOST, baseDeps({ store: freshStore, probe, emit })) expect(probe).toHaveBeenCalledOnce() diff --git a/cli/src/version/nudge.ts b/cli/src/version/nudge.ts index f5c9866f92..ad4c0fdde0 100644 --- a/cli/src/version/nudge.ts +++ b/cli/src/version/nudge.ts @@ -1,6 +1,6 @@ import type { ServerVersionResponse } from '@dify/contracts/api/openapi/types.gen' import type { NudgeStore } from '../cache/nudge-store.js' -import { colorScheme } from '../io/color.js' +import { colorScheme } from '../sys/io/color.js' import { difyCompat, evaluateCompat } from './compat.js' // Formats whose stdout is structured data (json/yaml) or a single name token — diff --git a/cli/src/version/probe.test.ts b/cli/src/version/probe.test.ts index 28e4335b12..77a20ef2e4 100644 --- a/cli/src/version/probe.test.ts +++ b/cli/src/version/probe.test.ts @@ -1,12 +1,13 @@ import type { ServerVersionResponse } from '@dify/contracts/api/openapi/types.gen' import type { HostsBundle } from '../auth/hosts.js' import { mkdtemp, rm } from 'node:fs/promises' -import { tmpdir } from 'node:os' +import { platform, tmpdir } from 'node:os' import { join } from 'node:path' import { describe, expect, it } from 'vitest' import { startMock } from '../../test/fixtures/dify-mock/server.js' import { saveHosts } from '../auth/hosts.js' -import { ENV_CONFIG_DIR } from '../config/dir.js' +import { ENV_CONFIG_DIR } from '../store/dir.js' +import { arch } from '../sys/index.js' import { runVersionProbe } from './probe.js' function bundle(overrides: Partial = {}): HostsBundle { @@ -195,7 +196,7 @@ describe('runVersionProbe', () => { expect(report.client.version).toBeTypeOf('string') expect(report.client.commit).toBeTypeOf('string') expect(report.client.channel).toMatch(/^(dev|rc|stable)$/) - expect(report.client.platform).toBe(process.platform) - expect(report.client.arch).toBe(process.arch) + expect(report.client.platform).toBe(platform()) + expect(report.client.arch).toBe(arch()) }) }) diff --git a/cli/src/version/probe.ts b/cli/src/version/probe.ts index 25af8b3d61..09fc373661 100644 --- a/cli/src/version/probe.ts +++ b/cli/src/version/probe.ts @@ -4,8 +4,9 @@ import type { CompatVerdict } from './compat.js' import type { Channel } from './info.js' import { META_PROBE_TIMEOUT_MS, MetaClient } from '../api/meta.js' import { loadHosts } from '../auth/hosts.js' -import { resolveConfigDir } from '../config/dir.js' import { createClient } from '../http/client.js' +import { resolveConfigDir } from '../store/dir.js' +import { arch, platform } from '../sys/index.js' import { hostWithScheme } from '../util/host.js' import { difyCompat, evaluateCompat } from './compat.js' import { versionInfo } from './info.js' @@ -60,8 +61,8 @@ function buildClientBlock(): ClientBlock { commit: versionInfo.commit, buildDate: versionInfo.buildDate, channel: versionInfo.channel, - platform: process.platform, - arch: process.arch, + platform: platform(), + arch: arch(), } } diff --git a/cli/src/version/render.ts b/cli/src/version/render.ts index 1398622df2..2777eca1d3 100644 --- a/cli/src/version/render.ts +++ b/cli/src/version/render.ts @@ -1,5 +1,5 @@ import type { VersionReport } from './probe.js' -import { colorScheme } from '../io/color.js' +import { colorScheme } from '../sys/io/color.js' const RC_WARNING_LINES = [ 'WARNING: This build is a release candidate. It is in beta test, not stable,', diff --git a/cli/src/workspace/resolver.ts b/cli/src/workspace/resolver.ts index 225e65f666..1be313cd63 100644 --- a/cli/src/workspace/resolver.ts +++ b/cli/src/workspace/resolver.ts @@ -25,7 +25,7 @@ export function resolveWorkspaceId(inputs: WorkspaceResolveInputs): string { throw new BaseError({ code: ErrorCode.UsageMissingArg, message: 'no workspace selected', - hint: 'pass --workspace, set DIFY_WORKSPACE_ID, or run \'difyctl auth use\'', + hint: 'pass --workspace, set DIFY_WORKSPACE_ID, or run \'difyctl use workspace \'', }) } diff --git a/cli/tsconfig.json b/cli/tsconfig.json index dc04c33f30..41b24e8690 100644 --- a/cli/tsconfig.json +++ b/cli/tsconfig.json @@ -2,6 +2,14 @@ "extends": "@dify/tsconfig/node.json", "compilerOptions": { "rootDir": "src", + "paths": { + "@/*": [ + "./*" + ], + "~@/*": [ + "./*" + ] + }, "types": ["node"], "declaration": true, "declarationMap": true, @@ -10,5 +18,5 @@ "sourceMap": true }, "include": ["src/**/*.ts"], - "exclude": ["dist", "test", "node_modules", "**/*.test.ts"] + "exclude": ["node_modules", "**/*.test.ts"] } diff --git a/dify-agent/docs/agenton/index.md b/dify-agent/docs/agenton/index.md index f96db54256..2af61df226 100644 --- a/dify-agent/docs/agenton/index.md +++ b/dify-agent/docs/agenton/index.md @@ -2,5 +2,3 @@ - [User guide](guide/index.md) explains how to compose layers, register config-backed plugins, use system/user prompts, and snapshot sessions. -- [API reference](api/index.md) lists the public Agenton classes, methods, and extension - points. diff --git a/dify-agent/docs/dify-agent/get-started/index.md b/dify-agent/docs/dify-agent/get-started/index.md index 517552e903..ff755aa183 100644 --- a/dify-agent/docs/dify-agent/get-started/index.md +++ b/dify-agent/docs/dify-agent/get-started/index.md @@ -111,11 +111,10 @@ import sys from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig from dify_agent.client import Client +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig from dify_agent.layers.dify_plugin import ( - DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID, DifyPluginLLMLayerConfig, - DifyPluginLayerConfig, ) from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, CreateRunRequest, RunComposition, RunLayerSpec @@ -147,19 +146,20 @@ def build_request() -> CreateRunRequest: config=PromptLayerConfig(prefix=SYSTEM_PROMPT, user=USER_PROMPT), ), RunLayerSpec( - name="plugin", - type=DIFY_PLUGIN_LAYER_TYPE_ID, - config=DifyPluginLayerConfig( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig( tenant_id=TENANT_ID, - plugin_id=PLUGIN_ID, user_id=USER_ID, + invoke_from="workflow_run", ), ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID, - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=DifyPluginLLMLayerConfig( + plugin_id=PLUGIN_ID, model_provider=MODEL_PROVIDER, model=MODEL_NAME, credentials=MODEL_CREDENTIALS, diff --git a/dify-agent/docs/dify-agent/guide/index.md b/dify-agent/docs/dify-agent/guide/index.md index 012bd3a598..e191caa613 100644 --- a/dify-agent/docs/dify-agent/guide/index.md +++ b/dify-agent/docs/dify-agent/guide/index.md @@ -61,9 +61,11 @@ record TTL so active runs that keep producing events remain observable. ## Scheduling and shutdown semantics -`POST /runs` validates the composition, persists a `running` run record, and starts -an `asyncio` task in the same process. There is no Redis job stream, consumer -group, pending reclaim, or automatic retry layer. +`POST /runs` persists a `running` run record and starts an `asyncio` task in the +same process. There is no Redis job stream, consumer group, pending reclaim, or +automatic retry layer. Request-shaped runtime failures such as bad composition, +prompt, output, or snapshot inputs are reported later as failed runs rather than +rejected synchronously once the request DTO itself is accepted. During FastAPI shutdown the scheduler rejects new runs, waits up to `DIFY_AGENT_SHUTDOWN_GRACE_SECONDS` for active tasks, then cancels remaining tasks diff --git a/dify-agent/docs/dify-agent/index.md b/dify-agent/docs/dify-agent/index.md index 39cedf74b9..2683fa78bd 100644 --- a/dify-agent/docs/dify-agent/index.md +++ b/dify-agent/docs/dify-agent/index.md @@ -4,5 +4,4 @@ Dify Agent hosts Agenton-composed Pydantic AI runs behind a FastAPI API. Its source code stays under `src/dify_agent`, while framework-neutral Agenton code stays under `src/agenton` and `src/agenton_collections`. -See the [operations guide](guide/index.md) for local server behavior and the -[run API](api/index.md) for request and event schemas. +See the [operations guide](guide/index.md) for local server behavior. diff --git a/dify-agent/docs/dify-agent/user-manual/execution-context-layer/index.md b/dify-agent/docs/dify-agent/user-manual/execution-context-layer/index.md new file mode 100644 index 0000000000..e73fd3c19b --- /dev/null +++ b/dify-agent/docs/dify-agent/user-manual/execution-context-layer/index.md @@ -0,0 +1,67 @@ +# Execution context layer + +The execution-context layer carries shared Dify run identifiers plus the tenant +and optional user context needed for plugin-daemon calls. Server settings still +provide the plugin daemon URL and API key. + +Use it together with a [plugin LLM layer](../plugin-llm-layer/index.md) and, +when the caller wants Dify tools exposed to the model, a +[plugin tool layer](../plugin-tool-layer/index.md). Both business layers depend +on this layer to reach the plugin daemon. + +## Config fields + +| Field | Type | Meaning | +| --- | --- | --- | +| `tenant_id` | `str` | Dify tenant/workspace id used when calling the plugin daemon. | +| `user_id` | `str \| None` | Optional end-user id passed through to the plugin daemon. | +| `invoke_from` | `Literal[...]` | Dify caller category recorded for observability and correlation. | +| `app_id` / `workflow_id` / `workflow_run_id` / `node_id` / `node_execution_id` / `conversation_id` / `agent_id` / `agent_config_version_id` / `trace_id` | `str \| None` | Optional Dify-owned execution identifiers forwarded with the run. | + +The execution-context layer type id is `dify.execution_context`. + +## Basic usage + +```python {test="skip" lint="skip"} +from dify_agent.layers.execution_context import ( + DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + DifyExecutionContextLayerConfig, +) +from dify_agent.protocol import RunLayerSpec + + +execution_context_layer = RunLayerSpec( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig( + tenant_id="replace-with-tenant-id", + user_id="replace-with-user-id", + invoke_from="workflow_run", + ), +) +``` + +If you do not need a user id, omit `user_id` or pass `None`. Most optional +execution identifiers may also be omitted when they are not available. + +## Server-side settings + +The execution-context layer config does not include daemon transport settings. +Configure these on the Dify Agent server instead: + +```env +DIFY_AGENT_PLUGIN_DAEMON_URL=http://localhost:5002 +DIFY_AGENT_PLUGIN_DAEMON_API_KEY=replace-with-plugin-daemon-server-key +``` + +This keeps server credentials out of client-submitted layer config and out of +session snapshots. + +## Notes + +- The execution-context layer does not open, cache, close, or snapshot HTTP clients. +- Concrete `plugin_id` values belong to the business layer that invokes the + daemon: the plugin LLM layer for model calls and each plugin tool config for + tool calls. +- The conventional layer name is `execution_context`. If you use another name, + point the LLM and tool layer dependencies at that name. diff --git a/dify-agent/docs/dify-agent/user-manual/plugin-layer/index.md b/dify-agent/docs/dify-agent/user-manual/plugin-layer/index.md deleted file mode 100644 index 2164da9882..0000000000 --- a/dify-agent/docs/dify-agent/user-manual/plugin-layer/index.md +++ /dev/null @@ -1,59 +0,0 @@ -# Plugin layer - -The plugin layer carries Dify plugin daemon identity for a run. It identifies the -tenant, plugin, and optional user context; server settings provide the plugin -daemon URL and API key. - -Use it together with a [plugin LLM layer](../plugin-llm-layer/index.md). The LLM -layer depends on this layer to reach the plugin daemon. - -## Config fields - -| Field | Type | Meaning | -| --- | --- | --- | -| `tenant_id` | `str` | Dify tenant/workspace id used when calling the plugin daemon. | -| `plugin_id` | `str` | Plugin id, for example `langgenius/openai`. | -| `user_id` | `str \| None` | Optional end-user id passed through to the plugin daemon. | - -The plugin layer type id is `dify.plugin`. - -## Basic usage - -```python {test="skip" lint="skip"} -from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LAYER_TYPE_ID, DifyPluginLayerConfig -from dify_agent.protocol import RunLayerSpec - - -plugin_layer = RunLayerSpec( - name="plugin", - type=DIFY_PLUGIN_LAYER_TYPE_ID, - config=DifyPluginLayerConfig( - tenant_id="replace-with-tenant-id", - plugin_id="langgenius/openai", - user_id="replace-with-user-id", - ), -) -``` - -If you do not need a user id, omit `user_id` or pass `None`. - -## Server-side settings - -The plugin layer config does not include daemon transport settings. Configure -these on the Dify Agent server instead: - -```env -DIFY_AGENT_PLUGIN_DAEMON_URL=http://localhost:5002 -DIFY_AGENT_PLUGIN_DAEMON_API_KEY=replace-with-plugin-daemon-server-key -``` - -This keeps server credentials out of client-submitted layer config and out of -session snapshots. - -## Notes - -- The plugin layer does not open, cache, close, or snapshot HTTP clients. -- `plugin_id` selects the plugin package. The business model provider and model - name belong to the plugin LLM layer, not this layer. -- The conventional layer name is `plugin`. If you use another name, point the LLM - layer dependency at that name. diff --git a/dify-agent/docs/dify-agent/user-manual/plugin-llm-layer/index.md b/dify-agent/docs/dify-agent/user-manual/plugin-llm-layer/index.md index 624d7cfd14..889d67778b 100644 --- a/dify-agent/docs/dify-agent/user-manual/plugin-llm-layer/index.md +++ b/dify-agent/docs/dify-agent/user-manual/plugin-llm-layer/index.md @@ -1,17 +1,18 @@ # Plugin LLM layer -The plugin LLM layer selects the model provider, model name, provider credentials, -and optional model settings for the current run. Dify Agent reads the model from -the reserved layer name `llm`. +The plugin LLM layer selects the plugin package, model provider, model name, +provider credentials, and optional model settings for the current run. Dify +Agent reads the model from the reserved layer name `llm`. -It must depend on a [plugin layer](../plugin-layer/index.md), because the plugin -layer supplies the daemon identity and transport context. +It must depend on an [execution context layer](../execution-context-layer/index.md), +because that layer supplies the daemon identity and transport context. ## Config fields | Field | Type | Meaning | | --- | --- | --- | -| `model_provider` | `str` | Provider name inside the selected plugin. Use the value of `DIFY_AGENT_PROVIDER` from `dify-agent/.env`. | +| `plugin_id` | `str` | Plugin package id, for example `langgenius/openai`. | +| `model_provider` | `str` | Provider name inside `plugin_id`. Use the value of `DIFY_AGENT_PROVIDER` from `dify-agent/.env`. | | `model` | `str` | Model name. Use the value of `DIFY_AGENT_MODEL_NAME` from `dify-agent/.env`. | | `credentials` | `dict[str, str \| int \| float \| bool \| None]` | Provider-specific credential object. | | `model_settings` | `ModelSettings \| None` | Optional pydantic-ai model settings. | @@ -27,12 +28,14 @@ from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, RunLayerSpec MODEL_PROVIDER = "replace-with-provider-from-dify-agent-env" MODEL_NAME = "replace-with-model-from-dify-agent-env" +PLUGIN_ID = "langgenius/openai" llm_layer = RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID, - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=DifyPluginLLMLayerConfig( + plugin_id=PLUGIN_ID, model_provider=MODEL_PROVIDER, model=MODEL_NAME, credentials={"api_key": "replace-with-provider-key"}, @@ -40,29 +43,30 @@ llm_layer = RunLayerSpec( ) ``` -`deps={"plugin": "plugin"}` means: bind the LLM layer's dependency field named -`plugin` to the composition layer named `plugin`. +`deps={"execution_context": "execution_context"}` means: bind the LLM layer's +dependency field named `execution_context` to the composition layer named +`execution_context`. Set `MODEL_PROVIDER` and `MODEL_NAME` to the same values as `DIFY_AGENT_PROVIDER` and `DIFY_AGENT_MODEL_NAME` in `dify-agent/.env`. ## Complete minimal model composition -Most runs include a prompt, plugin context, and LLM layer: +Most runs include a prompt, execution-context layer, and LLM layer: ```python {test="skip" lint="skip"} from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig from dify_agent.layers.dify_plugin import ( - DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID, DifyPluginLLMLayerConfig, - DifyPluginLayerConfig, ) from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, RunComposition, RunLayerSpec MODEL_PROVIDER = "replace-with-provider-from-dify-agent-env" MODEL_NAME = "replace-with-model-from-dify-agent-env" +PLUGIN_ID = "langgenius/openai" composition = RunComposition( layers=[ @@ -72,18 +76,19 @@ composition = RunComposition( config=PromptLayerConfig(prefix="You are concise.", user="Say hello."), ), RunLayerSpec( - name="plugin", - type=DIFY_PLUGIN_LAYER_TYPE_ID, - config=DifyPluginLayerConfig( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig( tenant_id="replace-with-tenant-id", - plugin_id="langgenius/openai", + invoke_from="workflow_run", ), ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID, - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=DifyPluginLLMLayerConfig( + plugin_id=PLUGIN_ID, model_provider=MODEL_PROVIDER, model=MODEL_NAME, credentials={"api_key": "replace-with-provider-key"}, @@ -96,6 +101,9 @@ composition = RunComposition( ## Notes - The model layer must use the reserved name `llm` (`DIFY_AGENT_MODEL_LAYER_ID`). +- `plugin_id` belongs here because model calls are plugin-specific business + calls. The shared execution-context layer only carries Dify run and + tenant/user daemon context. - Credential shape depends on the selected plugin provider; the OpenAI-style `api_key` field above is only an example. - Client-submitted model credentials remain in the scheduled request memory and diff --git a/dify-agent/docs/dify-agent/user-manual/plugin-tool-layer/index.md b/dify-agent/docs/dify-agent/user-manual/plugin-tool-layer/index.md new file mode 100644 index 0000000000..a7618462c9 --- /dev/null +++ b/dify-agent/docs/dify-agent/user-manual/plugin-tool-layer/index.md @@ -0,0 +1,130 @@ +# Plugin tool layer + +The plugin tool layer exposes Dify plugin tools to the model. It is designed for +Dify API to build after it has resolved a user's tool selections, plugin daemon +declarations, credentials, and manual/runtime inputs. + +Unlike the plugin LLM layer, this layer may contain tools from multiple plugin +packages. Each tool config carries its own `plugin_id`, while the shared +[execution context layer](../execution-context-layer/index.md) still carries +only tenant/user daemon context. + +## Responsibilities + +Dify API prepares the tool config before submitting the run request: + +- resolve the selected provider and tool name; +- merge declared parameters with runtime parameters; +- produce the model-visible JSON schema; +- provide hidden/manual `runtime_parameters` and credentials; +- choose the daemon `credential_type` for invocation. + +Dify Agent consumes that prepared config. At run time it validates required +hidden inputs, applies defaults, casts invocation values, calls plugin daemon, +and turns tool responses into model observations. + +## Config fields + +The plugin tools layer type id is `dify.plugin.tools`. + +`DifyPluginToolsLayerConfig` contains a list of `DifyPluginToolConfig` objects: + +| Field | Type | Meaning | +| --- | --- | --- | +| `tools` | `list[DifyPluginToolConfig]` | Prepared plugin tools to expose to the model. | + +Each tool config has these fields: + +| Field | Type | Meaning | +| --- | --- | --- | +| `plugin_id` | `str` | Plugin package id for this tool, for example `langgenius/wikipedia`. | +| `provider` | `str` | Tool provider name inside the plugin. | +| `tool_name` | `str` | Daemon tool name to invoke. | +| `credential_type` | `"api-key" \| "oauth2" \| "unauthorized"` | Credential mode sent to plugin daemon. | +| `name` | `str \| None` | Optional model-visible tool name. Defaults to `tool_name`. | +| `description` | `str \| None` | Optional model-visible description. Defaults to the tool name. | +| `credentials` | `dict[str, str \| int \| float \| bool \| None]` | Provider-specific tool credentials. | +| `runtime_parameters` | `dict[str, JsonValue]` | Hidden/manual values merged into daemon invocation but omitted from the model schema. | +| `parameters` | `list[DifyPluginToolParameter]` | API-prepared effective parameter declarations used for validation, defaults, and casting. | +| `parameters_json_schema` | `dict[str, JsonValue]` | API-prepared JSON schema shown to the model. | + +## Example: Dify API prepared Wikipedia tool + +```python {test="skip" lint="skip"} +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig +from dify_agent.layers.dify_plugin import ( + DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + DifyPluginToolConfig, + DifyPluginToolParameter, + DifyPluginToolsLayerConfig, +) +from dify_agent.protocol import RunComposition, RunLayerSpec + + +# Dify API side: resolve the selected tool into the API-side Tool runtime first, +# for example with ToolManager.get_agent_tool_runtime(...). Then adapt its +# effective ToolParameter objects at the protocol boundary. Dify Agent accepts +# both ToolParameter attribute objects and ToolParameter.model_dump(mode="json") +# dictionaries, ignoring API-only fields such as label and human_description. +tool_runtime = ToolManager.get_agent_tool_runtime(...) +effective_parameters = tool_runtime.get_merged_runtime_parameters() +prepared_parameters = [ + DifyPluginToolParameter.model_validate(parameter) + # If the API serializes first, use: + # DifyPluginToolParameter.model_validate(parameter.model_dump(mode="json")) + for parameter in effective_parameters +] +parameters_json_schema = tool_runtime.get_llm_parameters_json_schema() + +composition = RunComposition( + layers=[ + RunLayerSpec( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig( + tenant_id="replace-with-tenant-id", + user_id="replace-with-user-id", + invoke_from="workflow_run", + ), + ), + RunLayerSpec( + name="tools", + type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + deps={"execution_context": "execution_context"}, + config=DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/wikipedia", + provider="wikipedia", + tool_name="wikipedia_search", + credential_type="unauthorized", + name="wikipedia_search", + description="Search Wikipedia for relevant pages.", + parameters=prepared_parameters, + runtime_parameters={"language": "en"}, + parameters_json_schema=parameters_json_schema, + ) + ] + ), + ), + ] +) +``` + +`deps={"execution_context": "execution_context"}` means: bind the tool layer's +dependency field named `execution_context` to the composition layer named +`execution_context`. + +## Notes for Dify API callers + +- Do not ask Dify Agent to discover tool declarations. Resolve and prepare them + in API before creating the run. +- `parameters` should include all effective parameters, including hidden/manual + ones needed for validation and default application. +- `parameters_json_schema` should include only model-visible parameters. Omit + hidden/manual parameters and file/system-file parameters unless they are truly + intended for model input. +- `runtime_parameters` should contain hidden/manual values selected by the user + or derived from workflow variables. +- Put each tool's `plugin_id` on the tool config. The shared execution-context + layer has no package-specific identity. diff --git a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py index a3d0474b46..fb07c352d1 100644 --- a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py +++ b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py @@ -17,12 +17,11 @@ import asyncio from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig from dify_agent.client import Client +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig from dify_agent.layers.dify_plugin import ( - DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID, DifyPluginCredentialValue, DifyPluginLLMLayerConfig, - DifyPluginLayerConfig, ) from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, CreateRunRequest, RunComposition, RunLayerSpec @@ -50,20 +49,67 @@ async def main() -> None: ), ), RunLayerSpec( - name="plugin", - type=DIFY_PLUGIN_LAYER_TYPE_ID, - config=DifyPluginLayerConfig(tenant_id=TENANT_ID, plugin_id=PLUGIN_ID), + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig( + tenant_id=TENANT_ID, + invoke_from="workflow_run", + ), ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID, - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=DifyPluginLLMLayerConfig( + plugin_id=PLUGIN_ID, model_provider=PLUGIN_PROVIDER, model=MODEL_NAME, credentials=MODEL_CREDENTIALS, ), ), + # Minimal plugin-tools example. API callers should pass + # prepared parameters + JSON schema instead of relying on + # dify-agent to fetch and merge daemon declarations. + # from dify_agent.layers.dify_plugin import ( + # DifyPluginToolConfig, + # DifyPluginToolParameter, + # DifyPluginToolParameterForm, + # DifyPluginToolParameterType, + # DifyPluginToolsLayerConfig, + # ) + # RunLayerSpec( + # name="tools", + # type="dify.plugin.tools", + # deps={"execution_context": "execution_context"}, + # config=DifyPluginToolsLayerConfig( + # tools=[ + # DifyPluginToolConfig( + # plugin_id="langgenius/search", + # provider="search", + # tool_name="web_search", + # credential_type="api-key", + # credentials={"api_key": "replace-with-tool-key"}, + # runtime_parameters={"site": "docs.dify.ai"}, + # parameters=[ + # DifyPluginToolParameter( + # name="query", + # type=DifyPluginToolParameterType.STRING, + # form=DifyPluginToolParameterForm.LLM, + # required=True, + # llm_description="Search query", + # ), + # ], + # parameters_json_schema={ + # "type": "object", + # "properties": { + # "query": {"type": "string", "description": "Search query"} + # }, + # "required": ["query"], + # }, + # ) + # ] + # ), + # ), ], ), ) diff --git a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py index 3c789571f1..90ae65d39b 100644 --- a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py +++ b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py @@ -10,12 +10,11 @@ assuming the original request was not accepted. from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig from dify_agent.client import Client +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig from dify_agent.layers.dify_plugin import ( - DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID, DifyPluginCredentialValue, DifyPluginLLMLayerConfig, - DifyPluginLayerConfig, ) from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, CreateRunRequest, RunComposition, RunLayerSpec @@ -43,20 +42,67 @@ def main() -> None: ), ), RunLayerSpec( - name="plugin", - type=DIFY_PLUGIN_LAYER_TYPE_ID, - config=DifyPluginLayerConfig(tenant_id=TENANT_ID, plugin_id=PLUGIN_ID), + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig( + tenant_id=TENANT_ID, + invoke_from="workflow_run", + ), ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID, - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=DifyPluginLLMLayerConfig( + plugin_id=PLUGIN_ID, model_provider=PLUGIN_PROVIDER, model=MODEL_NAME, credentials=MODEL_CREDENTIALS, ), ), + # Minimal plugin-tools example. API callers should pass + # prepared parameters + JSON schema instead of relying on + # dify-agent to fetch and merge daemon declarations. + # from dify_agent.layers.dify_plugin import ( + # DifyPluginToolConfig, + # DifyPluginToolParameter, + # DifyPluginToolParameterForm, + # DifyPluginToolParameterType, + # DifyPluginToolsLayerConfig, + # ) + # RunLayerSpec( + # name="tools", + # type="dify.plugin.tools", + # deps={"execution_context": "execution_context"}, + # config=DifyPluginToolsLayerConfig( + # tools=[ + # DifyPluginToolConfig( + # plugin_id="langgenius/search", + # provider="search", + # tool_name="web_search", + # credential_type="api-key", + # credentials={"api_key": "replace-with-tool-key"}, + # runtime_parameters={"site": "docs.dify.ai"}, + # parameters=[ + # DifyPluginToolParameter( + # name="query", + # type=DifyPluginToolParameterType.STRING, + # form=DifyPluginToolParameterForm.LLM, + # required=True, + # llm_description="Search query", + # ), + # ], + # parameters_json_schema={ + # "type": "object", + # "properties": { + # "query": {"type": "string", "description": "Search query"} + # }, + # "required": ["query"], + # }, + # ) + # ] + # ), + # ), ], ), ) diff --git a/dify-agent/mkdocs.yml b/dify-agent/mkdocs.yml index c2fae8487f..ab66d3e72c 100644 --- a/dify-agent/mkdocs.yml +++ b/dify-agent/mkdocs.yml @@ -11,19 +11,18 @@ nav: - Agenton: - Overview: agenton/index.md - Guide: agenton/guide/index.md - - API Reference: agenton/api/index.md - Examples: agenton/examples/index.md - Dify Agent: - Overview: dify-agent/index.md - User Manual: - Get Started: dify-agent/get-started/index.md - Prompt Layer: dify-agent/user-manual/prompt-layer/index.md - - Plugin Layer: dify-agent/user-manual/plugin-layer/index.md + - Execution Context Layer: dify-agent/user-manual/execution-context-layer/index.md - Plugin LLM Layer: dify-agent/user-manual/plugin-llm-layer/index.md + - Plugin Tool Layer: dify-agent/user-manual/plugin-tool-layer/index.md - History Layer: dify-agent/user-manual/history-layer/index.md - Structured Output Layer: dify-agent/user-manual/structured-output-layer/index.md - Operations Guide: dify-agent/guide/index.md - - Run API: dify-agent/api/index.md - Examples: dify-agent/examples/index.md theme: diff --git a/dify-agent/src/dify_agent/adapters/llm/provider.py b/dify-agent/src/dify_agent/adapters/llm/provider.py index 6e7b92f646..a210cce1e3 100644 --- a/dify-agent/src/dify_agent/adapters/llm/provider.py +++ b/dify-agent/src/dify_agent/adapters/llm/provider.py @@ -8,7 +8,6 @@ this provider. from __future__ import annotations -import json from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field from typing import NoReturn @@ -22,6 +21,12 @@ from typing_extensions import override from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior, UserError from pydantic_ai.providers import Provider +from dify_agent.plugin_daemon_transport import ( + decode_plugin_daemon_error_payload, + to_plugin_daemon_jsonable, + unwrap_plugin_daemon_error, +) + _DEFAULT_DAEMON_TIMEOUT: float | httpx.Timeout | None = 600.0 @@ -83,7 +88,7 @@ class DifyPluginDaemonLLMClient: request_data: Mapping[str, object], response_model: type[T], ) -> AsyncIterator[T]: - payload: dict[str, object] = {"data": _to_jsonable(request_data)} + payload: dict[str, object] = {"data": to_plugin_daemon_jsonable(request_data)} if self.user_id is not None: payload["user_id"] = self.user_id @@ -97,14 +102,18 @@ class DifyPluginDaemonLLMClient: async with self.http_client.stream("POST", url, headers=headers, json=payload) as response: if response.is_error: body = (await response.aread()).decode("utf-8", errors="replace") - error = _decode_plugin_daemon_error_payload(body) + error = decode_plugin_daemon_error_payload(body) if error is not None: - _raise_plugin_daemon_error( - model_name=model_name, + resolved_error = unwrap_plugin_daemon_error( error_type=error["error_type"], message=error["message"], + ) + _raise_plugin_daemon_error( + model_name=model_name, + error_type=resolved_error["error_type"], + message=resolved_error["message"], status_code=response.status_code, - body=error, + body=resolved_error, ) raise ModelHTTPError(response.status_code, model_name, body or None) @@ -117,13 +126,17 @@ class DifyPluginDaemonLLMClient: wrapped = PluginDaemonBasicResponse.model_validate_json(line) if wrapped.code != 0: - error = _decode_plugin_daemon_error_payload(wrapped.message) + error = decode_plugin_daemon_error_payload(wrapped.message) if error is not None: - _raise_plugin_daemon_error( - model_name=model_name, + resolved_error = unwrap_plugin_daemon_error( error_type=error["error_type"], message=error["message"], - body=error, + ) + _raise_plugin_daemon_error( + model_name=model_name, + error_type=resolved_error["error_type"], + message=resolved_error["message"], + body=resolved_error, ) raise ModelAPIError( model_name, @@ -199,32 +212,6 @@ class DifyPluginDaemonProvider(Provider[DifyPluginDaemonLLMClient]): return self._client -def _to_jsonable(value: object) -> object: - if isinstance(value, BaseModel): - return value.model_dump(mode="json") - if isinstance(value, dict): - return {key: _to_jsonable(item) for key, item in value.items()} - if isinstance(value, list | tuple): - return [_to_jsonable(item) for item in value] - return value - - -def _decode_plugin_daemon_error_payload(raw_message: str) -> dict[str, str] | None: - try: - parsed = json.loads(raw_message) - except json.JSONDecodeError: - return None - - if not isinstance(parsed, dict): - return None - - error_type = parsed.get("error_type") - message = parsed.get("message") - if not isinstance(error_type, str) or not isinstance(message, str): - return None - return {"error_type": error_type, "message": message} - - def _raise_plugin_daemon_error( *, model_name: str, @@ -236,17 +223,6 @@ def _raise_plugin_daemon_error( http_error_body = body or {"error_type": error_type, "message": message} match error_type: - case "PluginInvokeError": - nested_error = _decode_plugin_daemon_error_payload(message) - if nested_error is not None: - _raise_plugin_daemon_error( - model_name=model_name, - error_type=nested_error["error_type"], - message=nested_error["message"], - status_code=status_code, - body=nested_error, - ) - raise ModelAPIError(model_name, message) case "PluginDaemonUnauthorizedError" | "InvokeAuthorizationError": raise ModelHTTPError(status_code or 401, model_name, http_error_body) case "PluginPermissionDeniedError": diff --git a/dify-agent/src/dify_agent/layers/dify_plugin/__init__.py b/dify-agent/src/dify_agent/layers/dify_plugin/__init__.py index 5b2f1dccce..bd719cea8f 100644 --- a/dify-agent/src/dify_agent/layers/dify_plugin/__init__.py +++ b/dify-agent/src/dify_agent/layers/dify_plugin/__init__.py @@ -1,21 +1,35 @@ -"""Client-safe exports for Dify plugin DTOs and public layer type ids. +"""Client-safe exports for Dify plugin business-layer DTOs and type ids. Implementation layers live in sibling modules and require server-side runtime dependencies. Keep this package root import-safe for client-only installs. """ from dify_agent.layers.dify_plugin.configs import ( - DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID, + DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, DifyPluginCredentialValue, DifyPluginLLMLayerConfig, - DifyPluginLayerConfig, + DifyPluginToolCredentialType, + DifyPluginToolConfig, + DifyPluginToolOption, + DifyPluginToolParameter, + DifyPluginToolParameterForm, + DifyPluginToolParameterType, + DifyPluginToolsLayerConfig, + DifyPluginToolValue, ) __all__ = [ - "DIFY_PLUGIN_LAYER_TYPE_ID", "DIFY_PLUGIN_LLM_LAYER_TYPE_ID", + "DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID", "DifyPluginCredentialValue", "DifyPluginLLMLayerConfig", - "DifyPluginLayerConfig", + "DifyPluginToolCredentialType", + "DifyPluginToolConfig", + "DifyPluginToolOption", + "DifyPluginToolParameter", + "DifyPluginToolParameterForm", + "DifyPluginToolParameterType", + "DifyPluginToolsLayerConfig", + "DifyPluginToolValue", ] diff --git a/dify-agent/src/dify_agent/layers/dify_plugin/configs.py b/dify-agent/src/dify_agent/layers/dify_plugin/configs.py index 5fff7dde51..25651482a5 100644 --- a/dify-agent/src/dify_agent/layers/dify_plugin/configs.py +++ b/dify-agent/src/dify_agent/layers/dify_plugin/configs.py @@ -1,38 +1,111 @@ -"""Client-safe DTOs for Dify plugin-backed Agenton layers. +"""Client-safe DTOs for Dify plugin-backed Agenton business layers. This module intentionally contains only public config schemas and scalar type -aliases plus stable layer type identifiers. Runtime objects such as HTTP -clients, server settings, and adapter implementations live in sibling -implementation modules so clients can build run requests without importing -server-only dependencies. +aliases plus stable plugin business-layer type identifiers. Runtime objects +such as HTTP clients, server settings, and adapter implementations live in +sibling implementation modules so clients can build run requests without +importing server-only dependencies. + +Shared tenant/user/run context now lives in the sibling +``dify_agent.layers.execution_context`` package. This module only covers the +plugin-backed LLM and tools layers that invoke daemon features with concrete +``plugin_id`` values. Tool configs also carry the API-side prepared parameter +declarations and model-visible JSON schema so the agent runtime does not have to +re-fetch and re-merge tool declarations at execution time. """ -from typing import ClassVar, Final, TypeAlias +from enum import StrEnum +from typing import ClassVar, Final, Literal, TypeAlias -from pydantic import ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator from pydantic_ai.settings import ModelSettings from agenton.layers import LayerConfig DifyPluginCredentialValue: TypeAlias = str | int | float | bool | None -DIFY_PLUGIN_LAYER_TYPE_ID: Final[str] = "dify.plugin" +DifyPluginToolCredentialType: TypeAlias = Literal["api-key", "oauth2", "unauthorized"] +DifyPluginToolValue: TypeAlias = JsonValue DIFY_PLUGIN_LLM_LAYER_TYPE_ID: Final[str] = "dify.plugin.llm" +DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID: Final[str] = "dify.plugin.tools" -class DifyPluginLayerConfig(LayerConfig): - """Public config for the plugin daemon tenant/plugin context layer.""" +class DifyPluginToolOption(BaseModel): + """Selectable tool option value exposed to the model. - tenant_id: str - plugin_id: str - user_id: str | None = None + The DTO also accepts API-side option dumps and attribute objects. Fields + such as ``label`` or ``icon`` are intentionally ignored because Dify Agent + only preserves the normalized option ``value`` for tool invocation and + model-visible schema generation. + """ - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + value: str + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore", from_attributes=True) + + @field_validator("value", mode="before") + @classmethod + def stringify_value(cls, value: object) -> str: + return value if isinstance(value, str) else str(value) + + +class DifyPluginToolParameterType(StrEnum): + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + SELECT = "select" + SECRET_INPUT = "secret-input" + FILE = "file" + FILES = "files" + APP_SELECTOR = "app-selector" + MODEL_SELECTOR = "model-selector" + ANY = "any" + DYNAMIC_SELECT = "dynamic-select" + CHECKBOX = "checkbox" + SYSTEM_FILES = "system-files" + ARRAY = "array" + OBJECT = "object" + + def as_normal_type(self) -> str: + if self in { + DifyPluginToolParameterType.SECRET_INPUT, + DifyPluginToolParameterType.SELECT, + DifyPluginToolParameterType.CHECKBOX, + }: + return "string" + return self.value + + +class DifyPluginToolParameterForm(StrEnum): + SCHEMA = "schema" + FORM = "form" + LLM = "llm" + + +class DifyPluginToolParameter(BaseModel): + """Prepared tool parameter declaration supplied by the API side. + + The DTO intentionally accepts both API-side ``ToolParameter`` dumps and + attribute objects so callers can adapt existing tool runtime declarations + without coupling Dify Agent to API-internal model classes. + """ + + name: str + type: DifyPluginToolParameterType + form: DifyPluginToolParameterForm + required: bool = False + default: DifyPluginToolValue = None + llm_description: str | None = None + input_schema: dict[str, JsonValue] | None = None + options: list[DifyPluginToolOption] = Field(default_factory=list) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore", from_attributes=True) class DifyPluginLLMLayerConfig(LayerConfig): - """Public config for selecting a business provider/model from a plugin.""" + """Public config for selecting a plugin-backed business provider/model.""" + plugin_id: str model_provider: str model: str credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict) @@ -41,10 +114,64 @@ class DifyPluginLLMLayerConfig(LayerConfig): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True) +class DifyPluginToolConfig(LayerConfig): + """Public config for exposing one plugin tool to the agent model. + + ``credential_type`` is an explicit caller-supplied daemon transport choice, + not an auto-discovered property. It must match the actual credential mode of + ``credentials`` for the configured plugin tool, for example ``"api-key"`` + versus ``"oauth2"``. A wrong value can make invocation fail at runtime even + when the config itself validates successfully. + + ``runtime_parameters`` mirrors Dify's agent-node hidden/manual tool inputs: + those values are merged into the actual daemon invocation but omitted from + the tool schema shown to the model. + + ``parameters`` and ``parameters_json_schema`` are API-side prepared tool + declaration artifacts. They let the agent runtime validate hidden/default + inputs and expose the correct LLM-facing schema without re-fetching or + re-merging daemon declarations at run time. + """ + + plugin_id: str + provider: str + tool_name: str + credential_type: DifyPluginToolCredentialType + name: str | None = None + description: str | None = None + credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict) + runtime_parameters: dict[str, DifyPluginToolValue] = Field(default_factory=dict) + parameters: list[DifyPluginToolParameter] = Field(default_factory=list) + parameters_json_schema: dict[str, JsonValue] = Field( + default_factory=lambda: {"type": "object", "properties": {}, "required": []} + ) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + +class DifyPluginToolsLayerConfig(LayerConfig): + """Public config for the Dify plugin tools layer. + + Callers configure the tools layer with this wrapper object and supply one + or more prepared ``DifyPluginToolConfig`` entries in ``tools``. + """ + + tools: list[DifyPluginToolConfig] = Field(default_factory=list) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + __all__ = [ - "DIFY_PLUGIN_LAYER_TYPE_ID", "DIFY_PLUGIN_LLM_LAYER_TYPE_ID", + "DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID", "DifyPluginCredentialValue", "DifyPluginLLMLayerConfig", - "DifyPluginLayerConfig", + "DifyPluginToolCredentialType", + "DifyPluginToolConfig", + "DifyPluginToolOption", + "DifyPluginToolParameter", + "DifyPluginToolParameterForm", + "DifyPluginToolParameterType", + "DifyPluginToolsLayerConfig", + "DifyPluginToolValue", ] diff --git a/dify-agent/src/dify_agent/layers/dify_plugin/llm_layer.py b/dify-agent/src/dify_agent/layers/dify_plugin/llm_layer.py index 4ac053df3f..48e6c5508d 100644 --- a/dify-agent/src/dify_agent/layers/dify_plugin/llm_layer.py +++ b/dify-agent/src/dify_agent/layers/dify_plugin/llm_layer.py @@ -1,15 +1,17 @@ """Dify plugin LLM model layer. This layer owns model capability resolution for Dify plugin-backed LLMs. It -depends on ``DifyPluginLayer`` for daemon identity through Agenton's direct -dependency binding and returns a Pydantic AI model adapter configured from the -public LLM layer DTO. Runtime code supplies the FastAPI lifespan-owned shared -HTTP client to ``get_model``; the layer does not own or discover live resources. -The daemon provider carries plugin transport identity, while the DTO's -``model_provider`` is passed to the adapter as request-level model identity. +depends on ``DifyExecutionContextLayer`` for shared daemon settings through +Agenton's direct dependency binding and returns a Pydantic AI model adapter +configured from the public LLM layer DTO. Runtime code supplies the FastAPI +lifespan-owned shared HTTP client to ``get_model``; the layer does not own or +discover live resources. The daemon provider carries plugin transport identity, +while the DTO's ``model_provider`` is passed to the adapter as request-level +model identity. """ from dataclasses import dataclass +from typing import ClassVar import httpx from typing_extensions import Self, override @@ -17,20 +19,20 @@ from typing_extensions import Self, override from agenton.layers import LayerDeps, PlainLayer from dify_agent.adapters.llm import DifyLLMAdapterModel from dify_agent.layers.dify_plugin.configs import DIFY_PLUGIN_LLM_LAYER_TYPE_ID, DifyPluginLLMLayerConfig -from dify_agent.layers.dify_plugin.plugin_layer import DifyPluginLayer +from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer class DifyPluginLLMDeps(LayerDeps): """Dependencies required by ``DifyPluginLLMLayer``.""" - plugin: DifyPluginLayer # pyright: ignore[reportUninitializedInstanceVariable] + execution_context: DifyExecutionContextLayer # pyright: ignore[reportUninitializedInstanceVariable] @dataclass(slots=True) class DifyPluginLLMLayer(PlainLayer[DifyPluginLLMDeps, DifyPluginLLMLayerConfig]): """Layer that creates the Dify plugin-daemon Pydantic AI model.""" - type_id = DIFY_PLUGIN_LLM_LAYER_TYPE_ID + type_id: ClassVar[str] = DIFY_PLUGIN_LLM_LAYER_TYPE_ID config: DifyPluginLLMLayerConfig @@ -41,8 +43,11 @@ class DifyPluginLLMLayer(PlainLayer[DifyPluginLLMDeps, DifyPluginLLMLayerConfig] return cls(config=config) def get_model(self, *, http_client: httpx.AsyncClient) -> DifyLLMAdapterModel: - """Return the configured model using the directly bound plugin dependency.""" - provider = self.deps.plugin.create_daemon_provider(http_client=http_client) + """Return the configured model using the directly bound execution context.""" + provider = self.deps.execution_context.create_daemon_provider( + plugin_id=self.config.plugin_id, + http_client=http_client, + ) return DifyLLMAdapterModel( model=self.config.model, daemon_provider=provider, diff --git a/dify-agent/src/dify_agent/layers/dify_plugin/plugin_layer.py b/dify-agent/src/dify_agent/layers/dify_plugin/plugin_layer.py deleted file mode 100644 index 71c649b6de..0000000000 --- a/dify-agent/src/dify_agent/layers/dify_plugin/plugin_layer.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Runtime Dify plugin context layer. - -The public config identifies tenant/plugin/user context only. Plugin daemon URL -and API key are server-side settings injected by the provider factory. The layer -is intentionally config/settings-only under Agenton's state-only core: it does -not open, cache, close, or snapshot HTTP clients, and its lifecycle hooks remain -the inherited no-op hooks. Runtime code passes the FastAPI lifespan-owned shared -``httpx.AsyncClient`` into ``create_daemon_provider`` for each model adapter. -Business model-provider names belong to the LLM layer/model request, not this -daemon context layer. -""" - -from dataclasses import dataclass - -import httpx -from typing_extensions import Self, override - -from agenton.layers import EmptyRuntimeState, NoLayerDeps, PlainLayer -from dify_agent.adapters.llm import DifyPluginDaemonProvider -from dify_agent.layers.dify_plugin.configs import DIFY_PLUGIN_LAYER_TYPE_ID, DifyPluginLayerConfig - - -@dataclass(slots=True) -class DifyPluginLayer(PlainLayer[NoLayerDeps, DifyPluginLayerConfig, EmptyRuntimeState]): - """Layer that carries plugin daemon identity without owning live resources.""" - - type_id = DIFY_PLUGIN_LAYER_TYPE_ID - - config: DifyPluginLayerConfig - daemon_url: str - daemon_api_key: str - - @classmethod - @override - def from_config(cls, config: DifyPluginLayerConfig) -> Self: - """Reject construction without server-injected daemon settings.""" - del config - raise TypeError("DifyPluginLayer requires server-side daemon settings and must use a provider factory.") - - @classmethod - def from_config_with_settings( - cls, - config: DifyPluginLayerConfig, - *, - daemon_url: str, - daemon_api_key: str, - ) -> Self: - """Create a plugin layer from public config plus server-only daemon settings.""" - return cls(config=config, daemon_url=daemon_url, daemon_api_key=daemon_api_key) - - def create_daemon_provider(self, *, http_client: httpx.AsyncClient) -> DifyPluginDaemonProvider: - """Return a daemon provider backed by the shared plugin daemon client. - - Raises: - RuntimeError: if ``http_client`` has already been closed. - """ - if http_client.is_closed: - raise RuntimeError("DifyPluginLayer.create_daemon_provider() requires an open shared HTTP client.") - return DifyPluginDaemonProvider( - tenant_id=self.config.tenant_id, - plugin_id=self.config.plugin_id, - plugin_daemon_url=self.daemon_url, - plugin_daemon_api_key=self.daemon_api_key, - user_id=self.config.user_id, - http_client=http_client, - ) - - -__all__ = ["DifyPluginLayer"] diff --git a/dify-agent/src/dify_agent/layers/dify_plugin/tool_client.py b/dify-agent/src/dify_agent/layers/dify_plugin/tool_client.py new file mode 100644 index 0000000000..db65265464 --- /dev/null +++ b/dify-agent/src/dify_agent/layers/dify_plugin/tool_client.py @@ -0,0 +1,333 @@ +"""Async plugin-daemon client for Dify plugin tool invocation. + +The agent runtime talks to the plugin daemon rather than importing provider SDKs +directly. The tools layer now consumes API-prepared declarations from config, so +this module only keeps the invoke-time boundary: + +- POST ``/plugin/{tenant_id}/dispatch/tool/invoke`` +- request headers ``X-Api-Key``, ``X-Plugin-ID``, and ``Content-Type`` +- top-level ``user_id`` forwarding when shared execution context includes one +- stream decoding and blob-chunk merging for agent observations + +The shared execution-context layer still owns tenant/user daemon context, while +each tool's own ``plugin_id`` determines the transport identity placed in +``X-Plugin-ID``. +""" + +from __future__ import annotations + +import base64 +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass, field +from enum import StrEnum + +import httpx +from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator + +from dify_agent.layers.dify_plugin.configs import DifyPluginToolCredentialType +from dify_agent.plugin_daemon_transport import ( + decode_plugin_daemon_error_payload, + to_plugin_daemon_jsonable, + unwrap_plugin_daemon_error, +) + + +class PluginDaemonBasicResponse(BaseModel): + """Common plugin-daemon stream and JSON wrapper.""" + + code: int + message: str + data: object | None = None + + +@dataclass(slots=True) +class FileChunk: + """Buffer for accumulating streamed blob chunks.""" + + total_length: int + bytes_written: int = field(default=0, init=False) + data: bytearray = field(init=False) + + def __post_init__(self) -> None: + self.data = bytearray(self.total_length) + + +class DifyPluginToolInvokeMessage(BaseModel): + """Subset of Dify tool stream messages needed for agent observations.""" + + class TextMessage(BaseModel): + text: str + + class JsonMessage(BaseModel): + json_object: dict[str, object] | list[object] + suppress_output: bool = False + + class BlobMessage(BaseModel): + blob: bytes + + class BlobChunkMessage(BaseModel): + id: str + sequence: int + total_length: int + blob: bytes + end: bool + + class FileMessage(BaseModel): + file_marker: str = "file_marker" + + @model_validator(mode="before") + @classmethod + def validate_file_marker(cls, values: object) -> object: + if isinstance(values, dict) and "file_marker" not in values: + raise ValueError("Invalid FileMessage: missing file_marker") + return values + + class VariableMessage(BaseModel): + variable_name: str + variable_value: object + stream: bool = False + + class LogMessage(BaseModel): + id: str + label: str + parent_id: str | None = None + error: str | None = None + status: str + data: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) + + class MessageType(StrEnum): + TEXT = "text" + IMAGE = "image" + LINK = "link" + BLOB = "blob" + JSON = "json" + IMAGE_LINK = "image_link" + BINARY_LINK = "binary_link" + VARIABLE = "variable" + FILE = "file" + LOG = "log" + BLOB_CHUNK = "blob_chunk" + + type: MessageType = MessageType.TEXT + message: ( + TextMessage | JsonMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | VariableMessage | None + ) + meta: dict[str, object] | None = None + + @field_validator("message", mode="before") + @classmethod + def decode_message(cls, value: object, info: ValidationInfo) -> object: + if isinstance(value, dict) and "blob" in value: + try: + value = {**value, "blob": base64.b64decode(value["blob"])} + except Exception: + return value + + msg_type = info.data.get("type") if isinstance(info.data, dict) else None + if msg_type == cls.MessageType.JSON and isinstance(value, dict) and "json_object" not in value: + return {"json_object": value} + if msg_type == cls.MessageType.FILE and isinstance(value, dict): + return {"file_marker": value.get("file_marker", "file_marker")} + return value + + +class DifyPluginToolClientError(Exception): + """Raised when the plugin daemon rejects a tool-layer request.""" + + error_type: str | None + status_code: int | None + + def __init__(self, message: str, *, error_type: str | None = None, status_code: int | None = None) -> None: + super().__init__(message) + self.error_type = error_type + self.status_code = status_code + + +@dataclass(slots=True) +class DifyPluginDaemonToolClient: + """HTTP wrapper for the invoke-only plugin-daemon tool boundary. + + Callers provide business-level provider/tool/credential data per invocation, + while this client supplies daemon transport identity from shared runtime + context: tenant path segment, daemon API key, plugin-specific ``X-Plugin-ID`` + header, and optional top-level ``user_id``. + """ + + plugin_daemon_url: str + plugin_daemon_api_key: str + tenant_id: str + plugin_id: str + user_id: str | None + http_client: httpx.AsyncClient = field(repr=False) + + def __post_init__(self) -> None: + self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/") + + async def invoke( + self, + *, + provider: str, + tool_name: str, + credential_type: DifyPluginToolCredentialType, + credentials: dict[str, object], + tool_parameters: Mapping[str, object], + ) -> list[DifyPluginToolInvokeMessage]: + """Invoke a plugin tool and collect its observation stream.""" + raw_messages = [ + item + async for item in self._iter_stream_response( + path=f"plugin/{self.tenant_id}/dispatch/tool/invoke", + request_data={ + "provider": provider, + "tool": tool_name, + "credentials": credentials, + "credential_type": credential_type, + "tool_parameters": dict(tool_parameters), + }, + response_model=DifyPluginToolInvokeMessage, + ) + ] + return merge_blob_chunks(raw_messages) + + async def _iter_stream_response[T: BaseModel]( + self, + *, + path: str, + request_data: Mapping[str, object], + response_model: type[T], + ) -> AsyncIterator[T]: + """Send one daemon stream request and yield typed items. + + The daemon expects the actual invoke payload nested under ``data``. When + the shared plugin context included ``user_id``, it is forwarded as a + top-level peer to ``data`` so daemon-side auditing and credential logic + can attribute the request to the end user. + """ + payload: dict[str, object] = {"data": to_plugin_daemon_jsonable(dict(request_data))} + if self.user_id is not None: + payload["user_id"] = self.user_id + + url = f"{self.plugin_daemon_url}/{path}" + async with self.http_client.stream("POST", url, headers=self._headers(), json=payload) as response: + if response.is_error: + body = (await response.aread()).decode("utf-8", errors="replace") + error = decode_plugin_daemon_error_payload(body) + if error is not None: + resolved_error = unwrap_plugin_daemon_error( + error_type=error["error_type"], + message=error["message"], + ) + _raise_tool_daemon_error( + error_type=resolved_error["error_type"], + message=resolved_error["message"], + status_code=response.status_code, + ) + raise DifyPluginToolClientError( + body or "Plugin daemon stream request failed.", status_code=response.status_code + ) + + async for raw_line in response.aiter_lines(): + line = raw_line.strip() + if not line: + continue + if line.startswith("data:"): + line = line[5:].strip() + + wrapped = PluginDaemonBasicResponse.model_validate_json(line) + if wrapped.code != 0: + error = decode_plugin_daemon_error_payload(wrapped.message) + if error is not None: + resolved_error = unwrap_plugin_daemon_error( + error_type=error["error_type"], + message=error["message"], + ) + _raise_tool_daemon_error( + error_type=resolved_error["error_type"], + message=resolved_error["message"], + ) + raise DifyPluginToolClientError(wrapped.message or "Plugin daemon returned an error stream item.") + if wrapped.data is None: + raise DifyPluginToolClientError("Plugin daemon returned an empty stream item.") + yield response_model.model_validate(wrapped.data) + + def _headers(self) -> dict[str, str]: + """Build required plugin-daemon transport headers for tool invocation.""" + return { + "X-Api-Key": self.plugin_daemon_api_key, + "X-Plugin-ID": self.plugin_id, + "Content-Type": "application/json", + } + + +def merge_blob_chunks( + response: list[DifyPluginToolInvokeMessage], + *, + max_file_size: int = 30 * 1024 * 1024, + max_chunk_size: int = 8192, +) -> list[DifyPluginToolInvokeMessage]: + """Merge streamed blob chunks into complete blob messages. + + This mirrors Dify API's plugin-daemon chunk-merging behavior before the + higher-level observation conversion logic sees tool stream messages. + """ + files: dict[str, FileChunk] = {} + merged_messages: list[DifyPluginToolInvokeMessage] = [] + + for resp in response: + if resp.type is DifyPluginToolInvokeMessage.MessageType.BLOB_CHUNK: + if not isinstance(resp.message, DifyPluginToolInvokeMessage.BlobChunkMessage): + raise TypeError("Blob chunk responses must carry BlobChunkMessage payloads.") + + chunk_id = resp.message.id + total_length = resp.message.total_length + blob_data = resp.message.blob + is_end = resp.message.end + + if chunk_id not in files: + files[chunk_id] = FileChunk(total_length) + + if files[chunk_id].bytes_written + len(blob_data) > max_file_size: + del files[chunk_id] + raise ValueError(f"File is too large which reached the limit of {max_file_size / 1024 / 1024}MB") + if len(blob_data) > max_chunk_size: + raise ValueError(f"File chunk is too large which reached the limit of {max_chunk_size / 1024}KB") + + files[chunk_id].data[files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)] = ( + blob_data + ) + files[chunk_id].bytes_written += len(blob_data) + + if is_end: + merged_messages.append( + DifyPluginToolInvokeMessage( + type=DifyPluginToolInvokeMessage.MessageType.BLOB, + message=DifyPluginToolInvokeMessage.BlobMessage( + blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written]) + ), + meta=resp.meta, + ) + ) + del files[chunk_id] + else: + merged_messages.append(resp) + + return merged_messages + + +def _raise_tool_daemon_error( + *, + error_type: str, + message: str, + status_code: int | None = None, +) -> None: + raise DifyPluginToolClientError(message, error_type=error_type, status_code=status_code) + + +__all__ = [ + "DifyPluginDaemonToolClient", + "DifyPluginToolClientError", + "DifyPluginToolCredentialType", + "DifyPluginToolInvokeMessage", + "merge_blob_chunks", +] diff --git a/dify-agent/src/dify_agent/layers/dify_plugin/tools_layer.py b/dify-agent/src/dify_agent/layers/dify_plugin/tools_layer.py new file mode 100644 index 0000000000..5ed4a5ea33 --- /dev/null +++ b/dify-agent/src/dify_agent/layers/dify_plugin/tools_layer.py @@ -0,0 +1,341 @@ +"""Dify plugin tools layer for agent-accessible plugin tools. + +This layer consumes API-prepared plugin tool declarations. The API side is +responsible for resolving daemon declarations, applying runtime-parameter +overrides, and producing the clean LLM-facing JSON schema. At run time the layer +only validates hidden/manual inputs, prepares invocation arguments, and maps +daemon responses into agent-friendly observations. + +Like the LLM layer, this layer never owns live HTTP clients. The runtime passes +the FastAPI lifespan-owned shared client into ``get_tools`` so the layer can +build Pydantic AI tool adapters on demand. +""" + +from __future__ import annotations + +from copy import deepcopy +import json +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import ClassVar + +import httpx +from pydantic_ai import RunContext, Tool +from pydantic_ai.tools import ToolDefinition +from typing_extensions import Self, override + +from agenton.layers import LayerDeps, PlainLayer +from dify_agent.layers.dify_plugin.configs import ( + DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + DifyPluginToolConfig, + DifyPluginToolParameter, + DifyPluginToolParameterForm, + DifyPluginToolParameterType, + DifyPluginToolsLayerConfig, +) +from dify_agent.layers.dify_plugin.tool_client import ( + DifyPluginDaemonToolClient, + DifyPluginToolClientError, + DifyPluginToolInvokeMessage, +) +from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer + + +# Plugin tools intentionally do not expose a per-tool strictness override in the +# public config. The API supplies already-prepared schemas, but Dify Agent always +# registers those tools in loose mode so daemon tool invocation stays tolerant of +# plugin schema differences and older API-prepared payloads. +PLUGIN_TOOL_STRICT = False + + +class DifyPluginToolsDeps(LayerDeps): + """Dependencies required by ``DifyPluginToolsLayer``.""" + + execution_context: DifyExecutionContextLayer # pyright: ignore[reportUninitializedInstanceVariable] + + +@dataclass(slots=True) +class DifyPluginToolsLayer(PlainLayer[DifyPluginToolsDeps, DifyPluginToolsLayerConfig]): + """Layer that resolves Dify plugin tools into Pydantic AI tools.""" + + type_id: ClassVar[str] = DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID + + config: DifyPluginToolsLayerConfig + + @classmethod + @override + def from_config(cls, config: DifyPluginToolsLayerConfig) -> Self: + """Create the tools layer from validated public config.""" + return cls(config=DifyPluginToolsLayerConfig.model_validate(config)) + + async def get_tools(self, *, http_client: httpx.AsyncClient) -> list[Tool[object]]: + """Build Pydantic AI tool adapters from prepared plugin tool config.""" + tool_clients: dict[str, DifyPluginDaemonToolClient] = {} + tools: list[Tool[object]] = [] + + for tool_config in self.config.tools: + client = tool_clients.get(tool_config.plugin_id) + if client is None: + client = self.deps.execution_context.create_tool_client( + plugin_id=tool_config.plugin_id, + http_client=http_client, + ) + tool_clients[tool_config.plugin_id] = client + effective_parameters = [parameter.model_copy(deep=True) for parameter in tool_config.parameters] + _validate_required_hidden_parameters(tool_config, effective_parameters) + + tools.append( + _build_pydantic_ai_tool( + client=client, + tool_config=tool_config, + effective_parameters=effective_parameters, + ) + ) + + return tools + + +def _validate_required_hidden_parameters( + tool_config: DifyPluginToolConfig, + effective_parameters: Sequence[DifyPluginToolParameter], +) -> None: + missing_names = [ + parameter.name + for parameter in effective_parameters + if parameter.form is not DifyPluginToolParameterForm.LLM + and parameter.required + and parameter.default is None + and parameter.name not in tool_config.runtime_parameters + ] + if missing_names: + names = ", ".join(sorted(missing_names)) + raise ValueError(f"Tool '{tool_config.tool_name}' requires non-LLM runtime_parameters for: {names}.") + + +def _build_pydantic_ai_tool( + *, + client: DifyPluginDaemonToolClient, + tool_config: DifyPluginToolConfig, + effective_parameters: Sequence[DifyPluginToolParameter], +) -> Tool[object]: + tool_name = tool_config.name or tool_config.tool_name + tool_description = tool_config.description or tool_name + tool_schema = deepcopy(tool_config.parameters_json_schema) + + async def invoke_tool(_ctx: RunContext[object], **tool_arguments: object) -> str: + try: + merged_arguments = _prepare_tool_arguments(effective_parameters, tool_config, tool_arguments) + messages = await client.invoke( + provider=tool_config.provider, + tool_name=tool_config.tool_name, + credential_type=tool_config.credential_type, + credentials=dict(tool_config.credentials), + tool_parameters=merged_arguments, + ) + return _convert_tool_response_to_text(messages) + except DifyPluginToolClientError as exc: + return _tool_error_text(tool_name=tool_name, error=exc) + except ValueError as exc: + return f"tool parameters validation error: {exc}, please check your tool parameters" + + async def prepare_tool_definition(_ctx: RunContext[object], tool_def: ToolDefinition) -> ToolDefinition: + return ToolDefinition( + name=tool_def.name, + description=tool_def.description, + parameters_json_schema=tool_schema, + strict=PLUGIN_TOOL_STRICT, + sequential=tool_def.sequential, + metadata=tool_def.metadata, + timeout=tool_def.timeout, + defer_loading=tool_def.defer_loading, + kind=tool_def.kind, + return_schema=tool_def.return_schema, + include_return_schema=tool_def.include_return_schema, + ) + + return Tool( + invoke_tool, + takes_ctx=True, + name=tool_name, + description=tool_description, + prepare=prepare_tool_definition, + ) + + +def _prepare_tool_arguments( + effective_parameters: Sequence[DifyPluginToolParameter], + tool_config: DifyPluginToolConfig, + tool_arguments: Mapping[str, object], +) -> dict[str, object]: + """Build the daemon invocation payload from prepared config + model args. + + Argument precedence intentionally mirrors the old Dify tool runtime contract: + + 1. start from config-supplied ``runtime_parameters`` for hidden/manual inputs; + 2. let model-supplied tool arguments override same-named entries; + 3. if neither provided a value, fall back to the prepared parameter default; + 4. if a required parameter still has no value, raise validation error. + + Only parameters declared in ``effective_parameters`` are type-cast here; + extra merged keys are passed through unchanged for forward compatibility with + prepared config that may contain additional daemon inputs. + """ + merged_arguments: dict[str, object] = dict(tool_config.runtime_parameters) + merged_arguments.update(tool_arguments) + prepared_arguments: dict[str, object] = {} + + for parameter in effective_parameters: + if parameter.name in merged_arguments: + value = merged_arguments[parameter.name] + elif parameter.default is not None: + value = parameter.default + elif parameter.required: + raise ValueError(f"tool parameter {parameter.name} not found in tool config") + else: + continue + prepared_arguments[parameter.name] = _cast_tool_parameter_value(parameter.type, value) + + for key, value in merged_arguments.items(): + prepared_arguments.setdefault(key, value) + return prepared_arguments + + +def _cast_tool_parameter_value(parameter_type: DifyPluginToolParameterType, value: object) -> object: + """Cast prepared tool argument values into daemon-facing wire shapes. + + The API side prepares declaration metadata, but the actual invocation payload + still needs to match Dify plugin-daemon expectations. This helper keeps the + runtime-side coercion rules for common scalar, collection, file, and selector + parameter types so model-supplied JSON values and config-supplied hidden + inputs are normalized before transport. + """ + match parameter_type: + case ( + DifyPluginToolParameterType.STRING + | DifyPluginToolParameterType.SECRET_INPUT + | DifyPluginToolParameterType.SELECT + | DifyPluginToolParameterType.CHECKBOX + | DifyPluginToolParameterType.DYNAMIC_SELECT + ): + return "" if value is None else value if isinstance(value, str) else str(value) + case DifyPluginToolParameterType.BOOLEAN: + if value is None: + return False + if isinstance(value, str): + lowered = value.lower() + if lowered in {"true", "yes", "y", "1"}: + return True + if lowered in {"false", "no", "n", "0"}: + return False + return value if isinstance(value, bool) else bool(value) + case DifyPluginToolParameterType.NUMBER: + if isinstance(value, int | float): + return value + if isinstance(value, str) and value: + return float(value) if "." in value else int(value) + return value + case DifyPluginToolParameterType.SYSTEM_FILES | DifyPluginToolParameterType.FILES: + return value if isinstance(value, list) else [value] + case DifyPluginToolParameterType.FILE: + if isinstance(value, list): + if len(value) != 1: + raise ValueError("This parameter only accepts one file but got multiple files while invoking.") + return value[0] + return value + case DifyPluginToolParameterType.MODEL_SELECTOR | DifyPluginToolParameterType.APP_SELECTOR: + if not isinstance(value, dict): + raise ValueError("The selector must be a dictionary.") + return value + case DifyPluginToolParameterType.ANY: + if value is not None and not isinstance(value, dict | list | str | int | float | bool): + raise ValueError("The var selector must be a string, dictionary, list or number.") + return value + case DifyPluginToolParameterType.ARRAY: + if isinstance(value, list): + return value + if isinstance(value, str): + try: + parsed_value = json.loads(value) + except json.JSONDecodeError: + return [value] + if isinstance(parsed_value, list): + return parsed_value + return [value] + case DifyPluginToolParameterType.OBJECT: + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + parsed_value = json.loads(value) + except json.JSONDecodeError: + return {} + if isinstance(parsed_value, dict): + return parsed_value + return {} + + raise AssertionError(f"Unsupported tool parameter type: {parameter_type}") + + +def _tool_error_text(*, tool_name: str, error: DifyPluginToolClientError) -> str: + """Map expected daemon/tool failures into agent-visible observation text. + + Only known plugin-daemon rejection categories should be softened into tool + observations. Unexpected local bugs are intentionally not handled here and + should propagate so tests and callers notice the regression. + """ + error_type = error.error_type or "" + if any(token in error_type for token in ("Credential", "Authorization", "Unauthorized")): + return "Please check your tool provider credentials" + if any(token in error_type for token in ("ToolNotFound", "ProviderNotFound")): + return f"there is not a tool named {tool_name}" + if error.status_code == 400 or any(token in error_type for token in ("BadRequest", "Validate", "Validation")): + return f"tool parameters validation error: {error}, please check your tool parameters" + return f"tool invoke error: {error}" + + +def _convert_tool_response_to_text(tool_response: Sequence[DifyPluginToolInvokeMessage]) -> str: + """Convert daemon stream messages into the plain-text tool observation. + + This preserves the user-facing semantics Dify's agent tool runtime relies on: + text is appended directly, links/images become user-check instructions, JSON + output is included unless explicitly suppressed, variable messages stay + internal, and everything else falls back to ``str(message)``. JSON fragments + are deduplicated against existing text so mixed text/JSON streams do not + repeat the same content unnecessarily. + """ + parts: list[str] = [] + json_parts: list[str] = [] + + for response in tool_response: + if response.type is DifyPluginToolInvokeMessage.MessageType.TEXT: + text_message = response.message + if isinstance(text_message, DifyPluginToolInvokeMessage.TextMessage): + parts.append(text_message.text) + elif response.type is DifyPluginToolInvokeMessage.MessageType.LINK: + link_message = response.message + if isinstance(link_message, DifyPluginToolInvokeMessage.TextMessage): + parts.append(f"result link: {link_message.text}. please tell user to check it.") + elif response.type in { + DifyPluginToolInvokeMessage.MessageType.IMAGE_LINK, + DifyPluginToolInvokeMessage.MessageType.IMAGE, + }: + parts.append( + "image has been created and sent to user already, " + "you do not need to create it, just tell the user to check it now." + ) + elif response.type is DifyPluginToolInvokeMessage.MessageType.JSON: + json_message = response.message + if isinstance(json_message, DifyPluginToolInvokeMessage.JsonMessage) and not json_message.suppress_output: + json_parts.append(json.dumps(json_message.json_object, ensure_ascii=False, default=str)) + elif response.type is DifyPluginToolInvokeMessage.MessageType.VARIABLE: + continue + else: + parts.append(str(response.message)) + + if json_parts: + existing_parts = set(parts) + parts.extend(part for part in json_parts if part not in existing_parts) + return "".join(parts) + + +__all__ = ["DifyPluginToolsDeps", "DifyPluginToolsLayer"] diff --git a/dify-agent/src/dify_agent/layers/execution_context/__init__.py b/dify-agent/src/dify_agent/layers/execution_context/__init__.py new file mode 100644 index 0000000000..daf67ef7db --- /dev/null +++ b/dify-agent/src/dify_agent/layers/execution_context/__init__.py @@ -0,0 +1,18 @@ +"""Client-safe exports for the Dify execution-context layer DTOs. + +Implementation layers live in sibling modules and require server-side runtime +dependencies. Keep this package root import-safe for client code that only +needs to build run requests. +""" + +from dify_agent.layers.execution_context.configs import ( + DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + DifyExecutionContextInvokeFrom, + DifyExecutionContextLayerConfig, +) + +__all__ = [ + "DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID", + "DifyExecutionContextInvokeFrom", + "DifyExecutionContextLayerConfig", +] diff --git a/dify-agent/src/dify_agent/layers/execution_context/configs.py b/dify-agent/src/dify_agent/layers/execution_context/configs.py new file mode 100644 index 0000000000..e5eedbba3c --- /dev/null +++ b/dify-agent/src/dify_agent/layers/execution_context/configs.py @@ -0,0 +1,50 @@ +"""Client-safe DTOs for the Dify execution-context Agenton layer. + +This layer carries Dify-owned execution identifiers plus the tenant/user daemon +transport context shared by plugin-backed business layers. The identifiers are +for observability and product correlation only; callers must not treat them as +authorization proof. Server-only plugin-daemon settings are injected by the +runtime provider factory and therefore do not appear in this public schema. +""" + +from typing import ClassVar, Final, Literal, TypeAlias + +from pydantic import ConfigDict + +from agenton.layers import LayerConfig + + +DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID: Final[str] = "dify.execution_context" +DifyExecutionContextInvokeFrom: TypeAlias = Literal[ + "workflow_run", + "single_step", + "agent_app", + "babysit", + "fasten", +] + + +class DifyExecutionContextLayerConfig(LayerConfig): + """Public config for Dify execution identity and daemon transport context.""" + + tenant_id: str + user_id: str | None = None + app_id: str | None = None + workflow_id: str | None = None + workflow_run_id: str | None = None + node_id: str | None = None + node_execution_id: str | None = None + conversation_id: str | None = None + agent_id: str | None = None + agent_config_version_id: str | None = None + invoke_from: DifyExecutionContextInvokeFrom + trace_id: str | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + +__all__ = [ + "DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID", + "DifyExecutionContextInvokeFrom", + "DifyExecutionContextLayerConfig", +] diff --git a/dify-agent/src/dify_agent/layers/execution_context/layer.py b/dify-agent/src/dify_agent/layers/execution_context/layer.py new file mode 100644 index 0000000000..06ef41ecf4 --- /dev/null +++ b/dify-agent/src/dify_agent/layers/execution_context/layer.py @@ -0,0 +1,95 @@ +"""Runtime Dify execution-context layer. + +The public config carries Dify-owned execution identifiers plus the tenant/user +daemon context needed by plugin-backed business layers. Server-only daemon URL +and API key are injected by the provider factory. The layer is intentionally +config/settings-only under Agenton's state-only core: it does not open, cache, +close, or snapshot HTTP clients, and its lifecycle hooks remain the inherited +no-op hooks. Runtime code passes the FastAPI lifespan-owned shared +``httpx.AsyncClient`` into ``create_daemon_provider`` or ``create_tool_client`` +for each invocation. +""" + +from dataclasses import dataclass +from typing import ClassVar + +import httpx +from typing_extensions import Self, override + +from agenton.layers import EmptyRuntimeState, NoLayerDeps, PlainLayer +from dify_agent.adapters.llm import DifyPluginDaemonProvider +from dify_agent.layers.dify_plugin.tool_client import DifyPluginDaemonToolClient +from dify_agent.layers.execution_context.configs import ( + DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + DifyExecutionContextLayerConfig, +) + + +@dataclass(slots=True) +class DifyExecutionContextLayer(PlainLayer[NoLayerDeps, DifyExecutionContextLayerConfig, EmptyRuntimeState]): + """Layer that carries Dify execution context without owning live resources.""" + + type_id: ClassVar[str] = DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID + + config: DifyExecutionContextLayerConfig + daemon_url: str + daemon_api_key: str + + @classmethod + @override + def from_config(cls, config: DifyExecutionContextLayerConfig) -> Self: + """Reject construction without server-injected daemon settings.""" + del config + raise TypeError( + "DifyExecutionContextLayer requires server-side daemon settings and must use a provider factory." + ) + + @classmethod + def from_config_with_settings( + cls, + config: DifyExecutionContextLayerConfig, + *, + daemon_url: str, + daemon_api_key: str, + ) -> Self: + """Create the layer from public config plus server-only daemon settings.""" + return cls(config=config, daemon_url=daemon_url, daemon_api_key=daemon_api_key) + + def create_daemon_provider(self, *, plugin_id: str, http_client: httpx.AsyncClient) -> DifyPluginDaemonProvider: + """Return a daemon provider backed by the shared plugin daemon client. + + Raises: + RuntimeError: if ``http_client`` has already been closed. + """ + if http_client.is_closed: + raise RuntimeError( + "DifyExecutionContextLayer.create_daemon_provider() requires an open shared HTTP client." + ) + return DifyPluginDaemonProvider( + tenant_id=self.config.tenant_id, + plugin_id=plugin_id, + plugin_daemon_url=self.daemon_url, + plugin_daemon_api_key=self.daemon_api_key, + user_id=self.config.user_id, + http_client=http_client, + ) + + def create_tool_client(self, *, plugin_id: str, http_client: httpx.AsyncClient) -> DifyPluginDaemonToolClient: + """Return a plugin-daemon tool client backed by the shared HTTP client. + + Raises: + RuntimeError: if ``http_client`` has already been closed. + """ + if http_client.is_closed: + raise RuntimeError("DifyExecutionContextLayer.create_tool_client() requires an open shared HTTP client.") + return DifyPluginDaemonToolClient( + tenant_id=self.config.tenant_id, + plugin_id=plugin_id, + plugin_daemon_url=self.daemon_url, + plugin_daemon_api_key=self.daemon_api_key, + user_id=self.config.user_id, + http_client=http_client, + ) + + +__all__ = ["DifyExecutionContextLayer"] diff --git a/dify-agent/src/dify_agent/plugin_daemon_transport.py b/dify-agent/src/dify_agent/plugin_daemon_transport.py new file mode 100644 index 0000000000..dc88c3f01e --- /dev/null +++ b/dify-agent/src/dify_agent/plugin_daemon_transport.py @@ -0,0 +1,72 @@ +"""Shared plugin-daemon transport helpers. + +These helpers define the common request-payload and nested-error semantics used +by Dify Agent's LLM and tools daemon clients so the two transport adapters do +not drift when the daemon protocol evolves. +""" + +from __future__ import annotations + +import json +from typing import TypedDict + +from pydantic import BaseModel + + +class PluginDaemonErrorPayload(TypedDict): + """Decoded plugin-daemon error payload.""" + + error_type: str + message: str + + +def to_plugin_daemon_jsonable(value: object) -> object: + """Convert nested request data into JSON-safe daemon payload values.""" + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {key: to_plugin_daemon_jsonable(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [to_plugin_daemon_jsonable(item) for item in value] + return value + + +def decode_plugin_daemon_error_payload(raw_message: str) -> PluginDaemonErrorPayload | None: + """Decode one plugin-daemon JSON error payload if present.""" + try: + parsed = json.loads(raw_message) + except json.JSONDecodeError: + return None + + if not isinstance(parsed, dict): + return None + + error_type = parsed.get("error_type") + message = parsed.get("message") + if not isinstance(error_type, str) or not isinstance(message, str): + return None + return {"error_type": error_type, "message": message} + + +def unwrap_plugin_daemon_error( + *, + error_type: str, + message: str, +) -> PluginDaemonErrorPayload: + """Unwrap nested ``PluginInvokeError`` payloads to their effective error.""" + if error_type == "PluginInvokeError": + nested_error = decode_plugin_daemon_error_payload(message) + if nested_error is not None: + return unwrap_plugin_daemon_error( + error_type=nested_error["error_type"], + message=nested_error["message"], + ) + return {"error_type": error_type, "message": message} + + +__all__ = [ + "PluginDaemonErrorPayload", + "decode_plugin_daemon_error_payload", + "to_plugin_daemon_jsonable", + "unwrap_plugin_daemon_error", +] diff --git a/dify-agent/src/dify_agent/protocol/__init__.py b/dify-agent/src/dify_agent/protocol/__init__.py index d1daba54d6..2e3c959548 100644 --- a/dify-agent/src/dify_agent/protocol/__init__.py +++ b/dify-agent/src/dify_agent/protocol/__init__.py @@ -11,8 +11,6 @@ from .schemas import ( CreateRunRequest, CreateRunResponse, EmptyRunEventData, - ExecutionContext, - InvokeFrom, LayerExitSignals, PydanticAIStreamRunEvent, RunCancelledEvent, @@ -46,8 +44,6 @@ __all__ = [ "DIFY_AGENT_MODEL_LAYER_ID", "DIFY_AGENT_OUTPUT_LAYER_ID", "EmptyRunEventData", - "ExecutionContext", - "InvokeFrom", "LayerExitSignals", "PydanticAIStreamRunEvent", "RUN_EVENT_ADAPTER", diff --git a/dify-agent/src/dify_agent/protocol/schemas.py b/dify-agent/src/dify_agent/protocol/schemas.py index 6990ce7f57..9a989976c7 100644 --- a/dify-agent/src/dify_agent/protocol/schemas.py +++ b/dify-agent/src/dify_agent/protocol/schemas.py @@ -47,7 +47,6 @@ DIFY_AGENT_HISTORY_LAYER_ID: Final[str] = "history" DIFY_AGENT_OUTPUT_LAYER_ID: Final[str] = "output" RunStatus = Literal["running", "paused", "succeeded", "failed", "cancelled"] RunPurpose = Literal["workflow_node", "single_step", "agent_app", "babysit", "fasten_preview"] -InvokeFrom = Literal["workflow_run", "single_step", "agent_app", "babysit", "fasten"] RunEventType = Literal[ "run_started", "pydantic_ai_event", @@ -106,29 +105,6 @@ class RunComposition(BaseModel): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") -class ExecutionContext(BaseModel): - """Dify-owned execution identifiers attached to one Agent backend run. - - The Agent backend stores and replays this context for observability and - product correlation only. It must not use these identifiers as authorization - proof; API backend remains responsible for tenant and user access checks. - """ - - tenant_id: str - app_id: str | None = None - workflow_id: str | None = None - workflow_run_id: str | None = None - node_id: str | None = None - node_execution_id: str | None = None - conversation_id: str | None = None - agent_id: str | None = None - agent_config_version_id: str | None = None - invoke_from: InvokeFrom - trace_id: str | None = None - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - class CreateRunRequest(BaseModel): """Request body for creating one async agent run. @@ -142,11 +118,13 @@ class CreateRunRequest(BaseModel): explicitly request delete for one or more layers. Session snapshots do not preserve output-layer config, so resume requests that rely on structured output must include the same ``output`` layer in ``composition.layers[]`` to - keep snapshot compatibility and rebuild the output schema. + keep snapshot compatibility and rebuild the output schema. Dify tenant, + user, and run-correlation identifiers must be submitted through a + ``dify.execution_context`` entry in ``composition.layers[]``; there is no + parallel top-level ``execution_context`` request field. """ composition: RunComposition - execution_context: ExecutionContext | None = None purpose: RunPurpose = "workflow_node" idempotency_key: str | None = None metadata: dict[str, JsonValue] = Field(default_factory=dict) @@ -356,8 +334,6 @@ __all__ = [ "DIFY_AGENT_MODEL_LAYER_ID", "DIFY_AGENT_OUTPUT_LAYER_ID", "EmptyRunEventData", - "ExecutionContext", - "InvokeFrom", "LayerExitSignals", "PydanticAIStreamRunEvent", "RUN_EVENT_ADAPTER", diff --git a/dify-agent/src/dify_agent/runtime/compositor_factory.py b/dify-agent/src/dify_agent/runtime/compositor_factory.py index 8750dbc71d..f3cc3b37b3 100644 --- a/dify-agent/src/dify_agent/runtime/compositor_factory.py +++ b/dify-agent/src/dify_agent/runtime/compositor_factory.py @@ -2,12 +2,18 @@ Only explicitly allowed provider type ids are constructible here. The default provider set contains prompt layers, the optional pydantic-ai history layer, the -state-free Dify structured output layer, plus Dify plugin LLM layers. Public -DTOs provide tenant/plugin/model data, while server-only plugin daemon settings -are injected through the provider factory for ``DifyPluginLayer``. The resulting -``Compositor`` remains Agenton state-only: live resources such as the plugin -daemon HTTP client are supplied later by the runtime and never enter providers, -layers, or session snapshots. +state-free Dify structured output layer, the Dify execution-context layer, and +the Dify plugin business-layer family: + +- ``dify.execution_context`` for shared tenant/user/run daemon context, +- ``dify.plugin.llm`` for plugin-backed model selection, and +- ``dify.plugin.tools`` for prepared plugin tool exposure. + +Public DTOs provide Dify context plus plugin/model/tool data, while server-only +plugin daemon settings are injected through the provider factory for +``DifyExecutionContextLayer``. The resulting ``Compositor`` remains Agenton +state-only: live resources such as the plugin daemon HTTP client are supplied +later by the runtime and never enter providers, layers, or session snapshots. """ from collections.abc import Mapping, Sequence @@ -20,9 +26,10 @@ from agenton.layers.types import AllPromptTypes, AllToolTypes, AllUserPromptType from agenton_collections.layers.pydantic_ai import PydanticAIHistoryLayer from agenton_collections.layers.plain.basic import PromptLayer from agenton_collections.transformers.pydantic_ai import PYDANTIC_AI_TRANSFORMERS -from dify_agent.layers.dify_plugin.configs import DifyPluginLayerConfig from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer -from dify_agent.layers.dify_plugin.plugin_layer import DifyPluginLayer +from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer +from dify_agent.layers.execution_context.configs import DifyExecutionContextLayerConfig +from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer from dify_agent.layers.output.output_layer import DifyOutputLayer @@ -40,14 +47,15 @@ def create_default_layer_providers( LayerProvider.from_layer_type(PydanticAIHistoryLayer), LayerProvider.from_layer_type(DifyOutputLayer), LayerProvider.from_factory( - layer_type=DifyPluginLayer, - create=lambda config: DifyPluginLayer.from_config_with_settings( - DifyPluginLayerConfig.model_validate(config), + layer_type=DifyExecutionContextLayer, + create=lambda config: DifyExecutionContextLayer.from_config_with_settings( + DifyExecutionContextLayerConfig.model_validate(config), daemon_url=plugin_daemon_url, daemon_api_key=plugin_daemon_api_key, ), ), LayerProvider.from_layer_type(DifyPluginLLMLayer), + LayerProvider.from_layer_type(DifyPluginToolsLayer), ) diff --git a/dify-agent/src/dify_agent/runtime/run_scheduler.py b/dify-agent/src/dify_agent/runtime/run_scheduler.py index 5d8b229461..9dfc93b846 100644 --- a/dify-agent/src/dify_agent/runtime/run_scheduler.py +++ b/dify-agent/src/dify_agent/runtime/run_scheduler.py @@ -5,12 +5,11 @@ The scheduler is intentionally process-local: it persists a run record, starts a task registry. Redis remains the durable source for status and event streams, but there is no Redis job queue or cross-process handoff. If the process crashes, currently active runs are lost until an external operator marks or retries them. -Create-run validation enters a lightweight Agenton run before persistence so the -same transformed user prompts, temporary system-prompt history assembly, -optional structured output contract, and top-level ``on_exit`` policy used by -execution are checked without relying on removed session/control APIs; Dify's -default layers keep lifecycle hooks side-effect free so this validation does not -open plugin daemon clients. +Create-run requests are accepted once the scheduler is not stopping and storage +can persist the run record. Request-shaped execution failures are left to +``AgentRunRunner`` so bad compositions, ``on_exit`` policies, prompts, +structured-output schemas, or session snapshots become asynchronous +``run_failed`` outcomes instead of synchronous HTTP rejections. """ import asyncio @@ -21,15 +20,10 @@ from typing import Protocol import httpx from agenton.compositor import LayerProviderInput -from dify_agent.protocol.schemas import CreateRunRequest, normalize_composition -from dify_agent.runtime.agenton_validation import is_agenton_enter_validation_runtime_error -from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor, create_default_layer_providers +from dify_agent.protocol.schemas import CreateRunRequest +from dify_agent.runtime.compositor_factory import create_default_layer_providers from dify_agent.runtime.event_sink import RunEventSink, emit_run_failed -from dify_agent.runtime.history import build_run_message_history, get_history_layer, validate_history_layer_composition -from dify_agent.runtime.layer_exit_signals import apply_layer_exit_signals, validate_layer_exit_signals -from dify_agent.runtime.output_type import resolve_run_output_contract, validate_output_layer_composition from dify_agent.runtime.runner import AgentRunRunner -from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt from dify_agent.server.schemas import RunRecord logger = logging.getLogger(__name__) @@ -39,10 +33,6 @@ class SchedulerStoppingError(RuntimeError): """Raised when a create-run request arrives after shutdown has started.""" -class RunRequestValidationError(ValueError): - """Raised when a create-run request cannot produce an executable Agenton run.""" - - class RunStore(RunEventSink, Protocol): """Persistence boundary needed by the scheduler.""" @@ -68,9 +58,8 @@ class RunScheduler: ``active_tasks`` is mutated only on the event loop that calls ``create_run`` and ``shutdown``. The task registry is not durable; it exists so the lifespan hook can wait for in-flight work and mark cancelled runs failed before Redis is - closed. A lock guards the stopping flag, lightweight request validation, run - persistence, and task registration so shutdown cannot begin after a request is - admitted and no validation runs once stopping has been set. + closed. A lock guards the stopping flag, run persistence, and task + registration so shutdown cannot begin after a request is admitted. """ store: RunStore @@ -101,15 +90,16 @@ class RunScheduler: self._lifecycle_lock = asyncio.Lock() async def create_run(self, request: CreateRunRequest) -> RunRecord: - """Validate, persist, and schedule one run in the current process. + """Persist and schedule one run in the current process. The returned record is already ``running``. The background task is removed from ``active_tasks`` when it finishes, regardless of success or failure. + Request-shaped runtime failures are intentionally deferred to the runner so + callers can observe them through the normal event/status stream. """ async with self._lifecycle_lock: if self.stopping: raise SchedulerStoppingError("run scheduler is shutting down") - await validate_run_request(request, layer_providers=self.layer_providers) record = await self.store.create_run() task = asyncio.create_task(self._run_record(record, request), name=f"dify-agent-run-{record.run_id}") self.active_tasks[record.run_id] = task @@ -164,52 +154,4 @@ class RunScheduler: logger.exception("failed to mark cancelled run failed", extra={"run_id": run_id}) -async def validate_run_request( - request: CreateRunRequest, - *, - layer_providers: tuple[LayerProviderInput, ...] | None = None, -) -> None: - """Validate create-run semantics that require an entered Agenton run. - - This boundary rejects unsupported output/history-layer graph shapes, unknown - ``on_exit`` layer ids, effectively empty transformed user prompts, and known - enter-time snapshot lifecycle errors before the scheduler persists a run - record. It also exercises provider config validation, temporary - system-prompt history assembly, structured output contract construction, and - snapshot hydration without touching external services because Dify plugin - daemon clients are owned by the FastAPI lifespan, not Agenton lifecycle - hooks. - """ - resolved_layer_providers = layer_providers if layer_providers is not None else create_default_layer_providers() - entered_run = False - try: - validate_output_layer_composition(request.composition) - validate_history_layer_composition(request.composition) - graph_config, layer_configs = normalize_composition(request.composition) - compositor = build_pydantic_ai_compositor( - graph_config, - providers=resolved_layer_providers, - ) - validate_layer_exit_signals(compositor, request.on_exit) - async with compositor.enter(configs=layer_configs, session_snapshot=request.session_snapshot) as run: - entered_run = True - apply_layer_exit_signals(run, request.on_exit) - history_layer = get_history_layer(run) - _ = await build_run_message_history( - system_prompts=run.prompts, - stored_history=history_layer.message_history if history_layer is not None else (), - ) - if not has_non_blank_user_prompt(run.user_prompts): - raise RunRequestValidationError(EMPTY_USER_PROMPTS_ERROR) - _ = resolve_run_output_contract(run) - except RunRequestValidationError: - raise - except RuntimeError as exc: - if not entered_run and is_agenton_enter_validation_runtime_error(exc): - raise RunRequestValidationError(str(exc)) from exc - raise - except (KeyError, TypeError, ValueError) as exc: - raise RunRequestValidationError(str(exc)) from exc - - -__all__ = ["RunRequestValidationError", "RunScheduler", "SchedulerStoppingError", "validate_run_request"] +__all__ = ["RunScheduler", "SchedulerStoppingError"] diff --git a/dify-agent/src/dify_agent/runtime/runner.py b/dify-agent/src/dify_agent/runtime/runner.py index b1e1758928..11e99bb838 100644 --- a/dify-agent/src/dify_agent/runtime/runner.py +++ b/dify-agent/src/dify_agent/runtime/runner.py @@ -21,14 +21,17 @@ snapshot; there are no separate output or snapshot events to correlate. """ from collections.abc import AsyncIterable -from typing import cast +from collections import Counter +from typing import Any, cast import httpx from pydantic import JsonValue, TypeAdapter from pydantic_ai.messages import AgentStreamEvent from agenton.compositor import CompositorSessionSnapshot, LayerProviderInput +from agenton.layers.types import PydanticAITool from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer +from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer from dify_agent.protocol.schemas import DIFY_AGENT_MODEL_LAYER_ID, CreateRunRequest, normalize_composition from dify_agent.runtime.agent_factory import create_agent, normalize_user_input from dify_agent.runtime.agenton_validation import is_agenton_enter_validation_runtime_error @@ -149,12 +152,13 @@ class AgentRunRunner: ) llm_layer = run.get_layer(DIFY_AGENT_MODEL_LAYER_ID, DifyPluginLLMLayer) model = llm_layer.get_model(http_client=self.plugin_daemon_http_client) + tools = await _resolve_run_tools(run, http_client=self.plugin_daemon_http_client) except (KeyError, TypeError, RuntimeError, ValueError) as exc: raise AgentRunValidationError(str(exc)) from exc agent = create_agent( model, - tools=run.tools, + tools=tools, output_type=output_contract.output_type, ) result = await agent.run( @@ -180,4 +184,27 @@ def _serialize_agent_output(output: object) -> JsonValue: return cast(JsonValue, _AGENT_OUTPUT_ADAPTER.dump_python(output, mode="json")) +async def _resolve_run_tools( + run: Any, + *, + http_client: httpx.AsyncClient, +) -> list[PydanticAITool[object]]: + """Return the static compositor tools plus any Dify plugin runtime tools.""" + resolved_tools = list(cast(list[PydanticAITool[object]], run.tools)) + for slot in run.slots.values(): + layer = slot.layer + if isinstance(layer, DifyPluginToolsLayer): + resolved_tools.extend(await layer.get_tools(http_client=http_client)) + _validate_unique_tool_names(resolved_tools) + return resolved_tools + + +def _validate_unique_tool_names(tools: list[PydanticAITool[object]]) -> None: + """Reject duplicate tool names across static and dynamic tool sources.""" + duplicate_names = sorted(name for name, count in Counter(tool.name for tool in tools).items() if count > 1) + if duplicate_names: + names = ", ".join(duplicate_names) + raise ValueError(f"Agent run requires unique tool names across all layers, got duplicates: {names}.") + + __all__ = ["AgentRunRunner", "AgentRunValidationError"] diff --git a/dify-agent/src/dify_agent/runtime/user_prompt_validation.py b/dify-agent/src/dify_agent/runtime/user_prompt_validation.py index 8e8602c864..c2cda83acc 100644 --- a/dify-agent/src/dify_agent/runtime/user_prompt_validation.py +++ b/dify-agent/src/dify_agent/runtime/user_prompt_validation.py @@ -1,8 +1,8 @@ """Validation for effective user prompts produced by Agenton runs. -Validation happens after safe compositor construction and run entry so scheduler -and runner paths use the same transformed prompts as the actual pydantic-ai -input. Blank string fragments do not count as meaningful input; non-string +Validation happens after safe compositor construction and run entry so runtime +execution uses the same transformed prompts as the actual pydantic-ai input. +Blank string fragments do not count as meaningful input; non-string ``UserContent`` is treated as intentional content because rich media/message parts do not have a universal whitespace representation. """ diff --git a/dify-agent/src/dify_agent/server/routes/runs.py b/dify-agent/src/dify_agent/server/routes/runs.py index a5dff09218..1cbd9d2094 100644 --- a/dify-agent/src/dify_agent/server/routes/runs.py +++ b/dify-agent/src/dify_agent/server/routes/runs.py @@ -1,10 +1,13 @@ """FastAPI routes for asynchronous agent runs. -Controllers translate known validation and shutdown errors into HTTP status codes. -Unexpected scheduler or storage failures are intentionally left for FastAPI's -server-error handling so infrastructure problems are not reported as client input -errors. Created runs are scheduled in the current process and observed through -status polling or SSE replay backed by Redis event streams. +Controllers translate shutdown errors into HTTP status codes. Runtime request +failures are intentionally not pre-mapped here: once a request passes DTO +validation it is accepted for background execution, and bad compositions or +snapshots fail later through normal run events/status. Unexpected scheduler or +storage failures are intentionally left for FastAPI's server-error handling so +infrastructure problems are not reported as client input errors. Created runs +are scheduled in the current process and observed through status polling or SSE +replay backed by Redis event streams. """ from collections.abc import Callable @@ -21,7 +24,7 @@ from dify_agent.protocol.schemas import ( RunEventsResponse, RunStatusResponse, ) -from dify_agent.runtime.run_scheduler import RunRequestValidationError, RunScheduler, SchedulerStoppingError +from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError from dify_agent.server.sse import sse_event_stream from dify_agent.storage.redis_run_store import RedisRunStore, RunNotFoundError @@ -46,8 +49,6 @@ def create_runs_router( ) -> CreateRunResponse: try: record = await scheduler.create_run(request) - except RunRequestValidationError as exc: - raise HTTPException(status_code=422, detail=str(exc)) from exc except SchedulerStoppingError as exc: raise HTTPException(status_code=503, detail="run scheduler is shutting down") from exc return CreateRunResponse(run_id=record.run_id, status=record.status) diff --git a/dify-agent/tests/local/dify_agent/layers/dify_plugin/test_configs.py b/dify-agent/tests/local/dify_agent/layers/dify_plugin/test_configs.py index f6f84772ba..a3db61a06c 100644 --- a/dify-agent/tests/local/dify_agent/layers/dify_plugin/test_configs.py +++ b/dify-agent/tests/local/dify_agent/layers/dify_plugin/test_configs.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from pydantic import ValidationError @@ -5,55 +7,54 @@ import dify_agent.layers.dify_plugin as dify_plugin_exports from dify_agent.layers.dify_plugin import ( DifyPluginCredentialValue, DifyPluginLLMLayerConfig, - DifyPluginLayerConfig, + DifyPluginToolCredentialType, + DifyPluginToolConfig, + DifyPluginToolParameter, + DifyPluginToolParameterForm, + DifyPluginToolParameterType, + DifyPluginToolsLayerConfig, + DifyPluginToolValue, ) def test_dify_plugin_package_exports_client_safe_config_symbols_only() -> None: assert dify_plugin_exports.__all__ == [ - "DIFY_PLUGIN_LAYER_TYPE_ID", "DIFY_PLUGIN_LLM_LAYER_TYPE_ID", + "DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID", "DifyPluginCredentialValue", "DifyPluginLLMLayerConfig", - "DifyPluginLayerConfig", + "DifyPluginToolCredentialType", + "DifyPluginToolConfig", + "DifyPluginToolOption", + "DifyPluginToolParameter", + "DifyPluginToolParameterForm", + "DifyPluginToolParameterType", + "DifyPluginToolsLayerConfig", + "DifyPluginToolValue", ] - assert dify_plugin_exports.DIFY_PLUGIN_LAYER_TYPE_ID == "dify.plugin" assert dify_plugin_exports.DIFY_PLUGIN_LLM_LAYER_TYPE_ID == "dify.plugin.llm" - assert not hasattr(dify_plugin_exports, "DifyPluginLayer") + assert dify_plugin_exports.DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID == "dify.plugin.tools" assert not hasattr(dify_plugin_exports, "DifyPluginLLMLayer") -def test_dify_plugin_layer_config_forbids_runtime_settings() -> None: - config = DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="plugin-1", user_id="user-1") - - assert config.tenant_id == "tenant-1" - assert config.plugin_id == "plugin-1" - assert config.user_id == "user-1" - with pytest.raises(ValidationError): - _ = DifyPluginLayerConfig.model_validate( - { - "tenant_id": "tenant-1", - "plugin_id": "plugin-1", - "daemon_url": "http://daemon", - } - ) - - def test_dify_plugin_llm_config_accepts_scalar_credentials_and_model_settings() -> None: credential: DifyPluginCredentialValue = "secret" config = DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", model_provider="openai", model="gpt-4o-mini", credentials={"api_key": credential, "enabled": True, "retries": 2, "ratio": 0.5, "empty": None}, model_settings={"temperature": 0.2, "max_tokens": 64}, ) + assert config.plugin_id == "langgenius/openai" assert config.model_provider == "openai" assert config.credentials == {"api_key": "secret", "enabled": True, "retries": 2, "ratio": 0.5, "empty": None} assert config.model_settings == {"temperature": 0.2, "max_tokens": 64} with pytest.raises(ValidationError): _ = DifyPluginLLMLayerConfig.model_validate( { + "plugin_id": "langgenius/openai", "model_provider": "openai", "model": "gpt-4o-mini", "credentials": {"nested": {"not": "allowed"}}, @@ -66,6 +67,154 @@ def test_dify_plugin_llm_config_rejects_old_provider_field() -> None: _ = DifyPluginLLMLayerConfig.model_validate( { "provider": "openai", + "plugin_id": "langgenius/openai", "model": "gpt-4o-mini", } ) + + +def test_dify_plugin_tools_layer_config_accepts_prepared_parameters_and_schema() -> None: + runtime_value: DifyPluginToolValue = {"locale": "en-US", "max_results": 5} + credential_type: DifyPluginToolCredentialType = "api-key" + config = DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search", + credential_type=credential_type, + name="search_web", + description="Search the web.", + credentials={"api_key": "secret"}, + runtime_parameters={"settings": runtime_value}, + parameters=[ + DifyPluginToolParameter( + name="query", + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Search query", + ) + ], + parameters_json_schema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + }, + ) + ] + ) + + assert config.tools[0].plugin_id == "langgenius/tools" + assert config.tools[0].provider == "search" + assert config.tools[0].tool_name == "web_search" + assert config.tools[0].credential_type == "api-key" + assert config.tools[0].name == "search_web" + assert config.tools[0].runtime_parameters == {"settings": {"locale": "en-US", "max_results": 5}} + assert config.tools[0].parameters[0].name == "query" + assert config.tools[0].parameters_json_schema["required"] == ["query"] + + +def test_dify_plugin_tool_parameter_accepts_api_tool_parameter_dump_shape() -> None: + parameter = DifyPluginToolParameter.model_validate( + { + "name": "query", + "label": {"en_US": "Query"}, + "placeholder": None, + "human_description": {"en_US": "Visible in UI"}, + "type": "select", + "form": "llm", + "required": True, + "default": "dify", + "llm_description": "Search query", + "input_schema": {"type": "string"}, + "options": [ + { + "value": "dify", + "label": {"en_US": "Dify"}, + } + ], + } + ) + + assert parameter.name == "query" + assert parameter.type is DifyPluginToolParameterType.SELECT + assert parameter.form is DifyPluginToolParameterForm.LLM + assert parameter.required is True + assert parameter.default == "dify" + assert parameter.input_schema == {"type": "string"} + assert [option.value for option in parameter.options] == ["dify"] + + +def test_dify_plugin_tool_parameter_accepts_api_tool_parameter_attributes() -> None: + parameter = DifyPluginToolParameter.model_validate( + SimpleNamespace( + name="language", + label=SimpleNamespace(en_US="Language"), + type="string", + form="form", + required=False, + default="en", + llm_description=None, + input_schema=None, + options=[SimpleNamespace(value="en", label=SimpleNamespace(en_US="English"))], + ) + ) + + assert parameter.name == "language" + assert parameter.type is DifyPluginToolParameterType.STRING + assert parameter.form is DifyPluginToolParameterForm.FORM + assert parameter.default == "en" + assert [option.value for option in parameter.options] == ["en"] + + +def test_dify_plugin_tool_config_rejects_non_json_runtime_parameters() -> None: + with pytest.raises(ValidationError): + _ = DifyPluginToolConfig.model_validate( + { + "plugin_id": "langgenius/tools", + "provider": "search", + "tool_name": "web_search", + "credential_type": "api-key", + "runtime_parameters": {"bad": object()}, + } + ) + + +def test_dify_plugin_tool_config_rejects_non_json_schema_values() -> None: + with pytest.raises(ValidationError): + _ = DifyPluginToolConfig.model_validate( + { + "plugin_id": "langgenius/tools", + "provider": "search", + "tool_name": "web_search", + "credential_type": "api-key", + "parameters_json_schema": {"type": object()}, + } + ) + + +def test_dify_plugin_tool_config_rejects_strict_flag() -> None: + with pytest.raises(ValidationError): + _ = DifyPluginToolConfig.model_validate( + { + "plugin_id": "langgenius/tools", + "provider": "search", + "tool_name": "web_search", + "credential_type": "api-key", + "strict": True, + } + ) + + +def test_dify_plugin_tool_config_requires_explicit_credential_type() -> None: + with pytest.raises(ValidationError): + _ = DifyPluginToolConfig.model_validate( + { + "plugin_id": "langgenius/tools", + "provider": "search", + "tool_name": "web_search", + } + ) diff --git a/dify-agent/tests/local/dify_agent/layers/dify_plugin/test_layers.py b/dify-agent/tests/local/dify_agent/layers/dify_plugin/test_layers.py index 78c833d946..515e187ef3 100644 --- a/dify-agent/tests/local/dify_agent/layers/dify_plugin/test_layers.py +++ b/dify-agent/tests/local/dify_agent/layers/dify_plugin/test_layers.py @@ -1,26 +1,36 @@ import asyncio +import json import httpx import pytest +from pydantic import JsonValue from agenton.compositor import Compositor, LayerNode, LayerProvider from dify_agent.adapters.llm import DifyLLMAdapterModel from dify_agent.layers.dify_plugin.configs import ( - DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID, + DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, DifyPluginLLMLayerConfig, - DifyPluginLayerConfig, + DifyPluginToolConfig, + DifyPluginToolOption, + DifyPluginToolParameter, + DifyPluginToolParameterForm, + DifyPluginToolParameterType, + DifyPluginToolsLayerConfig, ) from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer -from dify_agent.layers.dify_plugin.plugin_layer import DifyPluginLayer +from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer +from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig +from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer -def _plugin_config() -> DifyPluginLayerConfig: - return DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai", user_id="user-1") +def _execution_context_config() -> DifyExecutionContextLayerConfig: + return DifyExecutionContextLayerConfig(tenant_id="tenant-1", user_id="user-1", invoke_from="workflow_run") def _llm_config() -> DifyPluginLLMLayerConfig: return DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", model_provider="openai", model="demo-model", credentials={"api_key": "secret"}, @@ -28,82 +38,192 @@ def _llm_config() -> DifyPluginLLMLayerConfig: ) -def _plugin_layer() -> DifyPluginLayer: - return DifyPluginLayer.from_config_with_settings( - _plugin_config(), - daemon_url="http://plugin-daemon", - daemon_api_key="daemon-secret", +def _tools_config() -> DifyPluginToolsLayerConfig: + return DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search", + credential_type="api-key", + description="Search the web.", + credentials={"api_key": "secret"}, + runtime_parameters={"api_version": "2026-01", "auth_scope": "workspace"}, + parameters=_prepared_tool_parameters(), + parameters_json_schema=_prepared_tool_schema(), + ) + ] ) -def _plugin_provider() -> LayerProvider[DifyPluginLayer]: +def _missing_hidden_parameter_tools_config() -> DifyPluginToolsLayerConfig: + return DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search", + credential_type="api-key", + description="Search the web.", + credentials={"api_key": "secret"}, + runtime_parameters={"api_version": "2026-01"}, + parameters=_prepared_tool_parameters(), + parameters_json_schema=_prepared_tool_schema(), + ) + ] + ) + + +def _execution_context_provider() -> LayerProvider[DifyExecutionContextLayer]: return LayerProvider.from_factory( - layer_type=DifyPluginLayer, - create=lambda config: DifyPluginLayer.from_config_with_settings( - DifyPluginLayerConfig.model_validate(config), + layer_type=DifyExecutionContextLayer, + create=lambda config: DifyExecutionContextLayer.from_config_with_settings( + DifyExecutionContextLayerConfig.model_validate(config), daemon_url="http://plugin-daemon", daemon_api_key="daemon-secret", ), ) +def _prepared_tool_parameters() -> list[DifyPluginToolParameter]: + return [ + DifyPluginToolParameter( + name="query", + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Search query", + ), + DifyPluginToolParameter( + name="region", + type=DifyPluginToolParameterType.SELECT, + form=DifyPluginToolParameterForm.LLM, + required=False, + llm_description="Search region", + options=[DifyPluginToolOption(value="global"), DifyPluginToolOption(value="cn")], + ), + DifyPluginToolParameter( + name="api_version", + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.FORM, + required=True, + llm_description="Hidden API version", + ), + DifyPluginToolParameter( + name="auth_scope", + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.FORM, + required=True, + llm_description="Hidden auth scope", + ), + ] + + +def _prepared_tool_schema() -> dict[str, JsonValue]: + return { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "region": { + "type": "string", + "description": "Search region", + "enum": ["global", "cn"], + }, + }, + "required": ["query"], + } + + +def _llm_only_parameter(*, name: str, description: str, default: JsonValue = None) -> DifyPluginToolParameter: + return DifyPluginToolParameter( + name=name, + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.LLM, + required=default is None, + default=default, + llm_description=description, + ) + + +def _invoke_stream_response( + *, + error_payload: dict[str, object] | None = None, + chunked_blob: bool = False, +) -> httpx.Response: + if error_payload is not None: + return httpx.Response(400, json=error_payload) + + if chunked_blob: + stream_payload = "\n".join( + [ + f"data: {json.dumps({'code': 0, 'message': 'ok', 'data': {'type': 'blob_chunk', 'message': {'id': 'blob-1', 'sequence': 0, 'total_length': 11, 'blob': 'aGVsbG8g', 'end': False}}})}", + f"data: {json.dumps({'code': 0, 'message': 'ok', 'data': {'type': 'blob_chunk', 'message': {'id': 'blob-1', 'sequence': 1, 'total_length': 11, 'blob': 'd29ybGQ=', 'end': True}}})}", + "", + ] + ) + return httpx.Response(200, text=stream_payload) + + stream_payload = "\n".join( + [ + f"data: {json.dumps({'code': 0, 'message': 'ok', 'data': {'type': 'text', 'message': {'text': 'found '}}})}", + f"data: {json.dumps({'code': 0, 'message': 'ok', 'data': {'type': 'json', 'message': {'json_object': {'count': 1}}}})}", + "", + ] + ) + return httpx.Response(200, text=stream_payload) + + +def _tool_transport( + *, + invoke_error_payload: dict[str, object] | None = None, + chunked_blob: bool = False, +) -> httpx.MockTransport: + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path.endswith("/dispatch/tool/invoke"): + payload = json.loads(request.content.decode("utf-8")) + assert payload["user_id"] == "user-1" + assert payload["data"]["provider"] == "search" + assert payload["data"]["tool"] == "web_search" + assert payload["data"]["credential_type"] == "api-key" + assert payload["data"]["tool_parameters"] == { + "query": "dify", + "region": "global", + "api_version": "2026-01", + "auth_scope": "workspace", + } + return _invoke_stream_response(error_payload=invoke_error_payload, chunked_blob=chunked_blob) + + raise AssertionError(f"Unexpected request path: {request.url.path}") + + return httpx.MockTransport(handler) + + def test_dify_plugin_type_id_constants_match_implementation_classes() -> None: - assert DIFY_PLUGIN_LAYER_TYPE_ID == DifyPluginLayer.type_id assert DIFY_PLUGIN_LLM_LAYER_TYPE_ID == DifyPluginLLMLayer.type_id - - -def test_dify_plugin_layer_creates_daemon_provider_from_shared_http_client() -> None: - async def scenario() -> None: - plugin = _plugin_layer() - async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client: - provider = plugin.create_daemon_provider(http_client=client) - - assert provider.name == "DifyPlugin/langgenius/openai" - assert provider.client.http_client is client - assert provider.client.tenant_id == "tenant-1" - assert provider.client.plugin_id == "langgenius/openai" - assert provider.client.user_id == "user-1" - - async with provider: - pass - assert client.is_closed is False - - asyncio.run(scenario()) - - -def test_dify_plugin_layer_rejects_closed_shared_http_client() -> None: - async def scenario() -> None: - plugin = _plugin_layer() - client = httpx.AsyncClient() - await client.aclose() - - with pytest.raises(RuntimeError, match="open shared HTTP client"): - _ = plugin.create_daemon_provider(http_client=client) - - asyncio.run(scenario()) + assert DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID == DifyPluginToolsLayer.type_id def test_dify_plugin_llm_layer_builds_adapter_model_from_direct_dependency() -> None: async def scenario() -> None: compositor = Compositor( [ - LayerNode("renamed-plugin", _plugin_provider()), - LayerNode("llm", DifyPluginLLMLayer, deps={"plugin": "renamed-plugin"}), + LayerNode("renamed-execution-context", _execution_context_provider()), + LayerNode("llm", DifyPluginLLMLayer, deps={"execution_context": "renamed-execution-context"}), ] ) async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client: async with compositor.enter( configs={ - "renamed-plugin": _plugin_config(), + "renamed-execution-context": _execution_context_config(), "llm": _llm_config(), } ) as run: - plugin = run.get_layer("renamed-plugin", DifyPluginLayer) + execution_context = run.get_layer("renamed-execution-context", DifyExecutionContextLayer) llm = run.get_layer("llm", DifyPluginLLMLayer) model = llm.get_model(http_client=client) - assert llm.deps.plugin is plugin + assert llm.deps.execution_context is execution_context assert isinstance(model, DifyLLMAdapterModel) assert model.model_name == "demo-model" assert model.model_provider == "openai" @@ -114,17 +234,436 @@ def test_dify_plugin_llm_layer_builds_adapter_model_from_direct_dependency() -> asyncio.run(scenario()) -def test_dify_plugin_layer_lifecycle_does_not_manage_http_client() -> None: +def test_dify_plugin_tools_layer_uses_prepared_tool_definition_and_invokes_daemon() -> None: async def scenario() -> None: - compositor = Compositor([LayerNode("plugin", _plugin_provider())]) - async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client: - async with compositor.enter(configs={"plugin": _plugin_config()}) as run: - plugin = run.get_layer("plugin", DifyPluginLayer) - provider = plugin.create_daemon_provider(http_client=client) - run.suspend_layer_on_exit("plugin") + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=_tool_transport()) as client: + async with compositor.enter( + configs={"execution_context": _execution_context_config(), "tools": _tools_config()} + ) as run: + tools_layer = run.get_layer("tools", DifyPluginToolsLayer) + tool = (await tools_layer.get_tools(http_client=client))[0] - assert run.session_snapshot is not None - assert provider.client.http_client is client - assert client.is_closed is False + tool_def = await tool.prepare_tool_def(None) # pyright: ignore[reportArgumentType] + result = await tool.function_schema.call( + {"query": "dify", "region": "global"}, + None, # pyright: ignore[reportArgumentType] + ) + + assert tool.name == "web_search" + assert tool.description == "Search the web." + assert tool_def is not None + assert tool_def.parameters_json_schema == _prepared_tool_schema() + assert tool_def.strict is False + assert result == 'found {"count": 1}' + + asyncio.run(scenario()) + + +def test_dify_plugin_tools_layer_uses_each_tool_plugin_id_for_transport() -> None: + async def scenario() -> None: + seen_requests: list[tuple[str, str, str, str]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path.endswith("/dispatch/tool/invoke"): + payload = json.loads(request.content.decode("utf-8")) + seen_requests.append( + ( + request.headers["X-Plugin-ID"], + payload["user_id"], + payload["data"]["provider"], + payload["data"]["tool"], + ) + ) + return _invoke_stream_response() + + raise AssertionError(f"Unexpected request path: {request.url.path}") + + tools_config = DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools-a", + provider="search-a", + tool_name="web_search_a", + credential_type="api-key", + parameters=[_llm_only_parameter(name="query", description="Search query A")], + parameters_json_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "Search query A"}}, + "required": ["query"], + }, + ), + DifyPluginToolConfig( + plugin_id="langgenius/tools-b", + provider="search-b", + tool_name="web_search_b", + credential_type="api-key", + parameters=[_llm_only_parameter(name="query", description="Search query B")], + parameters_json_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "Search query B"}}, + "required": ["query"], + }, + ), + ] + ) + + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + async with compositor.enter( + configs={"execution_context": _execution_context_config(), "tools": tools_config} + ) as run: + tools = await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client) + + await tools[0].function_schema.call({"query": "first"}, None) # pyright: ignore[reportArgumentType] + await tools[1].function_schema.call({"query": "second"}, None) # pyright: ignore[reportArgumentType] + + assert seen_requests == [ + ("langgenius/tools-a", "user-1", "search-a", "web_search_a"), + ("langgenius/tools-b", "user-1", "search-b", "web_search_b"), + ] + + asyncio.run(scenario()) + + +def test_dify_plugin_tools_layer_casts_prepared_parameter_values_before_invocation() -> None: + async def scenario() -> None: + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path.endswith("/dispatch/tool/invoke"): + payload = json.loads(request.content.decode("utf-8")) + assert payload["user_id"] == "user-1" + assert payload["data"]["tool_parameters"] == { + "enabled": True, + "count": 7, + "tags": ["a", "b"], + "metadata": {"source": "docs"}, + "model": {"provider": "openai", "model": "gpt-4o-mini"}, + } + return _invoke_stream_response() + + raise AssertionError(f"Unexpected request path: {request.url.path}") + + tools_config = DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search", + credential_type="api-key", + parameters=[ + DifyPluginToolParameter( + name="enabled", + type=DifyPluginToolParameterType.BOOLEAN, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Enable search", + ), + DifyPluginToolParameter( + name="count", + type=DifyPluginToolParameterType.NUMBER, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Result count", + ), + DifyPluginToolParameter( + name="tags", + type=DifyPluginToolParameterType.ARRAY, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Tags", + input_schema={"type": "array", "items": {"type": "string"}}, + ), + DifyPluginToolParameter( + name="metadata", + type=DifyPluginToolParameterType.OBJECT, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Metadata", + input_schema={"type": "object", "additionalProperties": True}, + ), + DifyPluginToolParameter( + name="model", + type=DifyPluginToolParameterType.MODEL_SELECTOR, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Model selector", + input_schema={"type": "object", "additionalProperties": True}, + ), + ], + parameters_json_schema={ + "type": "object", + "properties": { + "enabled": {"type": "boolean", "description": "Enable search"}, + "count": {"type": "number", "description": "Result count"}, + "tags": {"type": "array", "items": {"type": "string"}, "description": "Tags"}, + "metadata": { + "type": "object", + "additionalProperties": True, + "description": "Metadata", + }, + "model": { + "type": "object", + "additionalProperties": True, + "description": "Model selector", + }, + }, + "required": ["enabled", "count", "tags", "metadata", "model"], + }, + ) + ] + ) + + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + async with compositor.enter( + configs={"execution_context": _execution_context_config(), "tools": tools_config} + ) as run: + tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0] + + result = await tool.function_schema.call( + { + "enabled": "yes", + "count": "7", + "tags": '["a", "b"]', + "metadata": '{"source": "docs"}', + "model": {"provider": "openai", "model": "gpt-4o-mini"}, + }, + None, # pyright: ignore[reportArgumentType] + ) + + assert result == 'found {"count": 1}' + + asyncio.run(scenario()) + + +def test_dify_plugin_tools_layer_sends_prepared_parameter_defaults_to_daemon() -> None: + async def scenario() -> None: + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path.endswith("/dispatch/tool/invoke"): + payload = json.loads(request.content.decode("utf-8")) + assert payload["data"]["tool_parameters"] == { + "query": "dify", + "region": "global", + } + return _invoke_stream_response() + + raise AssertionError(f"Unexpected request path: {request.url.path}") + + tools_config = DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search", + credential_type="api-key", + parameters=[ + _llm_only_parameter(name="query", description="Search query"), + _llm_only_parameter(name="region", description="Search region", default="global"), + ], + parameters_json_schema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "region": {"type": "string", "description": "Search region"}, + }, + "required": ["query"], + }, + ) + ] + ) + + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + async with compositor.enter( + configs={"execution_context": _execution_context_config(), "tools": tools_config} + ) as run: + tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0] + + result = await tool.function_schema.call( + {"query": "dify"}, + None, # pyright: ignore[reportArgumentType] + ) + + assert result == 'found {"count": 1}' + + asyncio.run(scenario()) + + +def test_dify_plugin_tools_layer_requires_hidden_runtime_parameters_in_prepared_config() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=_tool_transport()) as client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "tools": _missing_hidden_parameter_tools_config(), + } + ) as run: + with pytest.raises(ValueError, match="requires non-LLM runtime_parameters for: auth_scope"): + await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client) + + asyncio.run(scenario()) + + +def test_dify_plugin_tools_layer_returns_agent_friendly_error_text() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient( + transport=_tool_transport( + invoke_error_payload={ + "error_type": "PluginDaemonBadRequestError", + "message": "missing query", + } + ) + ) as client: + async with compositor.enter( + configs={"execution_context": _execution_context_config(), "tools": _tools_config()} + ) as run: + tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0] + result = await tool.function_schema.call( + {"query": "dify", "region": "global"}, + None, # pyright: ignore[reportArgumentType] + ) + + assert result == "tool parameters validation error: missing query, please check your tool parameters" + + asyncio.run(scenario()) + + +def test_dify_plugin_tools_layer_propagates_unexpected_transport_errors() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path.endswith("/dispatch/tool/invoke"): + raise RuntimeError("unexpected transport failure") + + raise AssertionError(f"Unexpected request path: {request.url.path}") + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + async with compositor.enter( + configs={"execution_context": _execution_context_config(), "tools": _tools_config()} + ) as run: + tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0] + + with pytest.raises(RuntimeError, match="unexpected transport failure"): + await tool.function_schema.call( + {"query": "dify", "region": "global"}, + None, # pyright: ignore[reportArgumentType] + ) + + asyncio.run(scenario()) + + +@pytest.mark.parametrize( + ("invoke_error_payload", "expected_text"), + [ + ( + { + "error_type": "PluginInvokeError", + "message": json.dumps( + { + "error_type": "PluginDaemonUnauthorizedError", + "message": "invalid api key", + } + ), + }, + "Please check your tool provider credentials", + ), + ( + { + "error_type": "PluginInvokeError", + "message": json.dumps( + { + "error_type": "ToolNotFoundError", + "message": "missing plugin tool", + } + ), + }, + "there is not a tool named web_search", + ), + ], +) +def test_dify_plugin_tools_layer_maps_nested_plugin_invoke_errors_to_agent_text( + invoke_error_payload: dict[str, object], + expected_text: str, +) -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=_tool_transport(invoke_error_payload=invoke_error_payload)) as client: + async with compositor.enter( + configs={"execution_context": _execution_context_config(), "tools": _tools_config()} + ) as run: + tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0] + result = await tool.function_schema.call( + {"query": "dify", "region": "global"}, + None, # pyright: ignore[reportArgumentType] + ) + + assert result == expected_text + + asyncio.run(scenario()) + + +def test_dify_plugin_tools_layer_merges_blob_chunks_before_observation_conversion() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=_tool_transport(chunked_blob=True)) as client: + async with compositor.enter( + configs={"execution_context": _execution_context_config(), "tools": _tools_config()} + ) as run: + tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0] + result = await tool.function_schema.call( + {"query": "dify", "region": "global"}, + None, # pyright: ignore[reportArgumentType] + ) + + assert "hello world" in result + assert "sequence=0" not in result asyncio.run(scenario()) diff --git a/dify-agent/tests/local/dify_agent/layers/execution_context/test_configs.py b/dify-agent/tests/local/dify_agent/layers/execution_context/test_configs.py new file mode 100644 index 0000000000..691a483b65 --- /dev/null +++ b/dify-agent/tests/local/dify_agent/layers/execution_context/test_configs.py @@ -0,0 +1,47 @@ +import pytest +from pydantic import ValidationError + +import dify_agent.layers.execution_context as execution_context_exports +from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig + + +def test_execution_context_package_exports_client_safe_config_symbols_only() -> None: + assert execution_context_exports.__all__ == [ + "DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID", + "DifyExecutionContextInvokeFrom", + "DifyExecutionContextLayerConfig", + ] + assert execution_context_exports.DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID == "dify.execution_context" + assert not hasattr(execution_context_exports, "DifyExecutionContextLayer") + + +def test_execution_context_layer_config_forbids_runtime_settings_and_unknown_fields() -> None: + config = DifyExecutionContextLayerConfig( + tenant_id="tenant-1", + user_id="user-1", + workflow_id="workflow-1", + invoke_from="workflow_run", + ) + + assert config.tenant_id == "tenant-1" + assert config.user_id == "user-1" + assert config.workflow_id == "workflow-1" + assert config.invoke_from == "workflow_run" + + with pytest.raises(ValidationError): + _ = DifyExecutionContextLayerConfig.model_validate( + { + "tenant_id": "tenant-1", + "invoke_from": "workflow_run", + "daemon_url": "http://daemon", + } + ) + + with pytest.raises(ValidationError): + _ = DifyExecutionContextLayerConfig.model_validate( + { + "tenant_id": "tenant-1", + "invoke_from": "workflow_run", + "unknown": "value", + } + ) diff --git a/dify-agent/tests/local/dify_agent/layers/execution_context/test_layer.py b/dify-agent/tests/local/dify_agent/layers/execution_context/test_layer.py new file mode 100644 index 0000000000..6757cf7d7b --- /dev/null +++ b/dify-agent/tests/local/dify_agent/layers/execution_context/test_layer.py @@ -0,0 +1,107 @@ +import asyncio + +import httpx +import pytest + +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig +from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer + + +def _execution_context_layer() -> DifyExecutionContextLayer: + return DifyExecutionContextLayer.from_config_with_settings( + DifyExecutionContextLayerConfig(tenant_id="tenant-1", user_id="user-1", invoke_from="workflow_run"), + daemon_url="http://plugin-daemon", + daemon_api_key="daemon-secret", + ) + + +def test_execution_context_type_id_constant_matches_implementation_class() -> None: + assert DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID == DifyExecutionContextLayer.type_id + + +def test_execution_context_layer_creates_daemon_provider_from_shared_http_client() -> None: + async def scenario() -> None: + execution_context = _execution_context_layer() + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client: + provider = execution_context.create_daemon_provider(plugin_id="langgenius/openai", http_client=client) + + assert provider.name == "DifyPlugin/langgenius/openai" + assert provider.client.http_client is client + assert provider.client.tenant_id == "tenant-1" + assert provider.client.plugin_id == "langgenius/openai" + assert provider.client.user_id == "user-1" + + async with provider: + pass + assert client.is_closed is False + + asyncio.run(scenario()) + + +def test_execution_context_layer_creates_tool_client_from_shared_http_client() -> None: + async def scenario() -> None: + execution_context = _execution_context_layer() + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client: + tool_client = execution_context.create_tool_client(plugin_id="langgenius/tools", http_client=client) + + assert tool_client.http_client is client + assert tool_client.tenant_id == "tenant-1" + assert tool_client.user_id == "user-1" + assert tool_client.plugin_id == "langgenius/tools" + assert tool_client.plugin_daemon_url == "http://plugin-daemon" + assert tool_client.plugin_daemon_api_key == "daemon-secret" + assert client.is_closed is False + + asyncio.run(scenario()) + + +def test_execution_context_layer_rejects_closed_shared_http_client() -> None: + async def scenario() -> None: + execution_context = _execution_context_layer() + client = httpx.AsyncClient() + await client.aclose() + + with pytest.raises(RuntimeError, match="open shared HTTP client"): + _ = execution_context.create_daemon_provider(plugin_id="langgenius/openai", http_client=client) + with pytest.raises(RuntimeError, match="open shared HTTP client"): + _ = execution_context.create_tool_client(plugin_id="langgenius/tools", http_client=client) + + asyncio.run(scenario()) + + +def test_execution_context_layer_lifecycle_does_not_manage_http_client() -> None: + from agenton.compositor import Compositor, LayerNode, LayerProvider + + provider = LayerProvider.from_factory( + layer_type=DifyExecutionContextLayer, + create=lambda config: DifyExecutionContextLayer.from_config_with_settings( + DifyExecutionContextLayerConfig.model_validate(config), + daemon_url="http://plugin-daemon", + daemon_api_key="daemon-secret", + ), + ) + + async def scenario() -> None: + compositor = Compositor([LayerNode("execution_context", provider)]) + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client: + async with compositor.enter( + configs={ + "execution_context": DifyExecutionContextLayerConfig( + tenant_id="tenant-1", + user_id="user-1", + invoke_from="workflow_run", + ) + } + ) as run: + execution_context = run.get_layer("execution_context", DifyExecutionContextLayer) + daemon_provider = execution_context.create_daemon_provider( + plugin_id="langgenius/openai", + http_client=client, + ) + run.suspend_layer_on_exit("execution_context") + + assert run.session_snapshot is not None + assert daemon_provider.client.http_client is client + assert client.is_closed is False + + asyncio.run(scenario()) diff --git a/dify-agent/tests/local/dify_agent/protocol/test_protocol_schemas.py b/dify-agent/tests/local/dify_agent/protocol/test_protocol_schemas.py index ce39511302..e64eb4953f 100644 --- a/dify-agent/tests/local/dify_agent/protocol/test_protocol_schemas.py +++ b/dify-agent/tests/local/dify_agent/protocol/test_protocol_schemas.py @@ -6,13 +6,13 @@ from agenton.compositor import CompositorSessionSnapshot from agenton.layers import ExitIntent from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig import dify_agent.protocol as protocol_exports -from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig +from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LLM_LAYER_TYPE_ID, DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID from dify_agent.protocol.schemas import ( RUN_EVENT_ADAPTER, CreateRunRequest, - ExecutionContext, LayerExitSignals, PydanticAIStreamRunEvent, RunCancelledEvent, @@ -28,7 +28,14 @@ from dify_agent.protocol.schemas import ( RunSucceededEventData, normalize_composition, ) -from dify_agent.layers.dify_plugin.configs import DifyPluginLLMLayerConfig, DifyPluginLayerConfig +from dify_agent.layers.dify_plugin.configs import ( + DifyPluginLLMLayerConfig, + DifyPluginToolConfig, + DifyPluginToolParameter, + DifyPluginToolParameterForm, + DifyPluginToolParameterType, + DifyPluginToolsLayerConfig, +) def test_run_event_adapter_round_trips_typed_variants() -> None: @@ -87,10 +94,23 @@ def test_create_run_request_rejects_old_compositor_payload_and_model_layer_id_is ) +def test_protocol_package_no_longer_exports_execution_context_dto() -> None: + assert not hasattr(protocol_exports, "ExecutionContext") + + def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_graph_config() -> None: prompt_config = PromptLayerConfig(prefix="system", user="hello") - plugin_config = DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai") + execution_context_config = DifyExecutionContextLayerConfig( + tenant_id="tenant-1", + workflow_id="workflow-1", + workflow_run_id="workflow-run-1", + node_id="node-1", + node_execution_id="node-execution-1", + invoke_from="workflow_run", + trace_id="trace-1", + ) llm_config = DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", model_provider="openai", model="demo-model", credentials={"api_key": "secret"}, @@ -104,26 +124,21 @@ def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_ } ) request = CreateRunRequest( - execution_context=ExecutionContext( - tenant_id="tenant-1", - workflow_id="workflow-1", - workflow_run_id="workflow-run-1", - node_id="node-1", - node_execution_id="node-execution-1", - invoke_from="workflow_run", - trace_id="trace-1", - ), purpose="workflow_node", idempotency_key="workflow-run-1:node-execution-1", metadata={"source": "unit_test"}, composition=RunComposition( layers=[ RunLayerSpec(name="prompt", type=PLAIN_PROMPT_LAYER_TYPE_ID, config=prompt_config), - RunLayerSpec(name="plugin", type=DIFY_PLUGIN_LAYER_TYPE_ID, config=plugin_config), + RunLayerSpec( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=execution_context_config, + ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID, - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=llm_config, ), RunLayerSpec( @@ -138,8 +153,9 @@ def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_ graph_config, layer_configs = normalize_composition(request.composition) payload = request.model_dump(mode="json") - assert payload["execution_context"] == { + assert payload["composition"]["layers"][1]["config"] == { "tenant_id": "tenant-1", + "user_id": None, "app_id": None, "workflow_id": "workflow-1", "workflow_run_id": "workflow-run-1", @@ -157,11 +173,16 @@ def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_ assert payload["composition"]["layers"][0]["config"] == {"prefix": "system", "user": "hello", "suffix": []} assert [layer.model_dump(mode="json") for layer in graph_config.layers] == [ {"name": "prompt", "type": PLAIN_PROMPT_LAYER_TYPE_ID, "deps": {}, "metadata": {}}, - {"name": "plugin", "type": DIFY_PLUGIN_LAYER_TYPE_ID, "deps": {}, "metadata": {}}, + { + "name": "execution_context", + "type": DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + "deps": {}, + "metadata": {}, + }, { "name": DIFY_AGENT_MODEL_LAYER_ID, "type": DIFY_PLUGIN_LLM_LAYER_TYPE_ID, - "deps": {"plugin": "plugin"}, + "deps": {"execution_context": "execution_context"}, "metadata": {}, }, { @@ -173,12 +194,118 @@ def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_ ] assert layer_configs == { "prompt": prompt_config, - "plugin": plugin_config, + "execution_context": execution_context_config, DIFY_AGENT_MODEL_LAYER_ID: llm_config, DIFY_AGENT_OUTPUT_LAYER_ID: output_config, } +def test_create_run_request_accepts_plugin_tools_layer_with_prepared_parameters_and_schema() -> None: + request = CreateRunRequest.model_validate( + { + "composition": { + "layers": [ + {"name": "prompt", "type": PLAIN_PROMPT_LAYER_TYPE_ID, "config": {"user": "hello"}}, + { + "name": "execution_context", + "type": DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + "config": {"tenant_id": "tenant-1", "invoke_from": "workflow_run"}, + }, + { + "name": DIFY_AGENT_MODEL_LAYER_ID, + "type": DIFY_PLUGIN_LLM_LAYER_TYPE_ID, + "deps": {"execution_context": "execution_context"}, + "config": { + "plugin_id": "langgenius/openai", + "model_provider": "openai", + "model": "demo-model", + "credentials": {"api_key": "secret"}, + }, + }, + { + "name": "tools", + "type": DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + "deps": {"execution_context": "execution_context"}, + "config": { + "tools": [ + { + "plugin_id": "langgenius/search", + "provider": "search", + "tool_name": "web_search", + "credential_type": "api-key", + "runtime_parameters": {"site": "docs.dify.ai"}, + "parameters": [ + { + "name": "query", + "type": "string", + "form": "llm", + "required": True, + "llm_description": "Search query", + }, + { + "name": "site", + "type": "string", + "form": "form", + "required": True, + "llm_description": "Hidden site", + }, + ], + "parameters_json_schema": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Search query"}}, + "required": ["query"], + }, + } + ] + }, + }, + ] + } + } + ) + + graph_config, layer_configs = normalize_composition(request.composition) + + assert [layer.type for layer in graph_config.layers] == [ + PLAIN_PROMPT_LAYER_TYPE_ID, + DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + DIFY_PLUGIN_LLM_LAYER_TYPE_ID, + DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + ] + assert DifyPluginToolsLayerConfig.model_validate(layer_configs["tools"]) == DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/search", + provider="search", + tool_name="web_search", + credential_type="api-key", + runtime_parameters={"site": "docs.dify.ai"}, + parameters=[ + DifyPluginToolParameter( + name="query", + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Search query", + ), + DifyPluginToolParameter( + name="site", + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.FORM, + required=True, + llm_description="Hidden site", + ), + ], + parameters_json_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "Search query"}}, + "required": ["query"], + }, + ) + ] + ) + + def test_on_exit_default_to_suspend_and_are_public() -> None: assert protocol_exports.LayerExitSignals is LayerExitSignals assert protocol_exports.RunComposition is RunComposition @@ -206,13 +333,12 @@ def test_on_exit_accept_layer_overrides() -> None: assert request.on_exit.layers == {"prompt": ExitIntent.SUSPEND, "llm": ExitIntent.DELETE} -def test_execution_context_rejects_unknown_fields() -> None: +def test_create_run_request_rejects_removed_top_level_execution_context() -> None: with pytest.raises(ValidationError): - _ = ExecutionContext.model_validate( + _ = CreateRunRequest.model_validate( { - "tenant_id": "tenant-1", - "invoke_from": "workflow_run", - "unknown": "value", + "composition": {"layers": []}, + "execution_context": {"tenant_id": "tenant-1", "invoke_from": "workflow_run"}, } ) diff --git a/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py b/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py index d5fce5adb1..a4a5ad8429 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py @@ -6,25 +6,18 @@ import httpx import pytest from agenton.compositor import CompositorSessionSnapshot, LayerSessionSnapshot -from agenton.layers import ExitIntent, LifecycleState -from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID +from agenton.layers import LifecycleState from agenton_collections.layers.plain import PromptLayerConfig from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig -from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID +from dify_agent.protocol import DIFY_AGENT_OUTPUT_LAYER_ID from dify_agent.protocol.schemas import ( CreateRunRequest, - LayerExitSignals, RunComposition, RunEvent, RunLayerSpec, RunStatus, ) -from dify_agent.runtime.run_scheduler import ( - RunRequestValidationError, - RunScheduler, - SchedulerStoppingError, - validate_run_request, -) +from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError from dify_agent.server.schemas import RunRecord @@ -168,390 +161,64 @@ def test_shutdown_marks_unfinished_runs_failed_and_appends_event() -> None: asyncio.run(scenario()) -def test_create_run_rejects_blank_prompt_before_persisting() -> None: +def test_create_run_accepts_blank_prompt_and_runner_fails_asynchronously() -> None: async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - with pytest.raises(ValueError, match="run.user_prompts must not be empty"): - await scheduler.create_run(_request(["", " "])) + record = await scheduler.create_run(_request(["", " "])) + await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1) - assert store.records == {} + assert store.records == {record.run_id: record} + assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"] + assert store.statuses[record.run_id] == "failed" + assert store.errors[record.run_id] == "run.user_prompts must not be empty" asyncio.run(scenario()) -def test_create_run_rejects_invalid_output_schema_before_persisting() -> None: +def test_create_run_accepts_invalid_output_schema_and_runner_fails_asynchronously() -> None: async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - with pytest.raises(ValueError, match=r"Recursive \$defs refs are not supported"): - await scheduler.create_run( - _request( - output_config={ - "json_schema": _recursive_output_schema(), - } - ) - ) - - assert store.records == {} - - asyncio.run(scenario()) - - -def test_create_run_rejects_remote_ref_output_schema_before_persisting() -> None: - async def scenario() -> None: - store = FakeStore() - async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - - with pytest.raises(ValueError, match=r"Remote \$ref values are not supported"): - await scheduler.create_run( - _request( - output_config={ - "json_schema": { - "type": "object", - "properties": { - "title": {"$ref": "https://example.com/schema.json"}, - }, - }, - } - ) - ) - - assert store.records == {} - - asyncio.run(scenario()) - - -def test_create_run_rejects_non_object_output_schema_before_persisting() -> None: - async def scenario() -> None: - store = FakeStore() - async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - - with pytest.raises(ValueError, match="Schema must declare an object output"): - await scheduler.create_run( - _request( - output_config={ - "json_schema": { - "type": "array", - "items": {"type": "string"}, - }, - } - ) - ) - - assert store.records == {} - - asyncio.run(scenario()) - - -def test_create_run_rejects_public_output_tool_name_override_before_persisting() -> None: - async def scenario() -> None: - store = FakeStore() - async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - - with pytest.raises(ValueError, match="Extra inputs are not permitted"): - await scheduler.create_run( - _request( - output_config={ - "name": "incident_summary", - "json_schema": { - "type": "object", - "properties": {"title": {"type": "string"}}, - "required": ["title"], - "additionalProperties": False, - }, - } - ) - ) - - assert store.records == {} - - asyncio.run(scenario()) - - -def test_create_run_rejects_non_defs_local_ref_in_direct_object_schema_before_persisting() -> None: - async def scenario() -> None: - store = FakeStore() - async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - - with pytest.raises(ValueError, match=r"Only local refs under '#/\$defs/' are supported"): - await scheduler.create_run( - _request( - output_config={ - "json_schema": { - "type": "object", - "properties": { - "items": {"$ref": "#/definitions/itemArray"}, - }, - "required": ["items"], - "definitions": { - "itemArray": { - "type": "array", - "items": {"type": "string"}, - }, - }, - }, - } - ) - ) - - assert store.records == {} - - asyncio.run(scenario()) - - -def test_create_run_rejects_misnamed_output_layer_before_persisting() -> None: - async def scenario() -> None: - store = FakeStore() - async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - - request = CreateRunRequest( - composition=RunComposition( - layers=[ - RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")), - RunLayerSpec( - name="structured-output", - type=DIFY_OUTPUT_LAYER_TYPE_ID, - config=DifyOutputLayerConfig( - json_schema={ - "type": "object", - "properties": {"title": {"type": "string"}}, - "required": ["title"], - "additionalProperties": False, - } - ), - ), - ] + record = await scheduler.create_run( + _request( + output_config={ + "json_schema": _recursive_output_schema(), + } ) ) + await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1) - with pytest.raises(ValueError, match="must use reserved layer name 'output'"): - await scheduler.create_run(request) - - assert store.records == {} + assert store.records == {record.run_id: record} + assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"] + assert store.statuses[record.run_id] == "failed" + assert "Recursive $defs refs are not supported" in (store.errors[record.run_id] or "") asyncio.run(scenario()) -def test_create_run_rejects_multiple_output_layers_before_persisting() -> None: - async def scenario() -> None: - store = FakeStore() - async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - - request = CreateRunRequest( - composition=RunComposition( - layers=[ - RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")), - RunLayerSpec( - name=DIFY_AGENT_OUTPUT_LAYER_ID, - type=DIFY_OUTPUT_LAYER_TYPE_ID, - config=DifyOutputLayerConfig( - json_schema={ - "type": "object", - "properties": {"title": {"type": "string"}}, - "required": ["title"], - "additionalProperties": False, - } - ), - ), - RunLayerSpec( - name="secondary-output", - type=DIFY_OUTPUT_LAYER_TYPE_ID, - config=DifyOutputLayerConfig( - json_schema={ - "type": "object", - "properties": {"summary": {"type": "string"}}, - "required": ["summary"], - "additionalProperties": False, - } - ), - ), - ] - ) - ) - - with pytest.raises(ValueError, match="Only one 'dify.output' layer is supported"): - await scheduler.create_run(request) - - assert store.records == {} - - asyncio.run(scenario()) - - -def test_create_run_rejects_reserved_output_name_with_wrong_layer_type_before_persisting() -> None: - async def scenario() -> None: - store = FakeStore() - async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - - request = CreateRunRequest( - composition=RunComposition( - layers=[ - RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")), - RunLayerSpec( - name=DIFY_AGENT_OUTPUT_LAYER_ID, type="plain.prompt", config=PromptLayerConfig(user="hi") - ), - ] - ) - ) - - with pytest.raises(ValueError, match=r"Layer 'output' must be DifyOutputLayer, got PromptLayer"): - await scheduler.create_run(request) - - assert store.records == {} - - asyncio.run(scenario()) - - -def test_validate_run_request_honors_explicit_empty_layer_providers() -> None: - async def scenario() -> None: - with pytest.raises(RunRequestValidationError, match="plain.prompt"): - await validate_run_request(_request(), layer_providers=()) - - asyncio.run(scenario()) - - -def test_validate_run_request_rejects_misnamed_output_layer_before_provider_checks() -> None: - async def scenario() -> None: - request = CreateRunRequest( - composition=RunComposition( - layers=[ - RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")), - RunLayerSpec( - name="structured-output", - type=DIFY_OUTPUT_LAYER_TYPE_ID, - config=DifyOutputLayerConfig( - json_schema={ - "type": "object", - "properties": {"title": {"type": "string"}}, - "required": ["title"], - "additionalProperties": False, - } - ), - ), - ] - ) - ) - - with pytest.raises(RunRequestValidationError, match="must use reserved layer name 'output'"): - await validate_run_request(request, layer_providers=()) - - asyncio.run(scenario()) - - -def test_validate_run_request_accepts_reserved_history_layer() -> None: - async def scenario() -> None: - request = CreateRunRequest( - composition=RunComposition( - layers=[ - RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")), - RunLayerSpec(name=DIFY_AGENT_HISTORY_LAYER_ID, type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID), - ] - ) - ) - - await validate_run_request(request) - - asyncio.run(scenario()) - - -def test_validate_run_request_rejects_misnamed_history_layer_before_provider_checks() -> None: - async def scenario() -> None: - request = CreateRunRequest( - composition=RunComposition( - layers=[ - RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")), - RunLayerSpec(name="chat-history", type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID), - ] - ) - ) - - with pytest.raises(RunRequestValidationError, match="must use reserved layer name 'history'"): - await validate_run_request(request, layer_providers=()) - - asyncio.run(scenario()) - - -def test_validate_run_request_rejects_multiple_history_layers_before_provider_checks() -> None: - async def scenario() -> None: - request = CreateRunRequest( - composition=RunComposition( - layers=[ - RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")), - RunLayerSpec(name=DIFY_AGENT_HISTORY_LAYER_ID, type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID), - RunLayerSpec(name="secondary-history", type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID), - ] - ) - ) - - with pytest.raises(RunRequestValidationError, match="Only one 'pydantic_ai.history' layer is supported"): - await validate_run_request(request, layer_providers=()) - - asyncio.run(scenario()) - - -def test_validate_run_request_rejects_history_layer_dependencies_before_provider_checks() -> None: - async def scenario() -> None: - request = CreateRunRequest( - composition=RunComposition( - layers=[ - RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")), - RunLayerSpec( - name=DIFY_AGENT_HISTORY_LAYER_ID, - type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID, - deps={"prompt": "prompt"}, - ), - ] - ) - ) - - with pytest.raises(RunRequestValidationError, match="does not support dependencies"): - await validate_run_request(request, layer_providers=()) - - asyncio.run(scenario()) - - -def test_create_run_rejects_unknown_layer_exit_signal_before_persisting() -> None: - async def scenario() -> None: - store = FakeStore() - async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) - request = _request() - request.on_exit = LayerExitSignals(layers={"missing": ExitIntent.DELETE}) - - with pytest.raises(ValueError, match="missing"): - await scheduler.create_run(request) - - assert store.records == {} - - asyncio.run(scenario()) - - -def test_create_run_honors_explicit_empty_layer_providers_before_persisting() -> None: +def test_create_run_honors_explicit_empty_layer_providers_by_failing_after_persisting() -> None: async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, layer_providers=()) - with pytest.raises(RunRequestValidationError, match="plain.prompt"): - await scheduler.create_run(_request()) + record = await scheduler.create_run(_request()) + await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1) - assert store.records == {} + assert store.records == {record.run_id: record} + assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"] + assert store.statuses[record.run_id] == "failed" + assert "plain.prompt" in (store.errors[record.run_id] or "") asyncio.run(scenario()) -def test_create_run_rejects_closed_session_snapshot_before_persisting() -> None: +def test_create_run_accepts_closed_session_snapshot_and_runner_fails_asynchronously() -> None: async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: @@ -567,10 +234,13 @@ def test_create_run_rejects_closed_session_snapshot_before_persisting() -> None: ] ) - with pytest.raises(ValueError, match="CLOSED snapshots cannot be entered"): - _ = await scheduler.create_run(request) + record = await scheduler.create_run(request) + await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1) - assert store.records == {} + assert store.records == {record.run_id: record} + assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"] + assert store.statuses[record.run_id] == "failed" + assert "CLOSED snapshots cannot be entered" in (store.errors[record.run_id] or "") asyncio.run(scenario()) diff --git a/dify-agent/tests/local/dify_agent/runtime/test_runner.py b/dify-agent/tests/local/dify_agent/runtime/test_runner.py index ddf860beb6..6683f982a8 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_runner.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_runner.py @@ -1,9 +1,11 @@ import asyncio from collections.abc import Mapping -from typing import Any +from typing import Any, ClassVar, cast import httpx import pytest +from pydantic import JsonValue +from pydantic_ai import Tool from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( ModelMessage, @@ -18,12 +20,22 @@ from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.test import TestModel from pydantic_ai.settings import ModelSettings -from agenton.compositor import CompositorSessionSnapshot, LayerSessionSnapshot +from agenton.compositor import CompositorSessionSnapshot, LayerProvider, LayerSessionSnapshot from agenton.layers import ExitIntent, LifecycleState from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID, PydanticAIHistoryRuntimeState -from agenton_collections.layers.plain import PromptLayerConfig -from dify_agent.layers.dify_plugin.configs import DifyPluginLLMLayerConfig, DifyPluginLayerConfig +from agenton_collections.layers.plain import PromptLayerConfig, ToolsLayer +from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig +from dify_agent.layers.dify_plugin.configs import ( + DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + DifyPluginLLMLayerConfig, + DifyPluginToolConfig, + DifyPluginToolParameter, + DifyPluginToolParameterForm, + DifyPluginToolParameterType, + DifyPluginToolsLayerConfig, +) from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer +from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID from dify_agent.protocol.schemas import ( @@ -34,15 +46,20 @@ from dify_agent.protocol.schemas import ( RunSucceededEvent, ) from dify_agent.runtime.event_sink import InMemoryRunEventSink +from dify_agent.runtime.compositor_factory import create_default_layer_providers from dify_agent.runtime.runner import AgentRunRunner, AgentRunValidationError +class StaticToolsTestLayer(ToolsLayer): + type_id: ClassVar[str] = "test.static.tools" + + def _request( user: str | list[str] = "hello", *, include_history: bool = False, llm_layer_name: str = DIFY_AGENT_MODEL_LAYER_ID, - plugin_layer_name: str = "plugin", + execution_context_layer_name: str = "execution_context", on_exit: LayerExitSignals | None = None, output_config: Mapping[str, object] | DifyOutputLayerConfig | None = None, ) -> CreateRunRequest: @@ -58,15 +75,16 @@ def _request( else [] ), RunLayerSpec( - name=plugin_layer_name, - type="dify.plugin", - config=DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai"), + name=execution_context_layer_name, + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), ), RunLayerSpec( name=llm_layer_name, type="dify.plugin.llm", - deps={"plugin": plugin_layer_name}, + deps={"execution_context": execution_context_layer_name}, config=DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", model_provider="openai", model="demo-model", credentials={"api_key": "secret"}, @@ -103,6 +121,35 @@ def _recursive_output_schema() -> dict[str, object]: } +def _prepared_plugin_tool_parameters() -> list[DifyPluginToolParameter]: + return [ + DifyPluginToolParameter( + name="query", + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.LLM, + required=True, + llm_description="Search query", + ), + DifyPluginToolParameter( + name="auth_scope", + type=DifyPluginToolParameterType.STRING, + form=DifyPluginToolParameterForm.FORM, + required=True, + llm_description="Hidden auth scope", + ), + ] + + +def _prepared_plugin_tool_schema() -> dict[str, JsonValue]: + return { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + } + + class SequenceOutputTestModel(TestModel): outputs: list[str | dict[str, Any] | None] request_count: int @@ -170,7 +217,7 @@ def _history_session_snapshot( lifecycle_state=LifecycleState.SUSPENDED, runtime_state=PydanticAIHistoryRuntimeState(messages=messages).model_dump(mode="json"), ), - LayerSessionSnapshot(name="plugin", lifecycle_state=LifecycleState.SUSPENDED, runtime_state={}), + LayerSessionSnapshot(name="execution_context", lifecycle_state=LifecycleState.SUSPENDED, runtime_state={}), LayerSessionSnapshot( name=DIFY_AGENT_MODEL_LAYER_ID, lifecycle_state=LifecycleState.SUSPENDED, runtime_state={} ), @@ -198,12 +245,12 @@ def test_runner_emits_terminal_success_and_snapshot(monkeypatch: pytest.MonkeyPa def fake_get_model(self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient): assert self.config.model == "demo-model" - assert self.deps.plugin.config.plugin_id == "langgenius/openai" + assert self.config.plugin_id == "langgenius/openai" seen_clients.append(http_client) return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType] monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model) - request = _request(plugin_layer_name="renamed-plugin") + request = _request(execution_context_layer_name="renamed-execution-context") sink = InMemoryRunEventSink() async def scenario() -> None: @@ -230,7 +277,7 @@ def test_runner_emits_terminal_success_and_snapshot(monkeypatch: pytest.MonkeyPa assert terminal.data.output == "done" assert [layer.name for layer in terminal.data.session_snapshot.layers] == [ "prompt", - "renamed-plugin", + "renamed-execution-context", DIFY_AGENT_MODEL_LAYER_ID, ] assert [layer.lifecycle_state for layer in terminal.data.session_snapshot.layers] == [ @@ -241,6 +288,315 @@ def test_runner_emits_terminal_success_and_snapshot(monkeypatch: pytest.MonkeyPa assert sink.statuses["run-1"] == "succeeded" +def test_runner_passes_dynamic_dify_plugin_tools_to_agent(monkeypatch: pytest.MonkeyPatch) -> None: + seen_tools: list[Tool[object]] = [] + + async def plugin_tool() -> str: + return "tool" + + def fake_get_model(_self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient): + assert http_client.is_closed is False + return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType] + + async def fake_get_tools(self: DifyPluginToolsLayer, *, http_client: httpx.AsyncClient) -> list[Tool[object]]: + assert self.config.tools[0].tool_name == "web_search" + assert http_client.is_closed is False + return [Tool(plugin_tool, name="web_search")] + + class FakeResult: + output: str = "done" + + def new_messages(self) -> list[ModelMessage]: + return [] + + class FakeAgent: + async def run(self, *_args: object, **_kwargs: object) -> FakeResult: + return FakeResult() + + def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: object) -> FakeAgent: + del model, output_type + seen_tools.extend(tools) + return FakeAgent() + + monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model) + monkeypatch.setattr(DifyPluginToolsLayer, "get_tools", fake_get_tools) + monkeypatch.setattr("dify_agent.runtime.runner.create_agent", fake_create_agent) + + request = CreateRunRequest( + composition=RunComposition( + layers=[ + RunLayerSpec( + name="prompt", + type="plain.prompt", + config=PromptLayerConfig(prefix="system", user="hello"), + ), + RunLayerSpec( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), + ), + RunLayerSpec( + name=DIFY_AGENT_MODEL_LAYER_ID, + type="dify.plugin.llm", + deps={"execution_context": "execution_context"}, + config=DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", + model_provider="openai", + model="demo-model", + credentials={"api_key": "secret"}, + ), + ), + RunLayerSpec( + name="tools", + type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + deps={"execution_context": "execution_context"}, + config=DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search", + credential_type="api-key", + parameters=_prepared_plugin_tool_parameters(), + parameters_json_schema=_prepared_plugin_tool_schema(), + ) + ] + ), + ), + ] + ) + ) + sink = InMemoryRunEventSink() + + async def scenario() -> None: + async with httpx.AsyncClient() as client: + await AgentRunRunner( + sink=sink, + request=request, + run_id="run-tools", + plugin_daemon_http_client=client, + ).run() + + asyncio.run(scenario()) + + assert [tool.name for tool in seen_tools] == ["web_search"] + terminal = sink.events["run-tools"][-1] + assert isinstance(terminal, RunSucceededEvent) + assert terminal.data.output == "done" + + +def test_runner_rejects_duplicate_tool_names_across_dynamic_tool_layers( + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_agent_called = False + + async def duplicate_tool() -> str: + return "tool" + + def fake_get_model(_self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient): + assert http_client.is_closed is False + return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType] + + async def fake_get_tools(_self: DifyPluginToolsLayer, *, http_client: httpx.AsyncClient) -> list[Tool[object]]: + assert http_client.is_closed is False + return [Tool(duplicate_tool, name="shared_tool")] + + def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: object) -> object: + del model, tools, output_type + nonlocal create_agent_called + create_agent_called = True + raise AssertionError("create_agent should not be called when duplicate tool names are detected") + + monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model) + monkeypatch.setattr(DifyPluginToolsLayer, "get_tools", fake_get_tools) + monkeypatch.setattr("dify_agent.runtime.runner.create_agent", fake_create_agent) + + request = CreateRunRequest( + composition=RunComposition( + layers=[ + RunLayerSpec( + name="prompt", + type="plain.prompt", + config=PromptLayerConfig(prefix="system", user="hello"), + ), + RunLayerSpec( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), + ), + RunLayerSpec( + name=DIFY_AGENT_MODEL_LAYER_ID, + type="dify.plugin.llm", + deps={"execution_context": "execution_context"}, + config=DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", + model_provider="openai", + model="demo-model", + credentials={"api_key": "secret"}, + ), + ), + RunLayerSpec( + name="tools-1", + type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + deps={"execution_context": "execution_context"}, + config=DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search", + credential_type="api-key", + parameters=_prepared_plugin_tool_parameters(), + parameters_json_schema=_prepared_plugin_tool_schema(), + ) + ] + ), + ), + RunLayerSpec( + name="tools-2", + type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + deps={"execution_context": "execution_context"}, + config=DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search_two", + credential_type="api-key", + parameters=_prepared_plugin_tool_parameters(), + parameters_json_schema=_prepared_plugin_tool_schema(), + ) + ] + ), + ), + ] + ) + ) + sink = InMemoryRunEventSink() + + async def scenario() -> None: + async with httpx.AsyncClient() as client: + with pytest.raises( + AgentRunValidationError, + match="unique tool names across all layers, got duplicates: shared_tool", + ): + await AgentRunRunner( + sink=sink, + request=request, + run_id="run-duplicate-tools", + plugin_daemon_http_client=client, + ).run() + + asyncio.run(scenario()) + + assert create_agent_called is False + assert [event.type for event in sink.events["run-duplicate-tools"]] == ["run_started", "run_failed"] + assert sink.statuses["run-duplicate-tools"] == "failed" + + +def test_runner_rejects_duplicate_tool_names_between_static_and_dynamic_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_agent_called = False + + def web_search(query: str) -> str: + return query + + async def dynamic_duplicate_tool() -> str: + return "tool" + + def fake_get_model(_self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient): + assert http_client.is_closed is False + return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType] + + async def fake_get_tools(_self: DifyPluginToolsLayer, *, http_client: httpx.AsyncClient) -> list[Tool[object]]: + assert http_client.is_closed is False + return [Tool(dynamic_duplicate_tool, name="web_search")] + + def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: object) -> object: + del model, tools, output_type + nonlocal create_agent_called + create_agent_called = True + raise AssertionError("create_agent should not be called when duplicate tool names are detected") + + monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model) + monkeypatch.setattr(DifyPluginToolsLayer, "get_tools", fake_get_tools) + monkeypatch.setattr("dify_agent.runtime.runner.create_agent", fake_create_agent) + + static_tools_provider = LayerProvider.from_factory( + layer_type=StaticToolsTestLayer, + create=lambda _config: StaticToolsTestLayer(tool_entries=(web_search,)), + ) + layer_providers = (*create_default_layer_providers(), static_tools_provider) + + request = CreateRunRequest( + composition=RunComposition( + layers=[ + RunLayerSpec( + name="prompt", + type="plain.prompt", + config=PromptLayerConfig(prefix="system", user="hello"), + ), + RunLayerSpec(name="static-tools", type=cast(str, StaticToolsTestLayer.type_id)), + RunLayerSpec( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), + ), + RunLayerSpec( + name=DIFY_AGENT_MODEL_LAYER_ID, + type="dify.plugin.llm", + deps={"execution_context": "execution_context"}, + config=DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", + model_provider="openai", + model="demo-model", + credentials={"api_key": "secret"}, + ), + ), + RunLayerSpec( + name="tools", + type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID, + deps={"execution_context": "execution_context"}, + config=DifyPluginToolsLayerConfig( + tools=[ + DifyPluginToolConfig( + plugin_id="langgenius/tools", + provider="search", + tool_name="web_search", + credential_type="api-key", + parameters=_prepared_plugin_tool_parameters(), + parameters_json_schema=_prepared_plugin_tool_schema(), + ) + ] + ), + ), + ] + ) + ) + sink = InMemoryRunEventSink() + + async def scenario() -> None: + async with httpx.AsyncClient() as client: + with pytest.raises( + AgentRunValidationError, + match="unique tool names across all layers, got duplicates: web_search", + ): + await AgentRunRunner( + sink=sink, + request=request, + run_id="run-static-dynamic-duplicate-tools", + plugin_daemon_http_client=client, + layer_providers=layer_providers, + ).run() + + asyncio.run(scenario()) + + assert create_agent_called is False + assert [event.type for event in sink.events["run-static-dynamic-duplicate-tools"]] == ["run_started", "run_failed"] + assert sink.statuses["run-static-dynamic-duplicate-tools"] == "failed" + + def test_runner_passes_temporary_system_prompt_prefix_without_history_layer(monkeypatch: pytest.MonkeyPatch) -> None: model = RecordingTestModel(custom_output_text="done") @@ -271,7 +627,7 @@ def test_runner_passes_temporary_system_prompt_prefix_without_history_layer(monk assert isinstance(terminal, RunSucceededEvent) assert [layer.name for layer in terminal.data.session_snapshot.layers] == [ "prompt", - "plugin", + "execution_context", DIFY_AGENT_MODEL_LAYER_ID, ] @@ -440,7 +796,7 @@ def test_runner_applies_on_exit_overrides_to_success_snapshot(monkeypatch: pytes assert isinstance(terminal, RunSucceededEvent) assert {layer.name: layer.lifecycle_state for layer in terminal.data.session_snapshot.layers} == { "prompt": LifecycleState.CLOSED, - "plugin": LifecycleState.SUSPENDED, + "execution_context": LifecycleState.SUSPENDED, DIFY_AGENT_MODEL_LAYER_ID: LifecycleState.CLOSED, } @@ -478,7 +834,12 @@ def test_runner_passes_output_layer_spec_to_agent_and_serializes_structured_resu ) ) sink = InMemoryRunEventSink() - expected_snapshot_layer_names = ["prompt", "plugin", DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID] + expected_snapshot_layer_names = [ + "prompt", + "execution_context", + DIFY_AGENT_MODEL_LAYER_ID, + DIFY_AGENT_OUTPUT_LAYER_ID, + ] async def scenario() -> None: async with httpx.AsyncClient() as client: @@ -682,15 +1043,16 @@ def test_runner_rejects_misnamed_output_layer_before_model_resolution(monkeypatc config=PromptLayerConfig(prefix="system", user="hello"), ), RunLayerSpec( - name="plugin", - type="dify.plugin", - config=DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai"), + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type="dify.plugin.llm", - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", model_provider="openai", model="demo-model", credentials={"api_key": "secret"}, @@ -750,15 +1112,16 @@ def test_runner_rejects_multiple_output_layers_before_model_resolution(monkeypat config=PromptLayerConfig(prefix="system", user="hello"), ), RunLayerSpec( - name="plugin", - type="dify.plugin", - config=DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai"), + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type="dify.plugin.llm", - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", model_provider="openai", model="demo-model", credentials={"api_key": "secret"}, @@ -840,15 +1203,16 @@ def test_runner_rejects_reserved_output_name_with_wrong_layer_type_before_model_ config=PromptLayerConfig(prefix="system", user="hello"), ), RunLayerSpec( - name="plugin", - type="dify.plugin", - config=DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai"), + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"), ), RunLayerSpec( name=DIFY_AGENT_MODEL_LAYER_ID, type="dify.plugin.llm", - deps={"plugin": "plugin"}, + deps={"execution_context": "execution_context"}, config=DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", model_provider="openai", model="demo-model", credentials={"api_key": "secret"}, @@ -1042,7 +1406,7 @@ def test_runner_rejects_closed_session_snapshot_as_validation_error() -> None: runtime_state={}, ), LayerSessionSnapshot( - name="plugin", + name="execution_context", lifecycle_state=LifecycleState.NEW, runtime_state={}, ), diff --git a/dify-agent/tests/local/dify_agent/server/test_app.py b/dify-agent/tests/local/dify_agent/server/test_app.py index 73bfde69bd..a0415058d4 100644 --- a/dify-agent/tests/local/dify_agent/server/test_app.py +++ b/dify-agent/tests/local/dify_agent/server/test_app.py @@ -6,9 +6,9 @@ import pytest from fastapi.testclient import TestClient import dify_agent.server.app as app_module +from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig +from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer from dify_agent.runtime.compositor_factory import DifyAgentLayerProvider -from dify_agent.layers.dify_plugin.configs import DifyPluginLayerConfig -from dify_agent.layers.dify_plugin.plugin_layer import DifyPluginLayer from dify_agent.server.app import create_app, create_plugin_daemon_http_client from dify_agent.server.settings import ServerSettings from dify_agent.storage.redis_run_store import RedisRunStore @@ -148,11 +148,15 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt assert scheduler.shutdown_grace_seconds == 5 layer_providers = scheduler.layer_providers assert isinstance(layer_providers, tuple) - plugin_provider = next(provider for provider in layer_providers if provider.type_id == "dify.plugin") - plugin_layer = plugin_provider.create_layer(DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="plugin-1")) - assert isinstance(plugin_layer, DifyPluginLayer) - assert plugin_layer.daemon_url == "http://plugin-daemon" - assert plugin_layer.daemon_api_key == "daemon-secret" + execution_context_provider = next( + provider for provider in layer_providers if provider.type_id == "dify.execution_context" + ) + execution_context_layer = execution_context_provider.create_layer( + DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run") + ) + assert isinstance(execution_context_layer, DifyExecutionContextLayer) + assert execution_context_layer.daemon_url == "http://plugin-daemon" + assert execution_context_layer.daemon_api_key == "daemon-secret" http_client = scheduler.plugin_daemon_http_client assert http_client is fake_http_client assert http_client.is_closed is False diff --git a/dify-agent/tests/local/dify_agent/server/test_runs_routes.py b/dify-agent/tests/local/dify_agent/server/test_runs_routes.py index bed7883170..a33590e208 100644 --- a/dify-agent/tests/local/dify_agent/server/test_runs_routes.py +++ b/dify-agent/tests/local/dify_agent/server/test_runs_routes.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID -from dify_agent.runtime.run_scheduler import RunRequestValidationError, SchedulerStoppingError +from dify_agent.runtime.run_scheduler import SchedulerStoppingError from dify_agent.server.routes.runs import create_runs_router from dify_agent.server.schemas import RunRecord @@ -9,14 +9,14 @@ from dify_agent.server.schemas import RunRecord class FakeScheduler: async def create_run(self, request: object) -> object: del request - raise RunRequestValidationError("run.user_prompts must not be empty") + return RunRecord(run_id="run-1", status="running") class FakeStore: pass -def test_create_run_rejects_effectively_blank_user_prompt_list() -> None: +def test_create_run_accepts_effectively_blank_user_prompt_list() -> None: from fastapi import FastAPI app = FastAPI() @@ -35,8 +35,8 @@ def test_create_run_rejects_effectively_blank_user_prompt_list() -> None: }, ) - assert response.status_code == 422 - assert response.json()["detail"] == "run.user_prompts must not be empty" + assert response.status_code == 202 + assert response.json() == {"run_id": "run-1", "status": "running"} def test_create_run_returns_running_from_scheduler() -> None: @@ -104,15 +104,16 @@ def test_create_run_accepts_valid_full_plugin_graph() -> None: "layers": [ {"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}, { - "name": "plugin-renamed", - "type": "dify.plugin", - "config": {"tenant_id": "tenant-1", "plugin_id": "langgenius/openai"}, + "name": "execution-context-renamed", + "type": "dify.execution_context", + "config": {"tenant_id": "tenant-1", "invoke_from": "workflow_run"}, }, { "name": DIFY_AGENT_MODEL_LAYER_ID, "type": "dify.plugin.llm", - "deps": {"plugin": "plugin-renamed"}, + "deps": {"execution_context": "execution-context-renamed"}, "config": { + "plugin_id": "langgenius/openai", "model_provider": "openai", "model": "gpt-4o-mini", "credentials": {"api_key": "secret"}, @@ -128,17 +129,12 @@ def test_create_run_accepts_valid_full_plugin_graph() -> None: assert response.json() == {"run_id": "run-1", "status": "running"} -def test_create_run_rejects_unknown_layer_exit_signal_before_scheduling() -> None: +def test_create_run_accepts_unknown_layer_exit_signal_request() -> None: from fastapi import FastAPI - class UnknownSignalScheduler: - async def create_run(self, request: object) -> RunRecord: - del request - raise RunRequestValidationError("on_exit.layers references unknown layer ids: missing.") - app = FastAPI() app.include_router( - create_runs_router(lambda: FakeStore(), lambda: UnknownSignalScheduler()) # pyright: ignore[reportArgumentType] + create_runs_router(lambda: FakeStore(), lambda: FakeScheduler()) # pyright: ignore[reportArgumentType] ) client = TestClient(app) @@ -153,21 +149,16 @@ def test_create_run_rejects_unknown_layer_exit_signal_before_scheduling() -> Non }, ) - assert response.status_code == 422 - assert "missing" in response.json()["detail"] + assert response.status_code == 202 + assert response.json() == {"run_id": "run-1", "status": "running"} -def test_create_run_rejects_closed_session_snapshot_with_422() -> None: +def test_create_run_accepts_closed_session_snapshot_request() -> None: from fastapi import FastAPI - class ClosedSnapshotScheduler: - async def create_run(self, request: object) -> RunRecord: - del request - raise RunRequestValidationError("Layer 'prompt' is closed; CLOSED snapshots cannot be entered.") - app = FastAPI() app.include_router( - create_runs_router(lambda: FakeStore(), lambda: ClosedSnapshotScheduler()) # pyright: ignore[reportArgumentType] + create_runs_router(lambda: FakeStore(), lambda: FakeScheduler()) # pyright: ignore[reportArgumentType] ) client = TestClient(app) @@ -191,8 +182,8 @@ def test_create_run_rejects_closed_session_snapshot_with_422() -> None: }, ) - assert response.status_code == 422 - assert "CLOSED snapshots cannot be entered" in response.json()["detail"] + assert response.status_code == 202 + assert response.json() == {"run_id": "run-1", "status": "running"} def test_create_run_returns_503_when_scheduler_is_stopping() -> None: diff --git a/dify-agent/tests/local/dify_agent/test_import_boundaries.py b/dify-agent/tests/local/dify_agent/test_import_boundaries.py index 7ff55b167b..92a3d45f2a 100644 --- a/dify-agent/tests/local/dify_agent/test_import_boundaries.py +++ b/dify-agent/tests/local/dify_agent/test_import_boundaries.py @@ -79,8 +79,9 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() -> blocked_imports=[ "anthropic", "dify_agent.adapters.llm", + "dify_agent.layers.execution_context.layer", "dify_agent.layers.dify_plugin.llm_layer", - "dify_agent.layers.dify_plugin.plugin_layer", + "dify_agent.layers.dify_plugin.tools_layer", "dify_agent.layers.output.output_layer", "dify_agent.runtime", "dify_agent.server", @@ -91,10 +92,16 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() -> "pydantic_settings", "redis", ], - imports=["dify_agent.protocol", "dify_agent.layers.dify_plugin", "dify_agent.layers.output"], + imports=[ + "dify_agent.protocol", + "dify_agent.layers.execution_context", + "dify_agent.layers.dify_plugin", + "dify_agent.layers.output", + ], assertions=[ "assert hasattr(dify_agent_protocol, 'PydanticAIStreamRunEvent')", - "assert dify_agent_layers_dify_plugin.__all__ == ['DIFY_PLUGIN_LAYER_TYPE_ID', 'DIFY_PLUGIN_LLM_LAYER_TYPE_ID', 'DifyPluginCredentialValue', 'DifyPluginLLMLayerConfig', 'DifyPluginLayerConfig']", + "assert dify_agent_layers_execution_context.__all__ == ['DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID', 'DifyExecutionContextInvokeFrom', 'DifyExecutionContextLayerConfig']", + "assert dify_agent_layers_dify_plugin.__all__ == ['DIFY_PLUGIN_LLM_LAYER_TYPE_ID', 'DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID', 'DifyPluginCredentialValue', 'DifyPluginLLMLayerConfig', 'DifyPluginToolCredentialType', 'DifyPluginToolConfig', 'DifyPluginToolOption', 'DifyPluginToolParameter', 'DifyPluginToolParameterForm', 'DifyPluginToolParameterType', 'DifyPluginToolsLayerConfig', 'DifyPluginToolValue']", "assert dify_agent_layers_output.__all__ == ['DIFY_OUTPUT_LAYER_TYPE_ID', 'DifyOutputLayerConfig']", ], ) diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 2624746723..c3bb13dc79 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -67,6 +67,11 @@ "count": 2 } }, + "web/__tests__/goto-anything/search-error-handling.test.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/__tests__/i18n-upload-features.test.ts": { "no-console": { "count": 3 @@ -142,6 +147,11 @@ "count": 1 } }, + "web/app/(commonLayout)/snippets/[snippetId]/page.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/(humanInputLayout)/form/[token]/form.tsx": { "react/set-state-in-effect": { "count": 1 @@ -192,7 +202,7 @@ }, "web/app/account/(commonLayout)/account-page/email-change-modal.tsx": { "no-restricted-imports": { - "count": 1 + "count": 2 } }, "web/app/account/(commonLayout)/account-page/index.tsx": { @@ -238,6 +248,11 @@ "count": 1 } }, + "web/app/components/app-sidebar/nav-link/index.tsx": { + "tailwindcss/enforce-consistent-class-order": { + "count": 3 + } + }, "web/app/components/app/annotation/add-annotation-modal/edit-item/index.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -818,6 +833,11 @@ "count": 1 } }, + "web/app/components/base/chat/chat/__tests__/hooks.spec.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/base/chat/chat/answer/agent-content.tsx": { "style/multiline-ternary": { "count": 2 @@ -860,6 +880,9 @@ } }, "web/app/components/base/chat/chat/hooks.ts": { + "no-restricted-imports": { + "count": 2 + }, "react/set-state-in-effect": { "count": 2 }, @@ -1000,6 +1023,11 @@ "count": 1 } }, + "web/app/components/base/file-uploader/__tests__/utils.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/base/file-uploader/dynamic-pdf-preview.tsx": { "ts/no-explicit-any": { "count": 1 @@ -1039,6 +1067,9 @@ } }, "web/app/components/base/file-uploader/utils.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } @@ -1401,6 +1432,9 @@ } }, "web/app/components/base/image-uploader/__tests__/utils.spec.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } @@ -1421,6 +1455,9 @@ } }, "web/app/components/base/image-uploader/utils.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -1747,7 +1784,15 @@ "count": 1 } }, + "web/app/components/base/text-generation/__tests__/hooks.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/base/text-generation/hooks.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -1757,14 +1802,6 @@ "count": 1 } }, - "web/app/components/base/textarea/index.stories.tsx": { - "no-console": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/base/voice-input/__tests__/index.spec.tsx": { "ts/no-explicit-any": { "count": 3 @@ -1899,6 +1936,11 @@ "count": 1 } }, + "web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/datasets/create/notion-page-preview/index.tsx": { "react/set-state-in-effect": { "count": 1 @@ -2027,6 +2069,9 @@ } }, "web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -2037,6 +2082,9 @@ } }, "web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react/set-state-in-effect": { "count": 5 } @@ -2065,6 +2113,9 @@ } }, "web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/index.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -2102,6 +2153,11 @@ "count": 2 } }, + "web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/datasets/documents/detail/completed/child-segment-list.tsx": { "no-restricted-imports": { "count": 1 @@ -2379,6 +2435,11 @@ "count": 2 } }, + "web/app/components/goto-anything/actions/plugin.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/goto-anything/actions/types.ts": { "ts/no-explicit-any": { "count": 2 @@ -2646,6 +2707,9 @@ "web/app/components/plugins/marketplace/hooks.ts": { "@tanstack/query/exhaustive-deps": { "count": 1 + }, + "no-restricted-imports": { + "count": 1 } }, "web/app/components/plugins/marketplace/search-box/tags-filter.tsx": { @@ -2999,6 +3063,9 @@ } }, "web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -3014,6 +3081,9 @@ } }, "web/app/components/rag-pipeline/hooks/use-pipeline-run.ts": { + "no-restricted-imports": { + "count": 2 + }, "ts/no-explicit-any": { "count": 1 } @@ -3043,11 +3113,26 @@ "count": 1 } }, + "web/app/components/share/text-generation/result/__tests__/index.spec.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/app/components/share/text-generation/result/__tests__/workflow-stream-handlers.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/share/text-generation/result/index.tsx": { "ts/no-explicit-any": { "count": 1 } }, + "web/app/components/share/text-generation/result/workflow-stream-handlers.ts": { + "no-restricted-imports": { + "count": 2 + } + }, "web/app/components/share/text-generation/run-batch/csv-download/index.tsx": { "react/static-components": { "count": 2 @@ -3087,20 +3172,22 @@ "count": 2 } }, - "web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx": { + "web/app/components/snippets/hooks/use-nodes-sync-draft.ts": { "no-restricted-imports": { "count": 1 } }, + "web/app/components/snippets/hooks/use-snippet-run.ts": { + "no-restricted-imports": { + "count": 2 + } + }, "web/app/components/tools/edit-custom-collection-modal/get-schema.tsx": { "no-restricted-imports": { "count": 1 } }, "web/app/components/tools/edit-custom-collection-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 4 }, @@ -3213,6 +3300,9 @@ } }, "web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -3227,6 +3317,21 @@ "count": 2 } }, + "web/app/components/workflow-app/hooks/use-workflow-run-callbacks.ts": { + "no-restricted-imports": { + "count": 2 + } + }, + "web/app/components/workflow-app/hooks/use-workflow-run-utils.ts": { + "no-restricted-imports": { + "count": 3 + } + }, + "web/app/components/workflow-app/hooks/use-workflow-run.ts": { + "no-restricted-imports": { + "count": 2 + } + }, "web/app/components/workflow-app/hooks/use-workflow-template.ts": { "ts/no-explicit-any": { "count": 2 @@ -3247,6 +3352,11 @@ "count": 1 } }, + "web/app/components/workflow/block-selector/blocks.tsx": { + "unused-imports/no-unused-imports": { + "count": 1 + } + }, "web/app/components/workflow/block-selector/hooks.ts": { "react/set-state-in-effect": { "count": 1 @@ -3347,6 +3457,9 @@ } }, "web/app/components/workflow/hooks-store/store.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 6 } @@ -3667,6 +3780,9 @@ } }, "web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts": { + "no-restricted-imports": { + "count": 1 + }, "react/set-state-in-effect": { "count": 2 }, @@ -4526,6 +4642,11 @@ "count": 1 } }, + "web/app/components/workflow/panel/debug-and-preview/__tests__/hooks.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx": { "ts/no-explicit-any": { "count": 6 @@ -4537,6 +4658,9 @@ } }, "web/app/components/workflow/panel/debug-and-preview/hooks.ts": { + "no-restricted-imports": { + "count": 2 + }, "ts/no-explicit-any": { "count": 12 } @@ -4805,6 +4929,11 @@ "count": 1 } }, + "web/app/device/page.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/education-apply/hooks.ts": { "react/set-state-in-effect": { "count": 5 @@ -4882,7 +5011,7 @@ }, "web/app/signin/components/mail-and-password-auth.tsx": { "no-restricted-imports": { - "count": 1 + "count": 2 }, "ts/no-explicit-any": { "count": 1 @@ -5101,42 +5230,106 @@ "count": 6 } }, + "web/service/__tests__/base.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/__tests__/use-pipeline.spec.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/__tests__/use-snippet-workflows.spec.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/access-control.ts": { "@tanstack/query/exhaustive-deps": { "count": 1 + }, + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/annotation.spec.ts": { + "no-restricted-imports": { + "count": 1 } }, "web/service/annotation.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 4 } }, "web/service/apps.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 7 } }, + "web/service/base.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/base.ts": { + "no-restricted-imports": { + "count": 2 + }, "ts/no-explicit-any": { "count": 3 } }, + "web/service/billing.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/client.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/common.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 29 } }, "web/service/datasets.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 6 } }, "web/service/debug.ts": { + "no-restricted-imports": { + "count": 2 + }, "ts/no-explicit-any": { "count": 6 } }, + "web/service/fetch.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/fetch.ts": { + "no-restricted-imports": { + "count": 1 + }, "regexp/no-unused-capturing-group": { "count": 1 }, @@ -5144,20 +5337,104 @@ "count": 2 } }, + "web/service/knowledge/use-create-dataset.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/knowledge/use-dataset.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/knowledge/use-dataset.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/knowledge/use-document.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/knowledge/use-hit-testing.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/knowledge/use-import.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/knowledge/use-metadata.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/knowledge/use-segment.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/log.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/plugins.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/share.ts": { "erasable-syntax-only/enums": { "count": 1 }, + "no-restricted-imports": { + "count": 2 + }, "ts/no-explicit-any": { "count": 3 } }, + "web/service/sso.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/strategy.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/tools.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/try-app.spec.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/try-app.ts": { "no-barrel-files/no-barrel-files": { "count": 2 + }, + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/use-apps.ts": { + "no-restricted-imports": { + "count": 1 } }, "web/service/use-common.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-empty-object-type": { "count": 1 }, @@ -5165,7 +5442,20 @@ "count": 1 } }, + "web/service/use-datasource.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/use-education.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/use-endpoints.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 7 } @@ -5175,17 +5465,41 @@ "count": 1 } }, + "web/service/use-log.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/use-models.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/use-oauth.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/use-pipeline.ts": { "@tanstack/query/exhaustive-deps": { "count": 1 + }, + "no-restricted-imports": { + "count": 1 } }, "web/service/use-plugins-auth.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 4 } }, "web/service/use-plugins.ts": { + "no-restricted-imports": { + "count": 1 + }, "react/set-state-in-effect": { "count": 1 }, @@ -5196,24 +5510,55 @@ "count": 3 } }, + "web/service/use-snippet-workflows.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/use-tools.ts": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } }, + "web/service/use-triggers.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/use-workflow.ts": { "@tanstack/query/exhaustive-deps": { "count": 1 }, + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } }, + "web/service/use-workspace.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/service/utils.spec.ts": { "ts/no-explicit-any": { "count": 2 } }, + "web/service/webapp-auth.ts": { + "no-restricted-imports": { + "count": 1 + } + }, + "web/service/workflow.ts": { + "no-restricted-imports": { + "count": 1 + } + }, "web/types/app.ts": { "erasable-syntax-only/enums": { "count": 9 diff --git a/package.json b/package.json index 374d4f6df0..19dbb2e617 100644 --- a/package.json +++ b/package.json @@ -3,6 +3,13 @@ "type": "module", "private": true, "packageManager": "pnpm@11.2.2", + "devEngines": { + "runtime": { + "name": "node", + "version": "^22.22.1", + "onFail": "download" + } + }, "engines": { "node": "^22.22.1" }, diff --git a/packages/contracts/generated/api/console/agents/types.gen.ts b/packages/contracts/generated/api/console/agents/types.gen.ts index 8a4540f933..e63c57f03b 100644 --- a/packages/contracts/generated/api/console/agents/types.gen.ts +++ b/packages/contracts/generated/api/console/agents/types.gen.ts @@ -121,9 +121,7 @@ export type AgentSoulToolsConfig = { cli_tools?: Array<{ [key: string]: unknown }> - dify_tools?: Array<{ - [key: string]: unknown - }> + dify_tools?: Array } export type AgentKnowledgeQueryMode = 'generated_query' | 'user_query' @@ -134,6 +132,28 @@ export type AgentSoulModelCredentialRef = { type: string } +export type AgentSoulDifyToolConfig = { + credential_ref?: AgentSoulDifyToolCredentialRef + credential_type?: 'api-key' | 'oauth2' | 'unauthorized' + description?: string | null + enabled?: boolean + name?: string | null + plugin_id?: string | null + provider?: string | null + provider_id?: string | null + provider_type?: string + runtime_parameters?: { + [key: string]: unknown + } + tool_name: string +} + +export type AgentSoulDifyToolCredentialRef = { + id?: string | null + provider?: string | null + type?: 'provider' | 'tool' +} + export type GetAgentsData = { body?: never path?: never diff --git a/packages/contracts/generated/api/console/agents/zod.gen.ts b/packages/contracts/generated/api/console/agents/zod.gen.ts index f84b5fc411..130144d48d 100644 --- a/packages/contracts/generated/api/console/agents/zod.gen.ts +++ b/packages/contracts/generated/api/console/agents/zod.gen.ts @@ -78,14 +78,6 @@ export const zAgentSoulSkillsFilesConfig = z.object({ skills: z.array(z.record(z.string(), z.unknown())).optional(), }) -/** - * AgentSoulToolsConfig - */ -export const zAgentSoulToolsConfig = z.object({ - cli_tools: z.array(z.record(z.string(), z.unknown())).optional(), - dify_tools: z.array(z.record(z.string(), z.unknown())).optional(), -}) - /** * AgentKnowledgeQueryMode */ @@ -124,6 +116,53 @@ export const zAgentSoulModelConfig = z.object({ plugin_id: z.string().min(1).max(255), }) +/** + * AgentSoulDifyToolCredentialRef + * + * Reference to a stored Dify Plugin Tool credential. + * + * Secret values are resolved only at runtime. The legacy ``credential_id`` + * field is accepted by :class:`AgentSoulDifyToolConfig` and normalized here so + * old Agent tool payloads can be read while new payloads stay explicit. + */ +export const zAgentSoulDifyToolCredentialRef = z.object({ + id: z.string().max(255).nullish(), + provider: z.string().max(255).nullish(), + type: z.enum(['provider', 'tool']).optional().default('tool'), +}) + +/** + * AgentSoulDifyToolConfig + * + * One Dify Plugin Tool configured on Agent Soul. + * + * The API backend prepares this persisted product shape into + * ``DifyPluginToolConfig`` before sending a run request to Agent backend. + * ``provider_id`` keeps compatibility with existing Agent tool config payloads; + * new callers should send ``plugin_id`` + ``provider`` when available. + */ +export const zAgentSoulDifyToolConfig = z.object({ + credential_ref: zAgentSoulDifyToolCredentialRef.optional(), + credential_type: z.enum(['api-key', 'oauth2', 'unauthorized']).optional().default('api-key'), + description: z.string().nullish(), + enabled: z.boolean().optional().default(true), + name: z.string().max(255).nullish(), + plugin_id: z.string().max(255).nullish(), + provider: z.string().max(255).nullish(), + provider_id: z.string().max(255).nullish(), + provider_type: z.string().optional().default('plugin'), + runtime_parameters: z.record(z.string(), z.unknown()).optional(), + tool_name: z.string().min(1).max(255), +}) + +/** + * AgentSoulToolsConfig + */ +export const zAgentSoulToolsConfig = z.object({ + cli_tools: z.array(z.record(z.string(), z.unknown())).optional(), + dify_tools: z.array(zAgentSoulDifyToolConfig).optional(), +}) + /** * AgentSoulConfig */ diff --git a/packages/contracts/generated/api/console/apps/orpc.gen.ts b/packages/contracts/generated/api/console/apps/orpc.gen.ts index c3fdc93491..a4f6183130 100644 --- a/packages/contracts/generated/api/console/apps/orpc.gen.ts +++ b/packages/contracts/generated/api/console/apps/orpc.gen.ts @@ -167,6 +167,14 @@ import { zGetAppsByAppIdWorkflowsDraftNodesByNodeIdVariablesResponse, zGetAppsByAppIdWorkflowsDraftPath, zGetAppsByAppIdWorkflowsDraftResponse, + zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewPath, + zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponse, + zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdPath, + zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdResponse, + zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsPath, + zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsResponse, + zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsPath, + zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsResponse, zGetAppsByAppIdWorkflowsDraftSystemVariablesPath, zGetAppsByAppIdWorkflowsDraftSystemVariablesResponse, zGetAppsByAppIdWorkflowsDraftVariablesByVariableIdPath, @@ -175,6 +183,14 @@ import { zGetAppsByAppIdWorkflowsDraftVariablesQuery, zGetAppsByAppIdWorkflowsDraftVariablesResponse, zGetAppsByAppIdWorkflowsPath, + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewPath, + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponse, + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdPath, + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdResponse, + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsPath, + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsResponse, + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsPath, + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsResponse, zGetAppsByAppIdWorkflowsPublishPath, zGetAppsByAppIdWorkflowsPublishResponse, zGetAppsByAppIdWorkflowsQuery, @@ -3787,13 +3803,132 @@ export const run10 = { } /** - * Get system variables for workflow + * Server-Sent Events stream of inspector deltas for a draft workflow run. * * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. * * @deprecated */ export const get59 = oc + .route({ + deprecated: true, + description: + 'Server-Sent Events stream of inspector deltas for a draft workflow run.\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEvents', + path: '/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/events', + tags: ['console'], + }) + .input(z.object({ params: zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsPath })) + .output(zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsResponse) + +export const events = { + get: get59, +} + +/** + * Full value for one declared output, including signed download URL for files. + * + * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. + * + * @deprecated + */ +export const get60 = oc + .route({ + deprecated: true, + description: + 'Full value for one declared output, including signed download URL for files.\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreview', + path: '/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/{node_id}/{output_name}/preview', + tags: ['console'], + }) + .input( + z.object({ + params: zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewPath, + }), + ) + .output(zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponse) + +export const preview3 = { + get: get60, +} + +export const byOutputName = { + preview: preview3, +} + +/** + * One node's declared outputs for a draft workflow run. + * + * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. + * + * @deprecated + */ +export const get61 = oc + .route({ + deprecated: true, + description: + 'One node\'s declared outputs for a draft workflow run.\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeId', + path: '/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/{node_id}', + tags: ['console'], + }) + .input(z.object({ params: zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdPath })) + .output(zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdResponse) + +export const byNodeId8 = { + get: get61, + byOutputName, +} + +/** + * Snapshot of every node's declared outputs for a draft workflow run. + * + * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. + * + * @deprecated + */ +export const get62 = oc + .route({ + deprecated: true, + description: + 'Snapshot of every node\'s declared outputs for a draft workflow run.\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputs', + path: '/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs', + tags: ['console'], + }) + .input(z.object({ params: zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsPath })) + .output(zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsResponse) + +export const nodeOutputs = { + get: get62, + events, + byNodeId: byNodeId8, +} + +export const byRunId2 = { + nodeOutputs, +} + +export const runs = { + byRunId: byRunId2, +} + +/** + * Get system variables for workflow + * + * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. + * + * @deprecated + */ +export const get63 = oc .route({ deprecated: true, description: @@ -3808,7 +3943,7 @@ export const get59 = oc .output(zGetAppsByAppIdWorkflowsDraftSystemVariablesResponse) export const systemVariables = { - get: get59, + get: get63, } /** @@ -3930,7 +4065,7 @@ export const delete9 = oc * * @deprecated */ -export const get60 = oc +export const get64 = oc .route({ deprecated: true, description: @@ -3972,7 +4107,7 @@ export const patch2 = oc export const byVariableId = { delete: delete9, - get: get60, + get: get64, patch: patch2, reset, } @@ -4002,7 +4137,7 @@ export const delete10 = oc * * @deprecated */ -export const get61 = oc +export const get65 = oc .route({ deprecated: true, description: @@ -4024,7 +4159,7 @@ export const get61 = oc export const variables2 = { delete: delete10, - get: get61, + get: get65, byVariableId, } @@ -4037,7 +4172,7 @@ export const variables2 = { * * @deprecated */ -export const get62 = oc +export const get66 = oc .route({ deprecated: true, description: @@ -4082,7 +4217,7 @@ export const post55 = oc .output(zPostAppsByAppIdWorkflowsDraftResponse) export const draft2 = { - get: get62, + get: get66, post: post55, conversationVariables: conversationVariables2, environmentVariables, @@ -4092,6 +4227,7 @@ export const draft2 = { loop: loop2, nodes: nodes7, run: run10, + runs, systemVariables, trigger: trigger2, variables: variables2, @@ -4106,7 +4242,7 @@ export const draft2 = { * * @deprecated */ -export const get63 = oc +export const get67 = oc .route({ deprecated: true, description: @@ -4149,10 +4285,137 @@ export const post56 = oc .output(zPostAppsByAppIdWorkflowsPublishResponse) export const publish = { - get: get63, + get: get67, post: post56, } +/** + * Server-Sent Events stream of inspector deltas for a published workflow run. + * + * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. + * + * @deprecated + */ +export const get68 = oc + .route({ + deprecated: true, + description: + 'Server-Sent Events stream of inspector deltas for a published workflow run.\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEvents', + path: '/apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/events', + tags: ['console'], + }) + .input(z.object({ params: zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsPath })) + .output(zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsResponse) + +export const events2 = { + get: get68, +} + +/** + * Full value for one declared output of a published run. + * + * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. + * + * @deprecated + */ +export const get69 = oc + .route({ + deprecated: true, + description: + 'Full value for one declared output of a published run.\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + inputStructure: 'detailed', + method: 'GET', + operationId: + 'getAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreview', + path: '/apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/{node_id}/{output_name}/preview', + tags: ['console'], + }) + .input( + z.object({ + params: + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewPath, + }), + ) + .output( + zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponse, + ) + +export const preview4 = { + get: get69, +} + +export const byOutputName2 = { + preview: preview4, +} + +/** + * One node's declared outputs for a published workflow run. + * + * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. + * + * @deprecated + */ +export const get70 = oc + .route({ + deprecated: true, + description: + 'One node\'s declared outputs for a published workflow run.\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeId', + path: '/apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/{node_id}', + tags: ['console'], + }) + .input(z.object({ params: zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdPath })) + .output(zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdResponse) + +export const byNodeId9 = { + get: get70, + byOutputName: byOutputName2, +} + +/** + * Snapshot of every node's declared outputs for a published workflow run. + * + * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. + * + * @deprecated + */ +export const get71 = oc + .route({ + deprecated: true, + description: + 'Snapshot of every node\'s declared outputs for a published workflow run.\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputs', + path: '/apps/{app_id}/workflows/published/runs/{run_id}/node-outputs', + tags: ['console'], + }) + .input(z.object({ params: zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsPath })) + .output(zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsResponse) + +export const nodeOutputs2 = { + get: get71, + events: events2, + byNodeId: byNodeId9, +} + +export const byRunId3 = { + nodeOutputs: nodeOutputs2, +} + +export const runs2 = { + byRunId: byRunId3, +} + +export const published = { + runs: runs2, +} + /** * Get webhook trigger for a node * @@ -4160,7 +4423,7 @@ export const publish = { * * @deprecated */ -export const get64 = oc +export const get72 = oc .route({ deprecated: true, description: @@ -4181,7 +4444,7 @@ export const get64 = oc .output(zGetAppsByAppIdWorkflowsTriggersWebhookResponse) export const webhook = { - get: get64, + get: get72, } export const triggers2 = { @@ -4279,7 +4542,7 @@ export const byWorkflowId = { * * @deprecated */ -export const get65 = oc +export const get73 = oc .route({ deprecated: true, description: @@ -4300,10 +4563,11 @@ export const get65 = oc .output(zGetAppsByAppIdWorkflowsResponse) export const workflows3 = { - get: get65, + get: get73, defaultWorkflowBlockConfigs, draft: draft2, publish, + published, triggers: triggers2, byWorkflowId, } @@ -4336,7 +4600,7 @@ export const delete12 = oc * * @deprecated */ -export const get66 = oc +export const get74 = oc .route({ deprecated: true, description: @@ -4377,7 +4641,7 @@ export const put7 = oc export const byAppId2 = { delete: delete12, - get: get66, + get: get74, put: put7, advancedChat, agentComposer, @@ -4446,7 +4710,7 @@ export const byApiKeyId = { * * Get all API keys for an app */ -export const get67 = oc +export const get75 = oc .route({ description: 'Get all API keys for an app', inputStructure: 'detailed', @@ -4479,7 +4743,7 @@ export const post58 = oc .output(zPostAppsByResourceIdApiKeysResponse) export const apiKeys = { - get: get67, + get: get75, post: post58, byApiKeyId, } @@ -4495,7 +4759,7 @@ export const byResourceId = { * * @deprecated */ -export const get68 = oc +export const get76 = oc .route({ deprecated: true, description: @@ -4510,7 +4774,7 @@ export const get68 = oc .output(zGetAppsByServerIdServerRefreshResponse) export const refresh = { - get: get68, + get: get76, } export const server2 = { @@ -4526,7 +4790,7 @@ export const byServerId = { * * Get list of applications with pagination and filtering */ -export const get69 = oc +export const get77 = oc .route({ description: 'Get list of applications with pagination and filtering', inputStructure: 'detailed', @@ -4565,7 +4829,7 @@ export const post59 = oc .output(zPostAppsResponse) export const apps = { - get: get69, + get: get77, post: post59, imports, workflows, diff --git a/packages/contracts/generated/api/console/apps/types.gen.ts b/packages/contracts/generated/api/console/apps/types.gen.ts index 71ad8486ee..c62f34bb44 100644 --- a/packages/contracts/generated/api/console/apps/types.gen.ts +++ b/packages/contracts/generated/api/console/apps/types.gen.ts @@ -1430,9 +1430,7 @@ export type AgentSoulToolsConfig = { cli_tools?: Array<{ [key: string]: unknown }> - dify_tools?: Array<{ - [key: string]: unknown - }> + dify_tools?: Array } export type DeclaredOutputConfig = { @@ -1525,6 +1523,22 @@ export type AgentSoulModelCredentialRef = { type: string } +export type AgentSoulDifyToolConfig = { + credential_ref?: AgentSoulDifyToolCredentialRef + credential_type?: 'api-key' | 'oauth2' | 'unauthorized' + description?: string | null + enabled?: boolean + name?: string | null + plugin_id?: string | null + provider?: string | null + provider_id?: string | null + provider_type?: string + runtime_parameters?: { + [key: string]: unknown + } + tool_name: string +} + export type DeclaredArrayItem = { description?: string | null type: DeclaredOutputType @@ -1562,6 +1576,12 @@ export type UserActionConfig = { export type FormInputConfig = unknown +export type AgentSoulDifyToolCredentialRef = { + id?: string | null + provider?: string | null + type?: 'provider' | 'tool' +} + export type OutputErrorStrategy = 'default_value' | 'fail_branch' | 'stop' export type DeclaredOutputRetryConfig = { @@ -4750,6 +4770,122 @@ export type PostAppsByAppIdWorkflowsDraftRunResponses = { export type PostAppsByAppIdWorkflowsDraftRunResponse = PostAppsByAppIdWorkflowsDraftRunResponses[keyof PostAppsByAppIdWorkflowsDraftRunResponses] +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsData = { + body?: never + path: { + app_id: string + run_id: string + } + query?: never + url: '/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs' +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsErrors = { + 404: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsError + = GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsErrors[keyof GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsErrors] + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsResponses = { + 200: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsResponse + = GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsResponses[keyof GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsResponses] + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsData = { + body?: never + path: { + app_id: string + run_id: string + } + query?: never + url: '/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/events' +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsErrors = { + 404: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsError + = GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsErrors[keyof GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsErrors] + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsResponses = { + 200: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsResponse + = GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsResponses[keyof GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsResponses] + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdData = { + body?: never + path: { + app_id: string + node_id: string + run_id: string + } + query?: never + url: '/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/{node_id}' +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdErrors = { + 404: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdError + = GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdErrors[keyof GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdErrors] + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdResponses = { + 200: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdResponse + = GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdResponses[keyof GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdResponses] + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewData = { + body?: never + path: { + app_id: string + node_id: string + output_name: string + run_id: string + } + query?: never + url: '/apps/{app_id}/workflows/draft/runs/{run_id}/node-outputs/{node_id}/{output_name}/preview' +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewErrors = { + 404: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewError + = GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewErrors[keyof GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewErrors] + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponses + = { + 200: { + [key: string]: unknown + } + } + +export type GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponse + = GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponses[keyof GetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponses] + export type GetAppsByAppIdWorkflowsDraftSystemVariablesData = { body?: never path: { @@ -5006,6 +5142,124 @@ export type PostAppsByAppIdWorkflowsPublishResponses = { export type PostAppsByAppIdWorkflowsPublishResponse = PostAppsByAppIdWorkflowsPublishResponses[keyof PostAppsByAppIdWorkflowsPublishResponses] +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsData = { + body?: never + path: { + app_id: string + run_id: string + } + query?: never + url: '/apps/{app_id}/workflows/published/runs/{run_id}/node-outputs' +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsErrors = { + 404: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsError + = GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsErrors[keyof GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsErrors] + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsResponses = { + 200: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsResponse + = GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsResponses[keyof GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsResponses] + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsData = { + body?: never + path: { + app_id: string + run_id: string + } + query?: never + url: '/apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/events' +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsErrors = { + 404: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsError + = GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsErrors[keyof GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsErrors] + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsResponses = { + 200: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsResponse + = GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsResponses[keyof GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsResponses] + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdData = { + body?: never + path: { + app_id: string + node_id: string + run_id: string + } + query?: never + url: '/apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/{node_id}' +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdErrors = { + 404: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdError + = GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdErrors[keyof GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdErrors] + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdResponses = { + 200: { + [key: string]: unknown + } +} + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdResponse + = GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdResponses[keyof GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdResponses] + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewData + = { + body?: never + path: { + app_id: string + node_id: string + output_name: string + run_id: string + } + query?: never + url: '/apps/{app_id}/workflows/published/runs/{run_id}/node-outputs/{node_id}/{output_name}/preview' + } + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewErrors + = { + 404: { + [key: string]: unknown + } + } + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewError + = GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewErrors[keyof GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewErrors] + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponses + = { + 200: { + [key: string]: unknown + } + } + +export type GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponse + = GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponses[keyof GetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponses] + export type GetAppsByAppIdWorkflowsTriggersWebhookData = { body?: never path: { diff --git a/packages/contracts/generated/api/console/apps/zod.gen.ts b/packages/contracts/generated/api/console/apps/zod.gen.ts index 445127b966..2016a802cf 100644 --- a/packages/contracts/generated/api/console/apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/apps/zod.gen.ts @@ -1533,14 +1533,6 @@ export const zAgentSoulSkillsFilesConfig = z.object({ skills: z.array(z.record(z.string(), z.unknown())).optional(), }) -/** - * AgentSoulToolsConfig - */ -export const zAgentSoulToolsConfig = z.object({ - cli_tools: z.array(z.record(z.string(), z.unknown())).optional(), - dify_tools: z.array(z.record(z.string(), z.unknown())).optional(), -}) - /** * WorkflowNodeJobMode */ @@ -1771,25 +1763,6 @@ export const zAgentSoulModelConfig = z.object({ plugin_id: z.string().min(1).max(255), }) -/** - * AgentSoulConfig - */ -export const zAgentSoulConfig = z.object({ - app_features: z.record(z.string(), z.unknown()).optional(), - app_variables: z.array(zAppVariableConfig).optional(), - env: zAgentSoulEnvConfig.optional(), - human: zAgentSoulHumanConfig.optional(), - knowledge: zAgentSoulKnowledgeConfig.optional(), - memory: zAgentSoulMemoryConfig.optional(), - misc_legacy: z.record(z.string(), z.unknown()).optional(), - model: zAgentSoulModelConfig.optional(), - prompt: zAgentSoulPromptConfig.optional(), - sandbox: zAgentSoulSandboxConfig.optional(), - schema_version: z.int().optional().default(1), - skills_files: zAgentSoulSkillsFilesConfig.optional(), - tools: zAgentSoulToolsConfig.optional(), -}) - /** * DeclaredOutputCheckConfig * @@ -1842,6 +1815,72 @@ export const zDeclaredArrayItem = z.object({ export const zFormInputConfig = z.unknown() +/** + * AgentSoulDifyToolCredentialRef + * + * Reference to a stored Dify Plugin Tool credential. + * + * Secret values are resolved only at runtime. The legacy ``credential_id`` + * field is accepted by :class:`AgentSoulDifyToolConfig` and normalized here so + * old Agent tool payloads can be read while new payloads stay explicit. + */ +export const zAgentSoulDifyToolCredentialRef = z.object({ + id: z.string().max(255).nullish(), + provider: z.string().max(255).nullish(), + type: z.enum(['provider', 'tool']).optional().default('tool'), +}) + +/** + * AgentSoulDifyToolConfig + * + * One Dify Plugin Tool configured on Agent Soul. + * + * The API backend prepares this persisted product shape into + * ``DifyPluginToolConfig`` before sending a run request to Agent backend. + * ``provider_id`` keeps compatibility with existing Agent tool config payloads; + * new callers should send ``plugin_id`` + ``provider`` when available. + */ +export const zAgentSoulDifyToolConfig = z.object({ + credential_ref: zAgentSoulDifyToolCredentialRef.optional(), + credential_type: z.enum(['api-key', 'oauth2', 'unauthorized']).optional().default('api-key'), + description: z.string().nullish(), + enabled: z.boolean().optional().default(true), + name: z.string().max(255).nullish(), + plugin_id: z.string().max(255).nullish(), + provider: z.string().max(255).nullish(), + provider_id: z.string().max(255).nullish(), + provider_type: z.string().optional().default('plugin'), + runtime_parameters: z.record(z.string(), z.unknown()).optional(), + tool_name: z.string().min(1).max(255), +}) + +/** + * AgentSoulToolsConfig + */ +export const zAgentSoulToolsConfig = z.object({ + cli_tools: z.array(z.record(z.string(), z.unknown())).optional(), + dify_tools: z.array(zAgentSoulDifyToolConfig).optional(), +}) + +/** + * AgentSoulConfig + */ +export const zAgentSoulConfig = z.object({ + app_features: z.record(z.string(), z.unknown()).optional(), + app_variables: z.array(zAppVariableConfig).optional(), + env: zAgentSoulEnvConfig.optional(), + human: zAgentSoulHumanConfig.optional(), + knowledge: zAgentSoulKnowledgeConfig.optional(), + memory: zAgentSoulMemoryConfig.optional(), + misc_legacy: z.record(z.string(), z.unknown()).optional(), + model: zAgentSoulModelConfig.optional(), + prompt: zAgentSoulPromptConfig.optional(), + sandbox: zAgentSoulSandboxConfig.optional(), + schema_version: z.int().optional().default(1), + skills_files: zAgentSoulSkillsFilesConfig.optional(), + tools: zAgentSoulToolsConfig.optional(), +}) + /** * OutputErrorStrategy * @@ -3834,6 +3873,60 @@ export const zPostAppsByAppIdWorkflowsDraftRunPath = z.object({ */ export const zPostAppsByAppIdWorkflowsDraftRunResponse = z.record(z.string(), z.unknown()) +export const zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsPath = z.object({ + app_id: z.string(), + run_id: z.string(), +}) + +/** + * Success + */ +export const zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsResponse = z.record( + z.string(), + z.unknown(), +) + +export const zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsPath = z.object({ + app_id: z.string(), + run_id: z.string(), +}) + +/** + * Success + */ +export const zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsEventsResponse = z.record( + z.string(), + z.unknown(), +) + +export const zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdPath = z.object({ + app_id: z.string(), + node_id: z.string(), + run_id: z.string(), +}) + +/** + * Success + */ +export const zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdResponse = z.record( + z.string(), + z.unknown(), +) + +export const zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewPath + = z.object({ + app_id: z.string(), + node_id: z.string(), + output_name: z.string(), + run_id: z.string(), + }) + +/** + * Success + */ +export const zGetAppsByAppIdWorkflowsDraftRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponse + = z.record(z.string(), z.unknown()) + export const zGetAppsByAppIdWorkflowsDraftSystemVariablesPath = z.object({ app_id: z.string(), }) @@ -3954,6 +4047,60 @@ export const zPostAppsByAppIdWorkflowsPublishPath = z.object({ */ export const zPostAppsByAppIdWorkflowsPublishResponse = z.record(z.string(), z.unknown()) +export const zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsPath = z.object({ + app_id: z.string(), + run_id: z.string(), +}) + +/** + * Success + */ +export const zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsResponse = z.record( + z.string(), + z.unknown(), +) + +export const zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsPath = z.object({ + app_id: z.string(), + run_id: z.string(), +}) + +/** + * Success + */ +export const zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsEventsResponse = z.record( + z.string(), + z.unknown(), +) + +export const zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdPath = z.object({ + app_id: z.string(), + node_id: z.string(), + run_id: z.string(), +}) + +/** + * Success + */ +export const zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdResponse = z.record( + z.string(), + z.unknown(), +) + +export const zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewPath + = z.object({ + app_id: z.string(), + node_id: z.string(), + output_name: z.string(), + run_id: z.string(), + }) + +/** + * Success + */ +export const zGetAppsByAppIdWorkflowsPublishedRunsByRunIdNodeOutputsByNodeIdByOutputNamePreviewResponse + = z.record(z.string(), z.unknown()) + export const zGetAppsByAppIdWorkflowsTriggersWebhookPath = z.object({ app_id: z.string(), }) diff --git a/packages/contracts/generated/api/console/datasets/types.gen.ts b/packages/contracts/generated/api/console/datasets/types.gen.ts index 6c96c80d4a..938331f9c9 100644 --- a/packages/contracts/generated/api/console/datasets/types.gen.ts +++ b/packages/contracts/generated/api/console/datasets/types.gen.ts @@ -396,8 +396,8 @@ export type HitTestingPayload = { } export type HitTestingResponse = { - query: string - records?: Array + query: HitTestingQuery + records: Array } export type DocumentStatusListResponse = { @@ -666,13 +666,17 @@ export type DocumentStatusResponse = { total_segments?: number | null } +export type HitTestingQuery = { + content: string +} + export type HitTestingRecord = { - child_chunks?: Array - files?: Array - score?: number | null - segment?: HitTestingSegment - summary?: string | null - tsne_position?: unknown + child_chunks: Array + files: Array + score: number | null + segment: HitTestingSegment + summary: string | null + tsne_position: unknown } export type DatasetMetadataListItemResponse = { @@ -768,45 +772,45 @@ export type MetadataDetail = { } export type HitTestingChildChunk = { - content?: string | null - id?: string | null - position?: number | null - score?: number | null + content: string + id: string + position: number + score: number } export type HitTestingFile = { - extension?: string | null - id?: string | null - mime_type?: string | null - name?: string | null - size?: number | null - source_url?: string | null + extension: string + id: string + mime_type: string + name: string + size: number + source_url: string } export type HitTestingSegment = { - answer?: string | null - completed_at?: number | null - content?: string | null - created_at?: number | null - created_by?: string | null - disabled_at?: number | null - disabled_by?: string | null - document?: HitTestingDocument - document_id?: string | null - enabled?: boolean | null - error?: string | null - hit_count?: number | null - id?: string | null - index_node_hash?: string | null - index_node_id?: string | null - indexing_at?: number | null - keywords?: Array - position?: number | null - sign_content?: string | null - status?: string | null - stopped_at?: number | null - tokens?: number | null - word_count?: number | null + answer: string | null + completed_at: number | null + content: string + created_at: number + created_by: string + disabled_at: number | null + disabled_by: string | null + document: HitTestingDocument + document_id: string + enabled: boolean + error: string | null + hit_count: number + id: string + index_node_hash: string | null + index_node_id: string | null + indexing_at: number | null + keywords: Array + position: number + sign_content: string | null + status: string + stopped_at: number | null + tokens: number + word_count: number } export type DatasetQueryContentResponse = { @@ -898,11 +902,11 @@ export type WeightVectorSetting = { } export type HitTestingDocument = { - data_source_type?: string | null - doc_metadata?: unknown - doc_type?: string | null - id?: string | null - name?: string | null + data_source_type: string + doc_metadata: unknown + doc_type: string | null + id: string + name: string } export type DatasetQueryFileInfoResponse = { diff --git a/packages/contracts/generated/api/console/datasets/zod.gen.ts b/packages/contracts/generated/api/console/datasets/zod.gen.ts index f795d17f2f..082695d39d 100644 --- a/packages/contracts/generated/api/console/datasets/zod.gen.ts +++ b/packages/contracts/generated/api/console/datasets/zod.gen.ts @@ -530,6 +530,13 @@ export const zDocumentStatusListResponse = z.object({ data: z.array(zDocumentStatusResponse), }) +/** + * HitTestingQuery + */ +export const zHitTestingQuery = z.object({ + content: z.string(), +}) + /** * DatasetMetadataListItemResponse */ @@ -632,22 +639,22 @@ export const zMetadataOperationData = z.object({ * HitTestingChildChunk */ export const zHitTestingChildChunk = z.object({ - content: z.string().nullish(), - id: z.string().nullish(), - position: z.int().nullish(), - score: z.number().nullish(), + content: z.string(), + id: z.string(), + position: z.int(), + score: z.number(), }) /** * HitTestingFile */ export const zHitTestingFile = z.object({ - extension: z.string().nullish(), - id: z.string().nullish(), - mime_type: z.string().nullish(), - name: z.string().nullish(), - size: z.int().nullish(), - source_url: z.string().nullish(), + extension: z.string(), + id: z.string(), + mime_type: z.string(), + name: z.string(), + size: z.int(), + source_url: z.string(), }) /** @@ -1036,60 +1043,60 @@ export const zHitTestingPayload = z.object({ * HitTestingDocument */ export const zHitTestingDocument = z.object({ - data_source_type: z.string().nullish(), - doc_metadata: z.unknown().optional(), - doc_type: z.string().nullish(), - id: z.string().nullish(), - name: z.string().nullish(), + data_source_type: z.string(), + doc_metadata: z.unknown(), + doc_type: z.string().nullable(), + id: z.string(), + name: z.string(), }) /** * HitTestingSegment */ export const zHitTestingSegment = z.object({ - answer: z.string().nullish(), - completed_at: z.int().nullish(), - content: z.string().nullish(), - created_at: z.int().nullish(), - created_by: z.string().nullish(), - disabled_at: z.int().nullish(), - disabled_by: z.string().nullish(), - document: zHitTestingDocument.optional(), - document_id: z.string().nullish(), - enabled: z.boolean().nullish(), - error: z.string().nullish(), - hit_count: z.int().nullish(), - id: z.string().nullish(), - index_node_hash: z.string().nullish(), - index_node_id: z.string().nullish(), - indexing_at: z.int().nullish(), - keywords: z.array(z.string()).optional(), - position: z.int().nullish(), - sign_content: z.string().nullish(), - status: z.string().nullish(), - stopped_at: z.int().nullish(), - tokens: z.int().nullish(), - word_count: z.int().nullish(), + answer: z.string().nullable(), + completed_at: z.int().nullable(), + content: z.string(), + created_at: z.int(), + created_by: z.string(), + disabled_at: z.int().nullable(), + disabled_by: z.string().nullable(), + document: zHitTestingDocument, + document_id: z.string(), + enabled: z.boolean(), + error: z.string().nullable(), + hit_count: z.int(), + id: z.string(), + index_node_hash: z.string().nullable(), + index_node_id: z.string().nullable(), + indexing_at: z.int().nullable(), + keywords: z.array(z.string()), + position: z.int(), + sign_content: z.string().nullable(), + status: z.string(), + stopped_at: z.int().nullable(), + tokens: z.int(), + word_count: z.int(), }) /** * HitTestingRecord */ export const zHitTestingRecord = z.object({ - child_chunks: z.array(zHitTestingChildChunk).optional(), - files: z.array(zHitTestingFile).optional(), - score: z.number().nullish(), - segment: zHitTestingSegment.optional(), - summary: z.string().nullish(), - tsne_position: z.unknown().optional(), + child_chunks: z.array(zHitTestingChildChunk), + files: z.array(zHitTestingFile), + score: z.number().nullable(), + segment: zHitTestingSegment, + summary: z.string().nullable(), + tsne_position: z.unknown(), }) /** * HitTestingResponse */ export const zHitTestingResponse = z.object({ - query: z.string(), - records: z.array(zHitTestingRecord).optional(), + query: zHitTestingQuery, + records: z.array(zHitTestingRecord), }) /** diff --git a/packages/contracts/generated/api/openapi/orpc.gen.ts b/packages/contracts/generated/api/openapi/orpc.gen.ts index c445ab877b..c837f6258e 100644 --- a/packages/contracts/generated/api/openapi/orpc.gen.ts +++ b/packages/contracts/generated/api/openapi/orpc.gen.ts @@ -7,6 +7,8 @@ import { zDeleteAccountSessionsBySessionIdPath, zDeleteAccountSessionsBySessionIdResponse, zDeleteAccountSessionsSelfResponse, + zDeleteWorkspacesByWorkspaceIdMembersByMemberIdPath, + zDeleteWorkspacesByWorkspaceIdMembersByMemberIdResponse, zGetAccountResponse, zGetAccountSessionsResponse, zGetAppsByAppIdDescribePath, @@ -23,6 +25,9 @@ import { zGetOauthDeviceLookupResponse, zGetPermittedExternalAppsResponse, zGetVersionResponse, + zGetWorkspacesByWorkspaceIdMembersPath, + zGetWorkspacesByWorkspaceIdMembersQuery, + zGetWorkspacesByWorkspaceIdMembersResponse, zGetWorkspacesByWorkspaceIdPath, zGetWorkspacesByWorkspaceIdResponse, zGetWorkspacesResponse, @@ -44,6 +49,14 @@ import { zPostOauthDeviceDenyResponse, zPostOauthDeviceTokenBody, zPostOauthDeviceTokenResponse, + zPostWorkspacesByWorkspaceIdMembersBody, + zPostWorkspacesByWorkspaceIdMembersPath, + zPostWorkspacesByWorkspaceIdMembersResponse, + zPostWorkspacesByWorkspaceIdSwitchPath, + zPostWorkspacesByWorkspaceIdSwitchResponse, + zPutWorkspacesByWorkspaceIdMembersByMemberIdRoleBody, + zPutWorkspacesByWorkspaceIdMembersByMemberIdRolePath, + zPutWorkspacesByWorkspaceIdMembersByMemberIdRoleResponse, } from './zod.gen' /** @@ -461,7 +474,97 @@ export const permittedExternalApps = { get: get10, } +export const put = oc + .route({ + inputStructure: 'detailed', + method: 'PUT', + operationId: 'putWorkspacesByWorkspaceIdMembersByMemberIdRole', + path: '/workspaces/{workspace_id}/members/{member_id}/role', + tags: ['openapi'], + }) + .input( + z.object({ + body: zPutWorkspacesByWorkspaceIdMembersByMemberIdRoleBody, + params: zPutWorkspacesByWorkspaceIdMembersByMemberIdRolePath, + }), + ) + .output(zPutWorkspacesByWorkspaceIdMembersByMemberIdRoleResponse) + +export const role = { + put, +} + +export const delete3 = oc + .route({ + inputStructure: 'detailed', + method: 'DELETE', + operationId: 'deleteWorkspacesByWorkspaceIdMembersByMemberId', + path: '/workspaces/{workspace_id}/members/{member_id}', + tags: ['openapi'], + }) + .input(z.object({ params: zDeleteWorkspacesByWorkspaceIdMembersByMemberIdPath })) + .output(zDeleteWorkspacesByWorkspaceIdMembersByMemberIdResponse) + +export const byMemberId = { + delete: delete3, + role, +} + export const get11 = oc + .route({ + inputStructure: 'detailed', + method: 'GET', + operationId: 'getWorkspacesByWorkspaceIdMembers', + path: '/workspaces/{workspace_id}/members', + tags: ['openapi'], + }) + .input( + z.object({ + params: zGetWorkspacesByWorkspaceIdMembersPath, + query: zGetWorkspacesByWorkspaceIdMembersQuery.optional(), + }), + ) + .output(zGetWorkspacesByWorkspaceIdMembersResponse) + +export const post9 = oc + .route({ + inputStructure: 'detailed', + method: 'POST', + operationId: 'postWorkspacesByWorkspaceIdMembers', + path: '/workspaces/{workspace_id}/members', + successStatus: 201, + tags: ['openapi'], + }) + .input( + z.object({ + body: zPostWorkspacesByWorkspaceIdMembersBody, + params: zPostWorkspacesByWorkspaceIdMembersPath, + }), + ) + .output(zPostWorkspacesByWorkspaceIdMembersResponse) + +export const members = { + get: get11, + post: post9, + byMemberId, +} + +export const post10 = oc + .route({ + inputStructure: 'detailed', + method: 'POST', + operationId: 'postWorkspacesByWorkspaceIdSwitch', + path: '/workspaces/{workspace_id}/switch', + tags: ['openapi'], + }) + .input(z.object({ params: zPostWorkspacesByWorkspaceIdSwitchPath })) + .output(zPostWorkspacesByWorkspaceIdSwitchResponse) + +export const switch_ = { + post: post10, +} + +export const get12 = oc .route({ inputStructure: 'detailed', method: 'GET', @@ -473,10 +576,12 @@ export const get11 = oc .output(zGetWorkspacesByWorkspaceIdResponse) export const byWorkspaceId = { - get: get11, + get: get12, + members, + switch: switch_, } -export const get12 = oc +export const get13 = oc .route({ inputStructure: 'detailed', method: 'GET', @@ -487,7 +592,7 @@ export const get12 = oc .output(zGetWorkspacesResponse) export const workspaces = { - get: get12, + get: get13, byWorkspaceId, } diff --git a/packages/contracts/generated/api/openapi/types.gen.ts b/packages/contracts/generated/api/openapi/types.gen.ts index b0ce8f427b..194ec6f363 100644 --- a/packages/contracts/generated/api/openapi/types.gen.ts +++ b/packages/contracts/generated/api/openapi/types.gen.ts @@ -169,6 +169,50 @@ export type HumanInputFormSubmitPayload = { export type JsonValue = unknown +export type MemberActionResponse = { + result?: string +} + +export type MemberInvitePayload = { + email: string + role: 'admin' | 'normal' +} + +export type MemberInviteResponse = { + email: string + invite_url: string + member_id: string + result?: string + role: string + tenant_id: string +} + +export type MemberListQuery = { + limit?: number + page?: number +} + +export type MemberListResponse = { + data: Array + has_more: boolean + limit: number + page: number + total: number +} + +export type MemberResponse = { + avatar?: string | null + email: string + id: string + name: string + role: string + status: string +} + +export type MemberRoleUpdatePayload = { + role: 'admin' | 'normal' +} + export type MessageMetadata = { retriever_resources?: Array<{ [key: string]: unknown @@ -638,3 +682,88 @@ export type GetWorkspacesByWorkspaceIdResponses = { export type GetWorkspacesByWorkspaceIdResponse = GetWorkspacesByWorkspaceIdResponses[keyof GetWorkspacesByWorkspaceIdResponses] + +export type GetWorkspacesByWorkspaceIdMembersData = { + body?: never + path: { + workspace_id: string + } + query?: { + limit?: number + page?: number + } + url: '/workspaces/{workspace_id}/members' +} + +export type GetWorkspacesByWorkspaceIdMembersResponses = { + 200: MemberListResponse +} + +export type GetWorkspacesByWorkspaceIdMembersResponse + = GetWorkspacesByWorkspaceIdMembersResponses[keyof GetWorkspacesByWorkspaceIdMembersResponses] + +export type PostWorkspacesByWorkspaceIdMembersData = { + body: MemberInvitePayload + path: { + workspace_id: string + } + query?: never + url: '/workspaces/{workspace_id}/members' +} + +export type PostWorkspacesByWorkspaceIdMembersResponses = { + 201: MemberInviteResponse +} + +export type PostWorkspacesByWorkspaceIdMembersResponse + = PostWorkspacesByWorkspaceIdMembersResponses[keyof PostWorkspacesByWorkspaceIdMembersResponses] + +export type DeleteWorkspacesByWorkspaceIdMembersByMemberIdData = { + body?: never + path: { + member_id: string + workspace_id: string + } + query?: never + url: '/workspaces/{workspace_id}/members/{member_id}' +} + +export type DeleteWorkspacesByWorkspaceIdMembersByMemberIdResponses = { + 200: MemberActionResponse +} + +export type DeleteWorkspacesByWorkspaceIdMembersByMemberIdResponse + = DeleteWorkspacesByWorkspaceIdMembersByMemberIdResponses[keyof DeleteWorkspacesByWorkspaceIdMembersByMemberIdResponses] + +export type PutWorkspacesByWorkspaceIdMembersByMemberIdRoleData = { + body: MemberRoleUpdatePayload + path: { + member_id: string + workspace_id: string + } + query?: never + url: '/workspaces/{workspace_id}/members/{member_id}/role' +} + +export type PutWorkspacesByWorkspaceIdMembersByMemberIdRoleResponses = { + 200: MemberActionResponse +} + +export type PutWorkspacesByWorkspaceIdMembersByMemberIdRoleResponse + = PutWorkspacesByWorkspaceIdMembersByMemberIdRoleResponses[keyof PutWorkspacesByWorkspaceIdMembersByMemberIdRoleResponses] + +export type PostWorkspacesByWorkspaceIdSwitchData = { + body?: never + path: { + workspace_id: string + } + query?: never + url: '/workspaces/{workspace_id}/switch' +} + +export type PostWorkspacesByWorkspaceIdSwitchResponses = { + 200: WorkspaceDetailResponse +} + +export type PostWorkspacesByWorkspaceIdSwitchResponse + = PostWorkspacesByWorkspaceIdSwitchResponses[keyof PostWorkspacesByWorkspaceIdSwitchResponses] diff --git a/packages/contracts/generated/api/openapi/zod.gen.ts b/packages/contracts/generated/api/openapi/zod.gen.ts index 6f76b2a6d7..a98f0e3a86 100644 --- a/packages/contracts/generated/api/openapi/zod.gen.ts +++ b/packages/contracts/generated/api/openapi/zod.gen.ts @@ -150,6 +150,73 @@ export const zHumanInputFormSubmitPayload = z.object({ inputs: z.record(z.string(), zJsonValue), }) +/** + * MemberActionResponse + */ +export const zMemberActionResponse = z.object({ + result: z.string().optional().default('success'), +}) + +/** + * MemberInvitePayload + */ +export const zMemberInvitePayload = z.object({ + email: z.string(), + role: z.enum(['admin', 'normal']), +}) + +/** + * MemberInviteResponse + */ +export const zMemberInviteResponse = z.object({ + email: z.string(), + invite_url: z.string(), + member_id: z.string(), + result: z.string().optional().default('success'), + role: z.string(), + tenant_id: z.string(), +}) + +/** + * MemberListQuery + * + * Strict (extra='forbid'). + */ +export const zMemberListQuery = z.object({ + limit: z.int().gte(1).lte(200).optional().default(20), + page: z.int().gte(1).optional().default(1), +}) + +/** + * MemberResponse + */ +export const zMemberResponse = z.object({ + avatar: z.string().nullish(), + email: z.string(), + id: z.string(), + name: z.string(), + role: z.string(), + status: z.string(), +}) + +/** + * MemberListResponse + */ +export const zMemberListResponse = z.object({ + data: z.array(zMemberResponse), + has_more: z.boolean(), + limit: z.int(), + page: z.int(), + total: z.int(), +}) + +/** + * MemberRoleUpdatePayload + */ +export const zMemberRoleUpdatePayload = z.object({ + role: z.enum(['admin', 'normal']), +}) + /** * PermittedExternalAppsListQuery * @@ -546,3 +613,59 @@ export const zGetWorkspacesByWorkspaceIdPath = z.object({ * Workspace detail */ export const zGetWorkspacesByWorkspaceIdResponse = zWorkspaceDetailResponse + +export const zGetWorkspacesByWorkspaceIdMembersPath = z.object({ + workspace_id: z.string(), +}) + +export const zGetWorkspacesByWorkspaceIdMembersQuery = z.object({ + limit: z.int().gte(1).lte(200).optional().default(20), + page: z.int().gte(1).optional().default(1), +}) + +/** + * Member list + */ +export const zGetWorkspacesByWorkspaceIdMembersResponse = zMemberListResponse + +export const zPostWorkspacesByWorkspaceIdMembersBody = zMemberInvitePayload + +export const zPostWorkspacesByWorkspaceIdMembersPath = z.object({ + workspace_id: z.string(), +}) + +/** + * Member invited + */ +export const zPostWorkspacesByWorkspaceIdMembersResponse = zMemberInviteResponse + +export const zDeleteWorkspacesByWorkspaceIdMembersByMemberIdPath = z.object({ + member_id: z.string(), + workspace_id: z.string(), +}) + +/** + * Member removed + */ +export const zDeleteWorkspacesByWorkspaceIdMembersByMemberIdResponse = zMemberActionResponse + +export const zPutWorkspacesByWorkspaceIdMembersByMemberIdRoleBody = zMemberRoleUpdatePayload + +export const zPutWorkspacesByWorkspaceIdMembersByMemberIdRolePath = z.object({ + member_id: z.string(), + workspace_id: z.string(), +}) + +/** + * Role updated + */ +export const zPutWorkspacesByWorkspaceIdMembersByMemberIdRoleResponse = zMemberActionResponse + +export const zPostWorkspacesByWorkspaceIdSwitchPath = z.object({ + workspace_id: z.string(), +}) + +/** + * Workspace detail + */ +export const zPostWorkspacesByWorkspaceIdSwitchResponse = zWorkspaceDetailResponse diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index 54ce811a95..cd84f94d81 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -469,6 +469,30 @@ export type FileResponse = { user_id?: string | null } +export type HitTestingChildChunk = { + content: string + id: string + position: number + score: number +} + +export type HitTestingDocument = { + data_source_type: string + doc_metadata: unknown + doc_type: string | null + id: string + name: string +} + +export type HitTestingFile = { + extension: string + id: string + mime_type: string + name: string + size: number + source_url: string +} + export type HitTestingPayload = { attachment_ids?: Array | null external_retrieval_model?: { @@ -478,6 +502,50 @@ export type HitTestingPayload = { retrieval_model?: RetrievalModel } +export type HitTestingQuery = { + content: string +} + +export type HitTestingRecord = { + child_chunks: Array + files: Array + score: number | null + segment: HitTestingSegment + summary: string | null + tsne_position: unknown +} + +export type HitTestingResponse = { + query: HitTestingQuery + records: Array +} + +export type HitTestingSegment = { + answer: string | null + completed_at: number | null + content: string + created_at: number + created_by: string + disabled_at: number | null + disabled_by: string | null + document: HitTestingDocument + document_id: string + enabled: boolean + error: string | null + hit_count: number + id: string + index_node_hash: string | null + index_node_id: string | null + indexing_at: number | null + keywords: Array + position: number + sign_content: string | null + status: string + stopped_at: number | null + tokens: number + word_count: number +} + export type HumanInputFormSubmitPayload = { action: string inputs: { @@ -2510,9 +2578,7 @@ export type PostDatasetsByDatasetIdHitTestingError = PostDatasetsByDatasetIdHitTestingErrors[keyof PostDatasetsByDatasetIdHitTestingErrors] export type PostDatasetsByDatasetIdHitTestingResponses = { - 200: { - [key: string]: unknown - } + 200: HitTestingResponse } export type PostDatasetsByDatasetIdHitTestingResponse @@ -2794,9 +2860,7 @@ export type PostDatasetsByDatasetIdRetrieveError = PostDatasetsByDatasetIdRetrieveErrors[keyof PostDatasetsByDatasetIdRetrieveErrors] export type PostDatasetsByDatasetIdRetrieveResponses = { - 200: { - [key: string]: unknown - } + 200: HitTestingResponse } export type PostDatasetsByDatasetIdRetrieveResponse diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index 22e4b24721..e3008ddfbf 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -553,6 +553,95 @@ export const zFileResponse = z.object({ user_id: z.string().nullish(), }) +/** + * HitTestingChildChunk + */ +export const zHitTestingChildChunk = z.object({ + content: z.string(), + id: z.string(), + position: z.int(), + score: z.number(), +}) + +/** + * HitTestingDocument + */ +export const zHitTestingDocument = z.object({ + data_source_type: z.string(), + doc_metadata: z.unknown(), + doc_type: z.string().nullable(), + id: z.string(), + name: z.string(), +}) + +/** + * HitTestingFile + */ +export const zHitTestingFile = z.object({ + extension: z.string(), + id: z.string(), + mime_type: z.string(), + name: z.string(), + size: z.int(), + source_url: z.string(), +}) + +/** + * HitTestingQuery + */ +export const zHitTestingQuery = z.object({ + content: z.string(), +}) + +/** + * HitTestingSegment + */ +export const zHitTestingSegment = z.object({ + answer: z.string().nullable(), + completed_at: z.int().nullable(), + content: z.string(), + created_at: z.int(), + created_by: z.string(), + disabled_at: z.int().nullable(), + disabled_by: z.string().nullable(), + document: zHitTestingDocument, + document_id: z.string(), + enabled: z.boolean(), + error: z.string().nullable(), + hit_count: z.int(), + id: z.string(), + index_node_hash: z.string().nullable(), + index_node_id: z.string().nullable(), + indexing_at: z.int().nullable(), + keywords: z.array(z.string()), + position: z.int(), + sign_content: z.string().nullable(), + status: z.string(), + stopped_at: z.int().nullable(), + tokens: z.int(), + word_count: z.int(), +}) + +/** + * HitTestingRecord + */ +export const zHitTestingRecord = z.object({ + child_chunks: z.array(zHitTestingChildChunk), + files: z.array(zHitTestingFile), + score: z.number().nullable(), + segment: zHitTestingSegment, + summary: z.string().nullable(), + tsne_position: z.unknown(), +}) + +/** + * HitTestingResponse + */ +export const zHitTestingResponse = z.object({ + query: zHitTestingQuery, + records: z.array(zHitTestingRecord), +}) + /** * IndexInfoResponse */ @@ -1720,7 +1809,7 @@ export const zPostDatasetsByDatasetIdHitTestingPath = z.object({ /** * Hit testing results */ -export const zPostDatasetsByDatasetIdHitTestingResponse = z.record(z.string(), z.unknown()) +export const zPostDatasetsByDatasetIdHitTestingResponse = zHitTestingResponse export const zGetDatasetsByDatasetIdMetadataPath = z.object({ dataset_id: z.string(), @@ -1834,7 +1923,7 @@ export const zPostDatasetsByDatasetIdRetrievePath = z.object({ /** * Hit testing results */ -export const zPostDatasetsByDatasetIdRetrieveResponse = z.record(z.string(), z.unknown()) +export const zPostDatasetsByDatasetIdRetrieveResponse = zHitTestingResponse export const zGetDatasetsByDatasetIdTagsPath = z.object({ dataset_id: z.string(), diff --git a/packages/dify-ui/README.md b/packages/dify-ui/README.md index dbf6dbac17..157f6eb752 100644 --- a/packages/dify-ui/README.md +++ b/packages/dify-ui/README.md @@ -32,8 +32,10 @@ import { Dialog, DialogContent, DialogTrigger } from '@langgenius/dify-ui/dialog import { Drawer, DrawerPopup, DrawerTrigger } from '@langgenius/dify-ui/drawer' import { FieldControl, FieldLabel, FieldRoot } from '@langgenius/dify-ui/field' import { Form } from '@langgenius/dify-ui/form' +import { Kbd, KbdGroup } from '@langgenius/dify-ui/kbd' import { Popover, PopoverContent, PopoverTrigger } from '@langgenius/dify-ui/popover' import { SegmentedControl, SegmentedControlItem } from '@langgenius/dify-ui/segmented-control' +import { Textarea } from '@langgenius/dify-ui/textarea' import '@langgenius/dify-ui/styles.css' // once, in the app root ``` @@ -41,17 +43,18 @@ Importing from `@langgenius/dify-ui` (no subpath) is intentionally not supported ## Primitives -| Category | Subpath | Notes | -| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------ | -| Actions | `./button` | Design-system CTA primitive with `cva` variants. | -| Controls | `./segmented-control` | SegmentedControl for mode, filter, and view selection. | -| Feedback | `./meter`, `./toast` | Meter is inline status; Toast owns the `z-60` layer. | -| Form | `./form`, `./field`, `./fieldset`, `./input`, `./checkbox`, `./checkbox-group`, `./radio`, `./radio-group`, `./number-field`, `./select`, `./slider`, `./switch` | Native form boundary, field semantics, and controls. | -| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. | -| Media | `./avatar` | Avatar root, image, and fallback primitives. | -| Navigation | `./pagination`, `./tabs` | Pagination for page navigation; Tabs for panels. | -| Overlay / menu | `./alert-dialog`, `./context-menu`, `./dialog`, `./drawer`, `./dropdown-menu`, `./popover`, `./preview-card`, `./tooltip` | Portalled. See [Overlay & portal contract] below. | -| Search / pickers | `./autocomplete`, `./combobox`, `./select` | Search input, searchable picker, and closed picker. | +| Category | Subpath | Notes | +| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------ | +| Actions | `./button` | Design-system CTA primitive with `cva` variants. | +| Controls | `./segmented-control` | SegmentedControl for mode, filter, and view selection. | +| Display | `./kbd` | Keyboard input and shortcut keycap primitives. | +| Feedback | `./meter`, `./toast` | Meter is inline status; Toast owns the `z-60` layer. | +| Form | `./form`, `./field`, `./fieldset`, `./input`, `./textarea`, `./checkbox`, `./checkbox-group`, `./radio`, `./radio-group`, `./number-field`, `./select`, `./slider`, `./switch` | Native form boundary, field semantics, and controls. | +| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. | +| Media | `./avatar` | Avatar root, image, and fallback primitives. | +| Navigation | `./pagination`, `./tabs` | Pagination for page navigation; Tabs for panels. | +| Overlay / menu | `./alert-dialog`, `./context-menu`, `./dialog`, `./drawer`, `./dropdown-menu`, `./popover`, `./preview-card`, `./tooltip` | Portalled. See [Overlay & portal contract] below. | +| Search / pickers | `./autocomplete`, `./combobox`, `./select` | Search input, searchable picker, and closed picker. | Utilities: @@ -72,7 +75,7 @@ Use `Form` for the submit boundary. It renders a native `
`, preserves Ente Use `FieldRoot` for each standalone named field. A field must have a stable `name`, a label relationship, and either a `FieldControl` or another control that participates in the same Base UI field context. Prefer a visible label for normal form rows; when the surrounding UI already supplies the visible text, use the matching label primitive visually hidden or put `aria-label` on the actual interactive control. `FieldDescription` and `FieldError` provide the message relationships that screen readers need, while the Dify wrapper adds the default Form Input Set styling from the design system. -Choose the label primitive by the control semantics. Text-like inputs, input-based `Combobox` / `Autocomplete`, single `Checkbox` / `Radio`, `Switch`, and `NumberField` use `FieldLabel`. Trigger-based `Select` fields use `SelectLabel`; `Slider` fields use `SliderLabel`, with per-thumb `aria-label` only when the thumbs need distinct names. `SelectGroupLabel` and `AutocompleteGroupLabel` only label grouped options inside their popup content; they are not field labels. +Choose the label primitive by the control semantics. Text-like inputs, `Textarea`, input-based `Combobox` / `Autocomplete`, single `Checkbox` / `Radio`, `Switch`, and `NumberField` use `FieldLabel`. Trigger-based `Select` fields use `SelectLabel`; `Slider` fields use `SliderLabel`, with per-thumb `aria-label` only when the thumbs need distinct names. `SelectGroupLabel` and `AutocompleteGroupLabel` only label grouped options inside their popup content; they are not field labels. Use `FieldsetRoot` and `FieldsetLegend` when one field is represented by a group of related controls, such as checkbox groups, radio groups, multi-thumb sliders, or a section that combines several inputs. For checkbox and radio groups, wrap each option with `FieldItem` and give each option its own label: diff --git a/packages/dify-ui/package.json b/packages/dify-ui/package.json index f85210e8f6..3f4a7d7999 100644 --- a/packages/dify-ui/package.json +++ b/packages/dify-ui/package.json @@ -69,6 +69,10 @@ "types": "./src/input/index.tsx", "import": "./src/input/index.tsx" }, + "./kbd": { + "types": "./src/kbd/index.tsx", + "import": "./src/kbd/index.tsx" + }, "./meter": { "types": "./src/meter/index.tsx", "import": "./src/meter/index.tsx" @@ -129,6 +133,10 @@ "types": "./src/tabs/index.tsx", "import": "./src/tabs/index.tsx" }, + "./textarea": { + "types": "./src/textarea/index.tsx", + "import": "./src/textarea/index.tsx" + }, "./toast": { "types": "./src/toast/index.tsx", "import": "./src/toast/index.tsx" @@ -167,6 +175,7 @@ "@storybook/addon-themes": "catalog:", "@storybook/react-vite": "catalog:", "@tailwindcss/vite": "catalog:", + "@tanstack/react-hotkeys": "catalog:", "@tanstack/react-virtual": "catalog:", "@types/react": "catalog:", "@types/react-dom": "catalog:", diff --git a/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx b/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx index 72ed042033..d0cab69401 100644 --- a/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx +++ b/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx @@ -218,6 +218,8 @@ describe('Autocomplete wrappers', () => { await expect.element(screen.getByText('Workflow')).toHaveClass('system-sm-medium') await expect.element(screen.getByTestId('status')).toHaveClass('text-text-tertiary') await expect.element(screen.getByTestId('empty')).toHaveClass('system-sm-regular') + await expect.element(screen.getByTestId('empty')).toHaveClass('empty:p-0') + expect(screen.getByTestId('empty').element().getBoundingClientRect().height).toBe(0) expect(screen.getByText('Workflow').element().parentElement?.querySelector('.i-ri-arrow-right-line')).toHaveAttribute('aria-hidden', 'true') }) @@ -248,5 +250,34 @@ describe('Autocomplete wrappers', () => { await expect.element(screen.getByText('Workflow')).toHaveClass('custom-text') await expect.element(screen.getByTestId('indicator')).toHaveClass('custom-indicator') }) + + it('should navigate function-rendered items with arrow keys', async () => { + const screen = await renderWithSafeViewport( + + + + + + + {(item: string) => ( + + {item} + + )} + + + , + ) + + const input = asHTMLElement(screen.getByRole('combobox', { name: 'Search resources' }).element()) + + input.focus() + input.dispatchEvent(new KeyboardEvent('keydown', { key: 'ArrowDown', bubbles: true, cancelable: true })) + await expect.element(screen.getByRole('option', { name: 'workflow' })).toHaveAttribute('data-highlighted') + + input.dispatchEvent(new KeyboardEvent('keydown', { key: 'ArrowDown', bubbles: true, cancelable: true })) + + await expect.element(screen.getByRole('option', { name: 'dataset' })).toHaveAttribute('data-highlighted') + }) }) }) diff --git a/packages/dify-ui/src/autocomplete/index.stories.tsx b/packages/dify-ui/src/autocomplete/index.stories.tsx index 79f8983cd2..5b98b39f1f 100644 --- a/packages/dify-ui/src/autocomplete/index.stories.tsx +++ b/packages/dify-ui/src/autocomplete/index.stories.tsx @@ -2,7 +2,7 @@ import type { Meta, StoryObj } from '@storybook/react-vite' import type { Virtualizer } from '@tanstack/react-virtual' import type { RefObject } from 'react' import { useVirtualizer } from '@tanstack/react-virtual' -import { useEffect, useMemo, useRef, useState } from 'react' +import { useEffect, useMemo, useRef, useState, useTransition } from 'react' import { Autocomplete, AutocompleteClear, @@ -23,6 +23,7 @@ import { useAutocompleteFilteredItems, } from '.' import { cn } from '../cn' +import { Kbd } from '../kbd' type Suggestion = { value: string @@ -159,13 +160,60 @@ const virtualizedSuggestions: Suggestion[] = Array.from({ length: 1000 }, (_, in const getSuggestionLabel = (item: Suggestion) => item.label +async function searchSuggestions( + suggestions: Suggestion[], + query: string, + filter: (item: string, query: string) => boolean, +): Promise<{ items: Suggestion[], error: string | null }> { + await new Promise(resolve => window.setTimeout(resolve, 500)) + + if (query === 'will_error') { + return { + items: [], + error: 'Failed to load suggestions. Please try again.', + } + } + + return { + items: suggestions.filter(item => ( + filter(item.label, query) + || (item.description ? filter(item.description, query) : false) + )), + error: null, + } +} + const SuggestionItem = ({ + item, + dense, +}: { + item: Suggestion + dense?: boolean +}) => ( + + {item.icon && +) + +// Only virtualized items receive an explicit index; ordinary lists must let Base UI register items by DOM order for keyboard navigation. +const VirtualizedSuggestionItem = ({ item, index, dense, }: { item: Suggestion - index?: number + index: number dense?: boolean }) => ( @@ -186,12 +234,10 @@ const SuggestionItem = ({ const TagSuggestionItem = ({ item, - index, }: { item: Suggestion - index?: number }) => ( - + {item.label} {item.description && {item.description}} @@ -205,6 +251,7 @@ const BasicTagAutocomplete = ({ @@ -215,8 +262,8 @@ const BasicTagAutocomplete = ({ - {(item: Suggestion, index: number) => ( - + {(item: Suggestion) => ( + )} No tag suggestion. Keep the typed value. @@ -263,9 +310,9 @@ const CommandPaletteList = () => { {item.description} - + Enter - + )} @@ -289,32 +336,64 @@ const LimitedStatus = ({ } const AsyncSearchDemo = () => { - const [value, setValue] = useState('agent') - const [loading, setLoading] = useState(false) - const [items, setItems] = useState(remoteSuggestions) + const [searchValue, setSearchValue] = useState('') + const [searchResults, setSearchResults] = useState([]) + const [error, setError] = useState(null) + const [isPending, startTransition] = useTransition() + const { contains } = useAutocompleteFilter() + const abortControllerRef = useRef(null) - useEffect(() => { - setLoading(true) - const timeout = window.setTimeout(() => { - setItems( - value.trim() - ? remoteSuggestions.filter(item => item.label.toLowerCase().includes(value.trim().toLowerCase())) - : remoteSuggestions, - ) - setLoading(false) - }, 500) + const status = (() => { + if (isPending) + return 'Searching remote suggestions…' - return () => window.clearTimeout(timeout) - }, [value]) + if (error) + return error + + if (searchValue === '') + return null + + if (searchResults.length === 0) + return `No remote suggestion matches "${searchValue}".` + + return `${searchResults.length} remote suggestion${searchResults.length === 1 ? '' : 's'} found` + })() return (
{ + setSearchValue(nextSearchValue) + + const controller = new AbortController() + abortControllerRef.current?.abort() + abortControllerRef.current = controller + + if (nextSearchValue === '') { + setSearchResults([]) + setError(null) + return + } + + startTransition(async () => { + setError(null) + + const result = await searchSuggestions(remoteSuggestions, nextSearchValue, contains) + + if (controller.signal.aborted) + return + + startTransition(() => { + setSearchResults(result.items) + setError(result.error) + }) + }) + }} itemToStringValue={getSuggestionLabel} - openOnInputClick + filter={null} + mode="list" > - + - {loading ? 'Loading suggestions…' : `${items.length} remote suggestions`} + {status} - {(item: Suggestion, index: number) => ( - + {(item: Suggestion) => ( + )} - No remote suggestion. Keep the typed query.
@@ -384,7 +462,7 @@ const VirtualizedSuggestionList = ({ transform: `translateY(${virtualItem.start}px)`, }} > - + ) })} @@ -445,6 +523,7 @@ const FuzzyMatchingDemo = () => { onValueChange={setValue} filter={contains} itemToStringValue={getSuggestionLabel} + mode="list" openOnInputClick > @@ -455,8 +534,8 @@ const FuzzyMatchingDemo = () => { - {(item: Suggestion, index: number) => ( - + {(item: Suggestion) => ( + {item.icon &&