diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 280fcb6341..69c099a262 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -204,6 +204,16 @@ When assigned to test a directory/path, test **ALL content** within that path: > See [Test Structure Template](#test-structure-template) for correct import/mock patterns. +### `nuqs` Query State Testing (Required for URL State Hooks) + +When a component or hook uses `useQueryState` / `useQueryStates`: + +- ✅ Use `NuqsTestingAdapter` (prefer shared helpers in `web/test/nuqs-testing.tsx`) +- ✅ Assert URL synchronization via `onUrlUpdate` (`searchParams`, `options.history`) +- ✅ For custom parsers (`createParser`), keep `parse` and `serialize` bijective and add round-trip edge cases (`%2F`, `%25`, spaces, legacy encoded values) +- ✅ Verify default-clearing behavior (default values should be removed from URL when applicable) +- ⚠️ Only mock `nuqs` directly when URL behavior is explicitly out of scope for the test + ## Core Principles ### 1. AAA Pattern (Arrange-Act-Assert) diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md index 1ff2b27bbb..10b8fb66f9 100644 --- a/.agents/skills/frontend-testing/references/checklist.md +++ b/.agents/skills/frontend-testing/references/checklist.md @@ -80,6 +80,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen - [ ] Router mocks match actual Next.js API - [ ] Mocks reflect actual component conditional behavior - [ ] Only mock: API services, complex context providers, third-party libs +- [ ] For `nuqs` URL-state tests, wrap with `NuqsTestingAdapter` (prefer `web/test/nuqs-testing.tsx`) +- [ ] For `nuqs` URL-state tests, assert `onUrlUpdate` payload (`searchParams`, `options.history`) +- [ ] If custom `nuqs` parser exists, add round-trip tests for encoded edge cases (`%2F`, `%25`, spaces, legacy encoded values) ### Queries diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index 86bd375987..f58377c4a5 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -125,6 +125,31 @@ describe('Component', () => { }) ``` +### 2.1 `nuqs` Query State (Preferred: Testing Adapter) + +For tests that validate URL query behavior, use `NuqsTestingAdapter` instead of mocking `nuqs` directly. + +```typescript +import { renderHookWithNuqs } from '@/test/nuqs-testing' + +it('should sync query to URL with push history', async () => { + const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), { + searchParams: '?page=1', + }) + + act(() => { + result.current.setQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('push') + expect(update.searchParams.get('page')).toBe('2') +}) +``` + +Use direct `vi.mock('nuqs')` only when URL synchronization is intentionally out of scope. + ### 3. Portal Components (with Shared State) ```typescript diff --git a/.agents/skills/orpc-contract-first/SKILL.md b/.agents/skills/orpc-contract-first/SKILL.md index 4e3bfc7a37..b5cd62dfb5 100644 --- a/.agents/skills/orpc-contract-first/SKILL.md +++ b/.agents/skills/orpc-contract-first/SKILL.md @@ -1,43 +1,100 @@ --- name: orpc-contract-first -description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories. +description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service. --- # oRPC Contract-First Development -## Project Structure +## Intent -``` +- Keep contract as single source of truth in `web/contract/*`. +- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract. +- Keep abstractions minimal and preserve TypeScript inference. + +## Minimal Structure + +```text web/contract/ -├── base.ts # Base contract (inputStructure: 'detailed') -├── router.ts # Router composition & type exports -├── marketplace.ts # Marketplace contracts -└── console/ # Console contracts by domain - ├── system.ts - └── billing.ts +├── base.ts +├── router.ts +├── marketplace.ts +└── console/ + ├── billing.ts + └── ...other domains +web/service/client.ts ``` -## Workflow +## Core Workflow -1. **Create contract** in `web/contract/console/{domain}.ts` - - Import `base` from `../base` and `type` from `@orpc/contract` - - Define route with `path`, `method`, `input`, `output` +1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts` + - Use `base.route({...}).output(type<...>())` as baseline. + - Add `.input(type<...>())` only when request has `params/query/body`. + - For `GET` without input, omit `.input(...)` (do not use `.input(type())`). +2. Register contract in `web/contract/router.ts` + - Import directly from domain files and nest by API prefix. +3. Consume from UI call sites via oRPC query utils. -2. **Register in router** at `web/contract/router.ts` - - Import directly from domain file (no barrel files) - - Nest by API prefix: `billing: { invoices, bindPartnerStack }` +```typescript +import { useQuery } from '@tanstack/react-query' +import { consoleQuery } from '@/service/client' -3. **Create hooks** in `web/service/use-{domain}.ts` - - Use `consoleQuery.{group}.{contract}.queryKey()` for query keys - - Use `consoleClient.{group}.{contract}()` for API calls +const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({ + staleTime: 5 * 60 * 1000, + throwOnError: true, + select: invoice => invoice.url, +})) +``` -## Key Rules +## Query Usage Decision Rule + +1. Default: call site directly uses `*.queryOptions(...)`. +2. If 3+ call sites share the same extra options (for example `retry: false`), extract a small queryOptions helper, not a `use-*` passthrough hook. +3. Create `web/service/use-{domain}.ts` only for orchestration: + - Combine multiple queries/mutations. + - Share domain-level derived state or invalidation helpers. + +```typescript +const invoicesBaseQueryOptions = () => + consoleQuery.billing.invoices.queryOptions({ retry: false }) + +const invoiceQuery = useQuery({ + ...invoicesBaseQueryOptions(), + throwOnError: true, +}) +``` + +## Mutation Usage Decision Rule + +1. Default: call mutation helpers from `consoleQuery` / `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`. +2. If mutation flow is heavily custom, use oRPC clients as `mutationFn` (for example `consoleClient.xxx` / `marketplaceClient.xxx`), instead of generic handwritten non-oRPC mutation logic. + +## Key API Guide (`.key` vs `.queryKey` vs `.mutationKey`) + +- `.key(...)`: + - Use for partial matching operations (recommended for invalidation/refetch/cancel patterns). + - Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })` +- `.queryKey(...)`: + - Use for a specific query's full key (exact query identity / direct cache addressing). +- `.mutationKey(...)`: + - Use for a specific mutation's full key. + - Typical use cases: mutation defaults registration, mutation-status filtering (`useIsMutating`, `queryClient.isMutating`), or explicit devtools grouping. + +## Anti-Patterns + +- Do not wrap `useQuery` with `options?: Partial`. +- Do not split local `queryKey/queryFn` when oRPC `queryOptions` already exists and fits the use case. +- Do not create thin `use-*` passthrough hooks for a single endpoint. +- Reason: these patterns can degrade inference (`data` may become `unknown`, especially around `throwOnError`/`select`) and add unnecessary indirection. + +## Contract Rules - **Input structure**: Always use `{ params, query?, body? }` format +- **No-input GET**: Omit `.input(...)`; do not use `.input(type())` - **Path params**: Use `{paramName}` in path, match in `params` object -- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`) +- **Router nesting**: Group by API prefix (e.g., `/billing/*` -> `billing: {}`) - **No barrel files**: Import directly from specific files - **Types**: Import from `@/types/`, use `type()` helper +- **Mutations**: Prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults/filtering/devtools ## Type Export diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index bfb1c85436..1bb7d06232 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,7 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/core/model_runtime/ @laipz8200 @QuantumGhost +/api/dify_graph/model_runtime/ @laipz8200 @QuantumGhost # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 1a57bb0050..78f6eefd0d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,25 +1,37 @@ version: 2 -multi-ecosystem-groups: - python: - schedule: - interval: "weekly" # or whatever schedule you want - updates: - package-ecosystem: "pip" directory: "/api" open-pull-requests-limit: 2 - patterns: ["*"] schedule: interval: "weekly" + groups: + python-dependencies: + patterns: + - "*" - package-ecosystem: "uv" directory: "/api" open-pull-requests-limit: 2 - patterns: ["*"] schedule: interval: "weekly" + groups: + uv-dependencies: + patterns: + - "*" - package-ecosystem: "npm" directory: "/web" schedule: interval: "weekly" open-pull-requests-limit: 2 + groups: + storybook: + patterns: + - "storybook" + - "@storybook/*" + npm-dependencies: + patterns: + - "*" + exclude-patterns: + - "storybook" + - "@storybook/*" diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index cbd6edf94b..eb13c3d096 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -89,7 +89,7 @@ jobs: uses: actions/setup-node@v6 if: steps.changed-files.outputs.any_changed == 'true' with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index ec392cb3b2..d9a1168636 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -28,7 +28,7 @@ jobs: - name: Use Node.js uses: actions/setup-node@v6 with: - node-version: 24 + node-version: 22 cache: '' cache-dependency-path: 'pnpm-lock.yaml' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 5d9440ff35..b431c36a8b 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -57,7 +57,7 @@ jobs: - name: Set up Node.js uses: actions/setup-node@v6 with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index f50689636b..659620b2a9 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -39,7 +39,7 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v6 with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml @@ -83,7 +83,7 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v6 with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml @@ -457,7 +457,7 @@ jobs: uses: actions/setup-node@v6 if: steps.changed-files.outputs.any_changed == 'true' with: - node-version: 24 + node-version: 22 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml diff --git a/.gitignore b/.gitignore index dce9f66d2e..a621324775 100644 --- a/.gitignore +++ b/.gitignore @@ -224,6 +224,7 @@ mise.toml # AI Assistant .sisyphus/ .roo/ +/.claude/worktrees/ api/.env.backup /clickzetta diff --git a/AGENTS.md b/AGENTS.md index 51fa6e4527..d25d2eed96 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -29,7 +29,7 @@ The codebase is split into: ## Language Style -- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). +- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation. - **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types. ## General Practices diff --git a/api/.env.example b/api/.env.example index 2e155ce2d8..9ee733831b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -45,6 +45,8 @@ REFRESH_TOKEN_EXPIRE_DAYS=30 # redis configuration REDIS_HOST=localhost REDIS_PORT=6379 +# Optional: limit total connections in connection pool (unset for default) +# REDIS_MAX_CONNECTIONS=200 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false diff --git a/api/.importlinter b/api/.importlinter index 28c853bb62..e4536b1f10 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -1,6 +1,7 @@ [importlinter] root_packages = core + dify_graph configs controllers extensions @@ -21,48 +22,39 @@ layers = runtime entities containers = - core.workflow + dify_graph ignore_imports = - core.workflow.nodes.base.node -> core.workflow.graph_events - core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events - core.workflow.nodes.loop.loop_node -> core.workflow.graph_events + dify_graph.nodes.base.node -> dify_graph.graph_events + dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events + dify_graph.nodes.loop.loop_node -> dify_graph.graph_events - core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory - core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory - core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota - core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota - - core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine - core.workflow.nodes.iteration.iteration_node -> core.workflow.graph - core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels - core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine - core.workflow.nodes.loop.loop_node -> core.workflow.graph - core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels + dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine + dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine # TODO(QuantumGhost): fix the import violation later - core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities + dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities [importlinter:contract:workflow-infrastructure-dependencies] name = Workflow Infrastructure Dependencies type = forbidden source_modules = - core.workflow + dify_graph forbidden_modules = extensions.ext_database extensions.ext_redis allow_indirect_imports = True ignore_imports = - core.workflow.nodes.agent.agent_node -> extensions.ext_database - core.workflow.nodes.llm.file_saver -> extensions.ext_database - core.workflow.nodes.llm.node -> extensions.ext_database - core.workflow.nodes.tool.tool_node -> extensions.ext_database - # TODO(QuantumGhost): use DI to avoid depending on global DB. - core.workflow.nodes.human_input.human_input_node -> extensions.ext_database + dify_graph.nodes.agent.agent_node -> extensions.ext_database + dify_graph.nodes.llm.file_saver -> extensions.ext_database + dify_graph.nodes.llm.node -> extensions.ext_database + dify_graph.nodes.tool.tool_node -> extensions.ext_database + dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis + dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis [importlinter:contract:workflow-external-imports] name = Workflow External Imports type = forbidden source_modules = - core.workflow + dify_graph forbidden_modules = configs controllers @@ -100,129 +92,63 @@ forbidden_modules = core.trigger core.variables ignore_imports = - core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory - core.workflow.workflow_entry -> core.app.workflow.layers.observability - core.workflow.nodes.agent.agent_node -> core.model_manager - core.workflow.nodes.agent.agent_node -> core.provider_manager - core.workflow.nodes.agent.agent_node -> core.tools.tool_manager - core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy - core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory - core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota - core.workflow.nodes.llm.llm_utils -> core.model_manager - core.workflow.nodes.llm.protocols -> core.model_manager - core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model - core.workflow.nodes.llm.node -> core.tools.signature - core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler - core.workflow.nodes.tool.tool_node -> core.tools.tool_engine - core.workflow.nodes.tool.tool_node -> core.tools.tool_manager - core.workflow.workflow_entry -> configs - core.workflow.workflow_entry -> models.workflow - core.workflow.nodes.agent.agent_node -> core.agent.entities - core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities - core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities - core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities - core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model - core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform - core.workflow.workflow_entry -> core.app.apps.exc - core.workflow.workflow_entry -> core.app.entities.app_invoke_entities - core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota - core.workflow.workflow_entry -> core.app.workflow.node_factory - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager - core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager - core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer - core.workflow.nodes.tool.tool_node -> models - core.workflow.nodes.agent.agent_node -> models.model - core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy - core.workflow.nodes.llm.node -> core.helper.code_executor - core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor - core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors - core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output - core.workflow.nodes.llm.node -> core.model_manager - core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util - core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util - core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods - core.workflow.nodes.llm.node -> models.dataset - core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer - core.workflow.nodes.llm.file_saver -> core.tools.signature - core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager - core.workflow.nodes.tool.tool_node -> core.tools.errors - core.workflow.nodes.agent.agent_node -> extensions.ext_database - core.workflow.nodes.llm.file_saver -> extensions.ext_database - core.workflow.nodes.llm.node -> extensions.ext_database - core.workflow.nodes.tool.tool_node -> extensions.ext_database - core.workflow.nodes.human_input.human_input_node -> extensions.ext_database - core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository - core.workflow.workflow_entry -> extensions.otel.runtime - core.workflow.nodes.agent.agent_node -> models - core.workflow.nodes.base.node -> models.enums - core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota - core.workflow.nodes.llm.node -> models.model - core.workflow.workflow_entry -> models.enums - core.workflow.nodes.agent.agent_node -> services - core.workflow.nodes.tool.tool_node -> services - -[importlinter:contract:model-runtime-no-internal-imports] -name = Model Runtime Internal Imports -type = forbidden -source_modules = - core.model_runtime -forbidden_modules = - configs - controllers - extensions - models - services - tasks - core.agent - core.app - core.base - core.callback_handler - core.datasource - core.db - core.entities - core.errors - core.extension - core.external_data_tool - core.file - core.helper - core.hosting_configuration - core.indexing_runner - core.llm_generator - core.logging - core.mcp - core.memory - core.model_manager - core.moderation - core.ops - core.plugin - core.prompt - core.provider_manager - core.rag - core.repositories - core.schemas - core.tools - core.trigger - core.variables - core.workflow -ignore_imports = - core.model_runtime.model_providers.__base.ai_model -> configs - core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - core.model_runtime.model_providers.__base.large_language_model -> configs - core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type - core.model_runtime.model_providers.model_provider_factory -> configs - core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - core.model_runtime.model_providers.model_provider_factory -> models.provider_ids + dify_graph.nodes.agent.agent_node -> core.model_manager + dify_graph.nodes.agent.agent_node -> core.provider_manager + dify_graph.nodes.agent.agent_node -> core.tools.tool_manager + dify_graph.nodes.llm.llm_utils -> core.model_manager + dify_graph.nodes.llm.protocols -> core.model_manager + dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model + dify_graph.nodes.llm.node -> core.tools.signature + dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler + dify_graph.nodes.tool.tool_node -> core.tools.tool_engine + dify_graph.nodes.tool.tool_node -> core.tools.tool_manager + dify_graph.nodes.agent.agent_node -> core.agent.entities + dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities + dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model + dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager + dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager + dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer + dify_graph.nodes.tool.tool_node -> models + dify_graph.nodes.agent.agent_node -> models.model + dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy + dify_graph.nodes.llm.node -> core.helper.code_executor + dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors + dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output + dify_graph.nodes.llm.node -> core.model_manager + dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util + dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util + dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util + dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods + dify_graph.nodes.llm.node -> models.dataset + dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer + dify_graph.nodes.llm.file_saver -> core.tools.signature + dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager + dify_graph.nodes.tool.tool_node -> core.tools.errors + dify_graph.nodes.agent.agent_node -> extensions.ext_database + dify_graph.nodes.llm.file_saver -> extensions.ext_database + dify_graph.nodes.llm.node -> extensions.ext_database + dify_graph.nodes.tool.tool_node -> extensions.ext_database + dify_graph.nodes.agent.agent_node -> models + dify_graph.nodes.llm.node -> models.model + dify_graph.nodes.agent.agent_node -> services + dify_graph.nodes.tool.tool_node -> services + dify_graph.model_runtime.model_providers.__base.ai_model -> configs + dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis + dify_graph.model_runtime.model_providers.__base.large_language_model -> configs + dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type + dify_graph.model_runtime.model_providers.model_provider_factory -> configs + dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis + dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids [importlinter:contract:rsc] name = RSC @@ -231,7 +157,7 @@ layers = graph_engine response_coordinator containers = - core.workflow.graph_engine + dify_graph.graph_engine [importlinter:contract:worker] name = Worker @@ -240,7 +166,7 @@ layers = graph_engine worker containers = - core.workflow.graph_engine + dify_graph.graph_engine [importlinter:contract:graph-engine-architecture] name = Graph Engine Architecture @@ -256,28 +182,28 @@ layers = worker_management domain containers = - core.workflow.graph_engine + dify_graph.graph_engine [importlinter:contract:domain-isolation] name = Domain Model Isolation type = forbidden source_modules = - core.workflow.graph_engine.domain + dify_graph.graph_engine.domain forbidden_modules = - core.workflow.graph_engine.worker_management - core.workflow.graph_engine.command_channels - core.workflow.graph_engine.layers - core.workflow.graph_engine.protocols + dify_graph.graph_engine.worker_management + dify_graph.graph_engine.command_channels + dify_graph.graph_engine.layers + dify_graph.graph_engine.protocols [importlinter:contract:worker-management] name = Worker Management type = forbidden source_modules = - core.workflow.graph_engine.worker_management + dify_graph.graph_engine.worker_management forbidden_modules = - core.workflow.graph_engine.orchestration - core.workflow.graph_engine.command_processing - core.workflow.graph_engine.event_management + dify_graph.graph_engine.orchestration + dify_graph.graph_engine.command_processing + dify_graph.graph_engine.event_management [importlinter:contract:graph-traversal-components] @@ -287,11 +213,11 @@ layers = edge_processor skip_propagator containers = - core.workflow.graph_engine.graph_traversal + dify_graph.graph_engine.graph_traversal [importlinter:contract:command-channels] name = Command Channels Independence type = independence modules = - core.workflow.graph_engine.command_channels.in_memory_channel - core.workflow.graph_engine.command_channels.redis_channel + dify_graph.graph_engine.command_channels.in_memory_channel + dify_graph.graph_engine.command_channels.redis_channel diff --git a/api/.ruff.toml b/api/.ruff.toml index 3301452ad9..b0947eb619 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -100,7 +100,7 @@ ignore = [ "configs/*" = [ "N802", # invalid-function-name ] -"core/model_runtime/callbacks/base_callback.py" = ["T201"] +"dify_graph/model_runtime/callbacks/base_callback.py" = ["T201"] "core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 4705b28c69..367cb52731 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -111,3 +111,8 @@ class RedisConfig(BaseSettings): description="Enable client side cache in redis", default=False, ) + + REDIS_MAX_CONNECTIONS: PositiveInt | None = Field( + description="Maximum connections in the Redis connection pool (unset for library default)", + default=None, + ) diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py index a72e1dd28f..8cddc5677a 100644 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ b/api/configs/middleware/cache/redis_pubsub_config.py @@ -1,7 +1,7 @@ from typing import Literal, Protocol from urllib.parse import quote_plus, urlunparse -from pydantic import Field +from pydantic import AliasChoices, Field from pydantic_settings import BaseSettings @@ -23,41 +23,56 @@ class RedisConfigDefaultsMixin: class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): """ - Configuration settings for Redis pub/sub streaming. + Configuration settings for event transport between API and workers. + + Supported transports: + - pubsub: Redis PUBLISH/SUBSCRIBE (at-most-once) + - sharded: Redis 7+ Sharded Pub/Sub (at-most-once, better scaling) + - streams: Redis Streams (at-least-once, supports late subscribers) """ PUBSUB_REDIS_URL: str | None = Field( - alias="PUBSUB_REDIS_URL", + validation_alias=AliasChoices("EVENT_BUS_REDIS_URL", "PUBSUB_REDIS_URL"), description=( - "Redis connection URL for pub/sub streaming events between API " - "and celery worker, defaults to url constructed from " - "`REDIS_*` configurations" + "Redis connection URL for streaming events between API and celery worker; " + "defaults to URL constructed from `REDIS_*` configurations. Also accepts ENV: EVENT_BUS_REDIS_URL." ), default=None, ) PUBSUB_REDIS_USE_CLUSTERS: bool = Field( + validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"), description=( - "Enable Redis Cluster mode for pub/sub streaming. It's highly " - "recommended to enable this for large deployments." + "Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. " + "Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS." ), default=False, ) - PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field( + PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded", "streams"] = Field( + validation_alias=AliasChoices("EVENT_BUS_REDIS_CHANNEL_TYPE", "PUBSUB_REDIS_CHANNEL_TYPE"), description=( - "Pub/sub channel type for streaming events. " - "Valid options are:\n" - "\n" - " - pubsub: for normal Pub/Sub\n" - " - sharded: for sharded Pub/Sub\n" - "\n" - "It's highly recommended to use sharded Pub/Sub AND redis cluster " - "for large deployments." + "Event transport type. Options are:\n\n" + " - pubsub: normal Pub/Sub (at-most-once)\n" + " - sharded: sharded Pub/Sub (at-most-once)\n" + " - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)\n\n" + "Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.\n" + "Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce\n" + "the risk of data loss from Redis auto-eviction under memory pressure.\n" + "Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE." ), default="pubsub", ) + PUBSUB_STREAMS_RETENTION_SECONDS: int = Field( + validation_alias=AliasChoices("EVENT_BUS_STREAMS_RETENTION_SECONDS", "PUBSUB_STREAMS_RETENTION_SECONDS"), + description=( + "When using 'streams', expire each stream key this many seconds after the last event is published. " + "Also accepts ENV: EVENT_BUS_STREAMS_RETENTION_SECONDS." + ), + default=600, + ) + def _build_default_pubsub_url(self) -> str: defaults = self._redis_defaults() if not defaults.REDIS_HOST or not defaults.REDIS_PORT: diff --git a/api/context/__init__.py b/api/context/__init__.py index aebf9750ce..969e5f583d 100644 --- a/api/context/__init__.py +++ b/api/context/__init__.py @@ -12,7 +12,7 @@ or any other web framework. import contextvars from collections.abc import Callable -from core.workflow.context.execution_context import ( +from dify_graph.context.execution_context import ( ExecutionContext, IExecutionContext, NullAppContext, diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 2d465c8cf4..324a9ee8b4 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -10,8 +10,8 @@ from typing import Any, final from flask import Flask, current_app, g -from core.workflow.context import register_context_capturer -from core.workflow.context.execution_context import ( +from dify_graph.context import register_context_capturer +from dify_graph.context.execution_context import ( AppContext, IExecutionContext, ) diff --git a/api/controllers/cli_api/dify_cli/cli_api.py b/api/controllers/cli_api/dify_cli/cli_api.py index ef1fdab996..90a0b88cb4 100644 --- a/api/controllers/cli_api/dify_cli/cli_api.py +++ b/api/controllers/cli_api/dify_cli/cli_api.py @@ -1,3 +1,4 @@ +from core.workflow.file.helpers import get_signed_file_url_for_plugin from flask import abort from flask_restx import Resource from pydantic import BaseModel @@ -22,7 +23,6 @@ from core.session.cli_api import CliContext from core.skill.entities import ToolInvocationRequest from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager -from core.workflow.file.helpers import get_signed_file_url_for_plugin from libs.helper import length_prefixed_response from models.account import Account from models.model import EndUser, Tenant diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 9b30db8b75..ff5326dade 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -4,7 +4,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, computed_field -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index a4f93e1016..a268c9e9ff 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -26,8 +26,8 @@ from controllers.console.wraps import ( ) from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.enums import NodeType, WorkflowExecutionStatus -from core.workflow.file import helpers as file_helpers +from dify_graph.enums import NodeType, WorkflowExecutionStatus +from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 941db325bf..2c5e8d29ee 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -22,7 +22,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 2922121a54..4d7ddfea13 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -26,7 +26,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index f4c58f510f..3e3139e4f8 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -24,7 +24,7 @@ from core.llm_generator.context_models import ( ) from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b0c85aecf2..25661dd1b7 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -24,7 +24,7 @@ from controllers.console.wraps import ( ) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index ec99fd520a..e15b72f272 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -21,7 +21,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.trace_id_helper import get_external_trace_id -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginInvokeError from core.trigger.debug.event_selectors import ( TriggerDebugEvent, @@ -29,9 +28,10 @@ from core.trigger.debug.event_selectors import ( create_event_poller, select_trigger_debug_events, ) -from core.workflow.enums import NodeType -from core.workflow.file.models import File -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.enums import NodeType +from dify_graph.file.models import File +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 6736f24a2e..9b148c3f18 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 7845800bba..2cc9b81754 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -15,11 +15,11 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.file import helpers as file_helpers -from core.workflow.variables.segment_group import SegmentGroup -from core.workflow.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment -from core.workflow.variables.types import SegmentType +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.file import helpers as file_helpers +from dify_graph.variables.segment_group import SegmentGroup +from dify_graph.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment +from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories import variable_factory from factories.file_factory import build_from_mapping, build_from_mappings diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 9ac45cf2da..7ac653395e 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -12,8 +12,8 @@ from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 38ea5d2dae..6e59d4203c 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a06b872846..54303b2482 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -25,12 +25,12 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner -from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.retrieval.retrieval_methods import RetrievalMethod +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.app_fields import app_detail_kernel_fields, related_app_list from fields.dataset_fields import ( @@ -53,7 +53,7 @@ from fields.dataset_fields import ( from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile -from models.dataset import DatasetPermissionEnum +from models.dataset import DatasetPermission, DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -119,6 +119,14 @@ def _validate_indexing_technique(value: str | None) -> str | None: return value +def _validate_doc_form(value: str | None) -> str | None: + if value is None: + return value + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + + class DatasetCreatePayload(BaseModel): name: str = Field(..., min_length=1, max_length=40) description: str = Field("", max_length=400) @@ -179,6 +187,14 @@ class IndexingEstimatePayload(BaseModel): raise ValueError("indexing_technique is required.") return result + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + result = _validate_doc_form(value) + if result is None: + return "text_model" + return result + class ConsoleDatasetListQuery(BaseModel): page: int = Field(default=1, description="Page number") @@ -323,6 +339,18 @@ class DatasetListApi(Resource): model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields)) + dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"] + partial_members_map: dict[str, list[str]] = {} + if dataset_ids: + permissions = db.session.execute( + select(DatasetPermission.dataset_id, DatasetPermission.account_id).where( + DatasetPermission.dataset_id.in_(dataset_ids) + ) + ).all() + + for dataset_id, account_id in permissions: + partial_members_map.setdefault(dataset_id, []).append(account_id) + for item in data: # convert embedding_model_provider to plugin standard format if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: @@ -336,8 +364,7 @@ class DatasetListApi(Resource): item["embedding_available"] = True if item.get("permission") == "partial_members": - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) - item.update({"partial_member_list": part_users_list}) + item.update({"partial_member_list": partial_members_map.get(item["id"], [])}) else: item.update({"partial_member_list": []}) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index bf097d374a..ee726bc470 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -24,11 +24,11 @@ from core.errors.error import ( ) from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from fields.dataset_fields import dataset_fields from fields.document_fields import ( diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 23a668112d..3fd0f3b712 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -26,7 +26,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index db1a874437..99ff49d79d 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -19,7 +19,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields from libs.login import current_user from models.account import Account diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 1a47e226e5..a4498005d8 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -9,9 +9,9 @@ from configs import dify_config from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 7e285c8da9..4c441a5d07 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.variables.types import SegmentType +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 29b6b64b94..51cdcc0c7a 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -33,7 +33,7 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from libs import helper diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 0311db1584..ffb9e5bb6e 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -19,7 +19,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index a6e5b2822a..fcd52d2818 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -24,7 +24,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 88487ac96f..53970dbd3b 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -21,7 +21,7 @@ from controllers.console.explore.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from libs import helper diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index f6f731df36..25bb8ed7fe 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -41,8 +41,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.app_fields import ( diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index b841bda323..7801cee473 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -21,8 +21,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client from libs import helper from libs.login import current_account_with_tenant diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index f3738319df..49162d4dae 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -13,7 +13,7 @@ from controllers.common.errors import ( ) from controllers.console import console_ns from core.helper import ssrf_proxy -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 9527fe782e..e2b504751b 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -2,7 +2,7 @@ from flask_restx import Resource, fields from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 1897cbdca7..538c5fb561 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index ccb60b1461..0a9e54de99 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -5,8 +5,8 @@ from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 7bada2fa12..db3b02ae94 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -7,9 +7,9 @@ from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 583e3e3057..d7eceb656c 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index d1485bc1c0..2f06f72f29 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -12,8 +12,8 @@ from controllers.common.schema import register_enum_models, register_schema_mode from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginDaemonClientSideError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/sandbox_providers.py b/api/controllers/console/workspace/sandbox_providers.py index 95b8d77dbf..ab63decb4a 100644 --- a/api/controllers/console/workspace/sandbox_providers.py +++ b/api/controllers/console/workspace/sandbox_providers.py @@ -1,12 +1,12 @@ import logging +from core.model_runtime.utils.encoders import jsonable_encoder from flask import request from flask_restx import Resource, fields from pydantic import BaseModel from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.sandbox.sandbox_provider_service import SandboxProviderService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 5bfa895849..b38f05795a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -23,10 +23,10 @@ from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 6b642af613..ad78d2a623 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -10,11 +10,11 @@ from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config from controllers.common.schema import register_schema_models from controllers.web.error import NotFoundError -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from libs.login import current_user, login_required from models.account import Account diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index b34412ef6d..52690a12e1 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden import services from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file.helpers import verify_plugin_file_signature +from dify_graph.file.helpers import verify_plugin_file_signature from fields.file_fields import FileResponse from ..common.errors import ( diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 61a1815013..838b622d6a 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -4,7 +4,6 @@ from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse from core.plugin.backwards_invocation.encrypt import PluginEncrypter @@ -29,7 +28,8 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.file.helpers import get_signed_file_url_for_plugin +from dify_graph.file.helpers import get_signed_file_url_for_plugin +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 991a9166c7..2bc6640807 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -10,7 +10,7 @@ from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.variables.input_entities import VariableEntity from extensions.ext_database import db from libs import helper from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index ef254ca357..c22190cbc9 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -185,4 +185,4 @@ class AnnotationUpdateDeleteApi(Resource): def delete(self, app_model: App, annotation_id: str): """Delete an annotation.""" AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) - return {"result": "success"}, 204 + return "", 204 diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index e383920460..38d292d0b9 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -21,7 +21,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 9d8431f066..98f09c44a1 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -28,7 +28,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 8e29c9ff0f..edbf011656 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( - ConversationDelete, ConversationInfiniteScrollPagination, SimpleConversation, ) @@ -163,7 +162,7 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return ConversationDelete(result="success").model_dump(mode="json"), 204 + return "", 204 @service_api_ns.route("/conversations//name") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 2ce8f05f75..35dd22c801 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -27,9 +27,9 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from core.model_runtime.errors.invoke import InvokeError -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model @@ -132,6 +132,8 @@ class WorkflowRunDetailApi(Resource): app_id=app_model.id, run_id=workflow_run_id, ) + if not workflow_run: + raise NotFound("Workflow run not found.") return workflow_run diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c06b81b775..83d07087ab 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -14,8 +14,8 @@ from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, ) -from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag from libs.login import current_user diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 0aeb4a2d36..dc8da025d4 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -4,7 +4,7 @@ from uuid import UUID from flask import request from flask_restx import marshal -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -60,6 +60,13 @@ class DocumentTextCreatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -72,6 +79,13 @@ class DocumentTextUpdate(BaseModel): doc_language: str = "English" retrieval_model: RetrievalModel | None = None + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + @model_validator(mode="after") def check_text_and_name(self) -> Self: if self.text is not None and self.name is None: diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 4eb4fed29a..2e3b7fd85e 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -17,7 +17,7 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import current_account_with_tenant diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index fffcb47bd4..35aed40a59 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -3,7 +3,7 @@ from flask_restx import Resource from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 15828cc208..2b8f752668 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -20,7 +20,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index a97d745471..8634c1f43c 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -25,7 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 80035ba818..bbae1ce266 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -20,7 +20,7 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 1cdae0fe56..6a93ef6748 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -11,7 +11,7 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from core.helper import ssrf_proxy -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from services.file_service import FileService diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index a309ef3dad..508d1a756a 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -22,8 +22,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client from libs import helper from models.model import App, AppMode, EndUser diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 8ab6900ded..9312217835 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -19,7 +19,15 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities import ( +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolParameter, +) +from core.tools.tool_manager import ToolManager +from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool +from dify_graph.file import file_manager +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, LLMUsage, PromptMessage, @@ -29,17 +37,9 @@ from core.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.model_runtime.entities.model_entities import ModelFeature -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ( - ToolParameter, -) -from core.tools.tool_manager import ToolManager -from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool -from core.workflow.file import file_manager +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from dify_graph.model_runtime.entities.model_entities import ModelFeature +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from factories import file_factory from models.enums import CreatorUserRole diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 7c8f09e6b9..82676f1ebd 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -4,7 +4,7 @@ from collections.abc import Generator from typing import Union from core.agent.entities import AgentScratchpadUnit -from core.model_runtime.entities.llm_entities import LLMResultChunk +from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index b816c8d7d0..558b6e69a0 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -4,10 +4,10 @@ from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index c391a279b5..e4e750c735 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -2,9 +2,9 @@ from collections.abc import Mapping from typing import Any from core.app.app_config.entities import ModelConfigEntity -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider_ids import ModelProviderID diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 21614c010c..01b9601965 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -4,8 +4,8 @@ from core.app.app_config.entities import ( AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, ) -from core.model_runtime.entities.message_entities import PromptMessageRole from core.prompt.simple_prompt_transform import ModelMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 22d602a190..157e5d8bc0 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -2,7 +2,7 @@ import re from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( [ diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 062cc6a0b3..f26351d93e 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -4,10 +4,10 @@ from typing import Any, Literal from pydantic import BaseModel, Field -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.file import FileUploadConfig -from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity +from dify_graph.file import FileUploadConfig +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index d69fa85801..0c4266fbeb 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from typing import Any from constants import DEFAULT_FILE_NUMBER_LIMITS -from core.workflow.file import FileUploadConfig +from dify_graph.file import FileUploadConfig class FileUploadConfigManager: diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index ec7d85a09f..d2a9a73380 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,7 @@ import re from core.app.app_config.entities import RagPipelineVariableEntity -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 65fc15e065..5e6c8e5ab4 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -32,19 +32,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory from core.sandbox import Sandbox -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.repositories.draft_variable_repository import ( +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from dify_graph.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 35d7a6386e..3c416d22b3 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -26,16 +26,16 @@ from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration from core.sandbox import Sandbox -from core.workflow.enums import WorkflowType -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import VariableLoader -from core.workflow.variables.variables import Variable from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.enums import WorkflowType +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import VariableLoader +from dify_graph.variables.variables import Variable from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 6bf5b4bc29..8754644857 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -65,16 +65,16 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes import NodeType -from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.runtime import GraphRuntimeState -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.nodes import NodeType +from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory +from dify_graph.runtime import GraphRuntimeState +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile @@ -916,7 +916,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, tenant_id=self._workflow_tenant_id, ) form = form_repository.get_form(self._workflow_run_id, node_id) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 7bd3b8a56e..76a067d7b6 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -20,8 +20,8 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 699732d74f..521bba307d 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -12,9 +12,9 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelFeature -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationError +from dify_graph.model_runtime.entities.model_entities import ModelFeature +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from models.model import App, Conversation, Message diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index d1e2f16b6f..77950a832a 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -6,7 +6,7 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 81617c5fb2..20e6ac98ea 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -4,21 +4,21 @@ from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.enums import NodeType -from core.workflow.file import File, FileUploadConfig -from core.workflow.repositories.draft_variable_repository import ( +from dify_graph.enums import NodeType +from dify_graph.file import File, FileUploadConfig +from dify_graph.repositories.draft_variable_repository import ( DraftVariableSaver, DraftVariableSaverFactory, NoopDraftVariableSaver, ) -from core.workflow.variables.input_entities import VariableEntityType +from dify_graph.variables.input_entities import VariableEntityType from factories import file_factory from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl if TYPE_CHECKING: - from core.workflow.variables.input_entities import VariableEntity + from dify_graph.variables.input_entities import VariableEntity class BaseAppGenerator: diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index af1f1d7c66..5addd41815 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -20,7 +20,7 @@ from core.app.entities.queue_entities import ( QueueStopEvent, WorkflowQueueMessage, ) -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index b98e85dbe9..88714f3837 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -24,27 +24,27 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.errors.invoke import InvokeBadRequestError from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file.enums import FileTransferMethod, FileType +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from extensions.ext_database import db from models.enums import CreatorUserRole from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from core.workflow.file.models import File + from dify_graph.file.models import File _logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index c1251d2feb..91cf54c774 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -19,8 +19,8 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from models import Account diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 4870a56281..23546a47bb 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -13,10 +13,10 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file import File +from dify_graph.file import File +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import App, Conversation, Message diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 0b03149665..6a8e436163 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index f00a805d89..102cf66aff 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -49,21 +49,21 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.trigger_manager import TriggerManager -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import ( +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import ( NodeType, SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.file import FILE_MODEL_IDENTITY, File -from core.workflow.runtime import GraphRuntimeState -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment -from core.workflow.workflow_entry import WorkflowEntry -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.file import FILE_MODEL_IDENTITY, File +from dify_graph.runtime import GraphRuntimeState +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, EndUser diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 843328f904..e8b0e4f179 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -19,8 +19,8 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Message diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 30e1a609f8..ac05172945 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -11,10 +11,10 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file import File +from dify_graph.file import File +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import App, Message diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index eca96cb074..dcfc1415e8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -33,13 +33,13 @@ from core.datasource.entities.datasource_entities import ( ) from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories.factory import DifyCoreRepositoryFactory -from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 02caf8f511..4222aae809 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -8,23 +8,24 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, + UserFrom, + build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.enums import WorkflowType -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import VariableLoader -from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput +from core.workflow.node_factory import DifyNodeFactory from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities.graph_init_params import GraphInitParams +from dify_graph.enums import WorkflowType +from dify_graph.graph import Graph +from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import VariableLoader +from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from extensions.ext_database import db from models.dataset import Document, Pipeline -from models.enums import UserFrom from models.model import EndUser from models.workflow import Workflow @@ -257,13 +258,15 @@ class PipelineRunner(WorkflowBasedAppRunner): # init graph # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, workflow_id=workflow.id, graph_config=graph_config, - user_id=self.application_generate_entity.user_id, - user_from=user_from, - invoke_from=invoke_from, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, + ), call_depth=0, ) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index a0b2730abe..55675506ab 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -29,16 +29,16 @@ from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, Pau from core.app.layers.sandbox_layer import SandboxLayer from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory from core.sandbox.sandbox import Sandbox -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index a45466c5da..f176c2a1a7 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -9,15 +9,15 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.sandbox import Sandbox -from core.workflow.enums import WorkflowType -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.enums import WorkflowType +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import VariableLoader from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span from libs.datetime_utils import naive_utc_now diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index fe0ef138c6..02c42b59ce 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -56,11 +56,11 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.runtime import GraphRuntimeState -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory +from dify_graph.runtime import GraphRuntimeState +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from models import Account from models.enums import CreatorUserRole diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index be3c1e3025..7eb2a25e70 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, @@ -29,12 +29,13 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.graph import Graph -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import ( +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities import GraphInitParams +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.graph import Graph +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, @@ -60,14 +61,12 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.graph import GraphRunAbortedEvent -from core.workflow.nodes import NodeType -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool -from core.workflow.workflow_entry import WorkflowEntry -from models.enums import UserFrom +from dify_graph.graph_events.graph import GraphRunAbortedEvent +from dify_graph.nodes import NodeType +from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -119,13 +118,15 @@ class WorkflowBasedAppRunner: # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=tenant_id or "", - app_id=self._app_id, workflow_id=workflow_id, graph_config=graph_config, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, + run_context=build_dify_run_context( + tenant_id=tenant_id or "", + app_id=self._app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ), call_depth=0, ) @@ -267,13 +268,15 @@ class WorkflowBasedAppRunner: # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, workflow_id=workflow.id, graph_config=graph_config, - user_id="", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id="", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 7b7a8db62f..97c3c4c804 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,81 +7,77 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.model_runtime.entities.model_entities import AIModelEntity -from core.workflow.file import File, FileUploadConfig +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.file import File, FileUploadConfig +from dify_graph.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +class UserFrom(StrEnum): + ACCOUNT = "account" + END_USER = "end-user" + + class InvokeFrom(StrEnum): - """ - Invoke From. - """ - - # SERVICE_API indicates that this invocation is from an API call to Dify app. - # - # Description of service api in Dify docs: - # https://docs.dify.ai/en/guides/application-publishing/developing-with-apis SERVICE_API = "service-api" - - # WEB_APP indicates that this invocation is from - # the web app of the workflow (or chatflow). - # - # Description of web app in Dify docs: - # https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README WEB_APP = "web-app" - - # TRIGGER indicates that this invocation is from a trigger. - # this is used for plugin trigger and webhook trigger. TRIGGER = "trigger" - - # AGENT indicates that this invocation is from an agent. AGENT = "agent" - # EXPLORE indicates that this invocation is from - # the workflow (or chatflow) explore page. EXPLORE = "explore" - # DEBUGGER indicates that this invocation is from - # the workflow (or chatflow) edit page. DEBUGGER = "debugger" - # PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow. PUBLISHED_PIPELINE = "published" - - # VALIDATION indicates that this invocation is from validation. VALIDATION = "validation" @classmethod - def value_of(cls, value: str): - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid invoke from value {value}") + def value_of(cls, value: str) -> "InvokeFrom": + return cls(value) def to_source(self) -> str: - """ - Get source of invoke from. + source_mapping = { + InvokeFrom.WEB_APP: "web_app", + InvokeFrom.DEBUGGER: "dev", + InvokeFrom.EXPLORE: "explore_app", + InvokeFrom.TRIGGER: "trigger", + InvokeFrom.SERVICE_API: "api", + } + return source_mapping.get(self, "dev") - :return: source - """ - if self == InvokeFrom.WEB_APP: - return "web_app" - elif self == InvokeFrom.DEBUGGER: - return "dev" - elif self == InvokeFrom.EXPLORE: - return "explore_app" - elif self == InvokeFrom.TRIGGER: - return "trigger" - elif self == InvokeFrom.SERVICE_API: - return "api" - return "dev" +class DifyRunContext(BaseModel): + tenant_id: str + app_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + +def build_dify_run_context( + *, + tenant_id: str, + app_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + extra_context: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + """ + Build graph run_context with the reserved Dify runtime payload. + + `extra_context` can carry user-defined context keys. The reserved `_dify` + payload is always overwritten by this function to keep one canonical source. + """ + run_context = dict(extra_context) if extra_context else {} + run_context[DIFY_RUN_CONTEXT_KEY] = DifyRunContext( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + return run_context class ModelConfigWithCredentialsEntity(BaseModel): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 2262b571fa..fa000fa13e 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -5,13 +5,13 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import WorkflowNodeExecutionMetadataKey -from core.workflow.nodes import NodeType +from dify_graph.entities import AgentNodeStrategyInit, ToolCall, ToolResult +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import WorkflowNodeExecutionMetadataKey +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.nodes import NodeType class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index a0e2488376..5c0383d3a7 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -4,12 +4,12 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.nodes.human_input.entities import FormInput, UserAction +from dify_graph.entities import AgentNodeStrategyInit +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index a5a5486581..5ed1fadc41 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -2,7 +2,7 @@ import logging from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from core.model_runtime.entities.message_entities import PromptMessage +from dify_graph.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index a748d90387..e495abf855 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,12 @@ import logging -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater -from core.workflow.enums import NodeType -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.variables import VariableBase +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.conversation_variable_updater import ConversationVariableUpdater +from dify_graph.enums import NodeType +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.variables import VariableBase logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 1c267091a4..4370c01a0b 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -6,9 +6,9 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events.base import GraphEngineEvent -from core.workflow.graph_events.graph import GraphRunPausedEvent +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events.base import GraphEngineEvent +from dify_graph.graph_events.graph import GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py index 85ed53c4d6..f66054d41c 100644 --- a/api/core/app/layers/sandbox_layer.py +++ b/api/core/app/layers/sandbox_layer.py @@ -1,9 +1,10 @@ import logging -from core.sandbox import Sandbox from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent +from core.sandbox import Sandbox + logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 0a107de012..2adaf14a35 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,6 +1,6 @@ -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events.base import GraphEngineEvent -from core.workflow.graph_events.graph import GraphRunPausedEvent +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events.base import GraphEngineEvent +from dify_graph.graph_events.graph import GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index f82397deca..d7ca45f209 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -4,9 +4,9 @@ from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore -from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events.base import GraphEngineEvent +from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events.base import GraphEngineEvent from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index a7ea9ef446..a4019a83e1 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -5,9 +5,9 @@ from typing import Any, ClassVar from pydantic import TypeAdapter from core.db.session_factory import session_factory -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events.base import GraphEngineEvent -from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events.base import GraphEngineEvent +from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index ebae830389..a63ff39fa5 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -5,11 +5,11 @@ from typing import Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.nodes.llm.entities import ModelConfig +from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory class DifyCredentialsProvider: diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 1c66c8c1ff..7aa3bf15ab 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -6,7 +6,7 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.llm_entities import LLMUsage from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 26c7e60a4c..0d5e0acec6 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -16,8 +16,8 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.errors.error import QuotaExceededError -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 7feeb5fe2c..6093413c5d 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -47,19 +47,19 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.signature import sign_tool_file -from core.workflow.file import helpers as file_helpers -from core.workflow.file.enums import FileTransferMethod +from dify_graph.file import helpers as file_helpers +from dify_graph.file.enums import FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py index 172ee5d703..3bca7f5c34 100644 --- a/api/core/app/workflow/__init__.py +++ b/api/core/app/workflow/__init__.py @@ -1,3 +1,3 @@ -from .node_factory import DifyNodeFactory +from core.workflow.node_factory import DifyNodeFactory __all__ = ["DifyNodeFactory"] diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 954638b901..e0f8d27111 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -5,13 +5,13 @@ from collections.abc import Generator from configs import dify_config from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file -from core.workflow.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from core.workflow.file.runtime import set_workflow_file_runtime +from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from dify_graph.file.runtime import set_workflow_file_runtime from extensions.ext_storage import storage class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - """Production runtime wiring for ``core.workflow.file``.""" + """Production runtime wiring for ``dify_graph.file``.""" @property def files_url(self) -> str: diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 45fb84c81f..2e930a1f58 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -12,17 +12,17 @@ from typing_extensions import override from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from core.workflow.enums import NodeType -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase -from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.nodes.base.node import Node +from dify_graph.enums import NodeType +from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase +from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.nodes.base.node import Node if TYPE_CHECKING: - from core.workflow.nodes.llm.node import LLMNode - from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode - from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode + from dify_graph.nodes.llm.node import LLMNode + from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode + from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode logger = logging.getLogger(__name__) @@ -75,8 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer): return try: + dify_ctx = node.require_dify_context() deduct_llm_quota( - tenant_id=node.tenant_id, + tenant_id=dify_ctx.tenant_id, model_instance=model_instance, usage=result_event.node_run_result.llm_usage, ) diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 94839c8ae3..ab73db59f1 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -16,10 +16,10 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_ from typing_extensions import override from configs import dify_config -from core.workflow.enums import NodeType -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node +from dify_graph.enums import NodeType +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node from extensions.otel.parser import ( DefaultNodeOTelParser, LLMNodeOTelParser, diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 132302efe1..65653a1edf 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -17,17 +17,17 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution -from core.workflow.enums import ( +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution +from dify_graph.enums import ( SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, WorkflowType, ) -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import ( +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, @@ -42,9 +42,9 @@ from core.workflow.graph_events import ( NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeRunResult -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.node_events import NodeRunResult +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.datetime_utils import naive_utc_now diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index f83aaa0006..beda515666 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -15,8 +15,8 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import TextPromptMessageContent -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from dify_graph.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index f67bfb6ead..5971c1e013 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -213,6 +213,6 @@ class DatasourceFileManager: # init tool_file_parser -# from core.workflow.file.datasource_file_parser import datasource_file_manager +# from dify_graph.file.datasource_file_parser import datasource_file_manager # # datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 9c48f755a9..15cd319750 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import WorkflowNodeExecutionMetadataKey -from core.workflow.file import File -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionMetadataKey +from dify_graph.file import File +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory from models.model import UploadFile from models.tools import ToolFile diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 1179537570..4c9ff64479 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -3,8 +3,8 @@ from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index ab3302bd6e..2881888e27 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -4,7 +4,7 @@ from mimetypes import guess_extension, guess_type from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 46006f4381..1343bd8e82 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -5,7 +5,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, Field -from core.workflow.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 5902c03e27..d214652e9c 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -15,7 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index a123fb0321..3427fc54b1 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -3,9 +3,9 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType, ProviderModel -from core.model_runtime.entities.provider_entities import ProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType, ProviderModel +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 8a26b2e91b..9f8d06e322 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -19,15 +19,15 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from core.model_runtime.entities.provider_entities import ( +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.engine import db from models.provider import ( diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 0078ec7e4f..a830f227a9 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -11,8 +11,8 @@ from core.entities.parameter_entities import ( ModelSelectorScope, ToolSelectorScope, ) -from core.model_runtime.entities.model_entities import ModelType from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index d581b3ac39..4251cfd30b 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -13,7 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from core.workflow.nodes.code.entities import CodeLanguage +from dify_graph.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 1b56eaba21..c569e066f4 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from core.workflow.variables.utils import dumps_with_segments +from dify_graph.variables.utils import dumps_with_segments class TemplateTransformer(ABC): diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 86bac4119a..873f6a4093 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,10 +4,10 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeBadRequestError -from core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_hosting_provider import hosting_configuration from models.provider import ProviderType diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 370e64e385..600a444357 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 4e3ad7bb75..7eebd9ec95 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -15,7 +15,6 @@ from configs import dify_config from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore @@ -31,6 +30,7 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 370b814cd2..81c42c6269 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -34,15 +34,15 @@ from core.llm_generator.prompts import ( WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from dify_graph.model_runtime.entities.llm_entities import LLMResult +from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index a483775823..63e73d24fc 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -15,18 +15,18 @@ from core.llm_generator.prompts import ( STRUCTURED_OUTPUT_TOOL_CALL_PROMPT, ) from core.model_manager import ModelInstance -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import ( +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import ( LLMResult, LLMResultWithStructuredOutput, ) -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ParameterRule +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ParameterRule class ResponseFormat(StrEnum): diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index da747d2c1f..de68eb268b 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -7,7 +7,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 84bef7b935..db9cb726d7 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -8,7 +8,7 @@ from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 838d29398d..a021d4d8da 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -6,16 +6,16 @@ from sqlalchemy.orm import sessionmaker from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.memory.base import BaseMemory from core.model_manager import ModelInstance -from core.model_runtime.entities import ( +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from dify_graph.file import file_manager +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from core.workflow.file import file_manager +from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 2b3a3be1b9..0f710a8fcf 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -7,20 +7,20 @@ from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import EmbeddingResult -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from core.model_runtime.model_providers.__base.rerank_model import RerankModel -from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.provider_manager import ProviderManager +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import LLMResult +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel from extensions.ext_redis import redis_client from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 5cab4841f5..06676f5cf4 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,6 @@ from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult +from dify_graph.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 46c129099d..19111cc917 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -57,8 +57,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from core.workflow.entities import WorkflowNodeExecution -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 7f68889e92..45319f24c1 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -14,8 +14,8 @@ from core.ops.aliyun_trace.entities.semconv import ( GenAISpanKind, ) from core.rag.models.document import Document -from core.workflow.entities import WorkflowNodeExecution -from core.workflow.enums import WorkflowNodeExecutionStatus +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.ext_database import db from models import EndUser diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 4de4f403ce..28e800e6c7 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 8b8117b24c..b40bc89b71 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index df6e016632..ba2cb9e0c3 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType from extensions.ext_database import db from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 8050c59db9..eeae489c68 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 177991e645..33782e7949 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -35,7 +35,7 @@ from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks if TYPE_CHECKING: - from core.workflow.entities import WorkflowExecution + from dify_graph.entities import WorkflowExecution logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/client.py b/api/core/ops/tencent_trace/client.py index 99ccf00400..c39093bf4c 100644 --- a/api/core/ops/tencent_trace/client.py +++ b/api/core/ops/tencent_trace/client.py @@ -120,7 +120,8 @@ class TencentTraceClient: # Metrics exporter and instruments try: - from opentelemetry.sdk.metrics import Histogram, MeterProvider + from opentelemetry.sdk.metrics import Histogram as SdkHistogram + from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower() @@ -128,7 +129,7 @@ class TencentTraceClient: use_http_json = protocol in {"http/json", "http-json"} # Tencent APM works best with delta aggregation temporality - preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA} + preferred_temporality: dict[type, AggregationTemporality] = {SdkHistogram: AggregationTemporality.DELTA} def _create_metric_exporter(exporter_cls, **kwargs): """Create metric exporter with preferred_temporality support""" diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 26e8779e3e..0a6013e244 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -41,7 +41,7 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 93ec186863..cbff1c9e1c 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -24,10 +24,10 @@ from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from core.workflow.nodes import NodeType +from dify_graph.nodes import NodeType from extensions.ext_database import db from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 2134be0bce..7b62207366 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -31,7 +31,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 0b768fb98e..cc3e1d9422 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -5,18 +5,6 @@ from collections.abc import Generator from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from core.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.entities.request import ( RequestInvokeLLM, @@ -30,6 +18,18 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from dify_graph.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) from models.account import Tenant diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 9fbcbf55b4..33c45c0007 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,17 +1,17 @@ from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from core.workflow.enums import NodeType -from core.workflow.nodes.parameter_extractor.entities import ( +from dify_graph.enums import NodeType +from dify_graph.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) -from core.workflow.nodes.parameter_extractor.entities import ( +from dify_graph.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, ) -from core.workflow.nodes.question_classifier.entities import ( +from dify_graph.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from core.workflow.nodes.question_classifier.entities import ( +from dify_graph.nodes.question_classifier.entities import ( ModelConfig as QuestionClassifierModelConfig, ) from services.workflow_service import WorkflowService diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index cf1f7ff0dd..81e1e12c5f 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,10 @@ from pydantic import BaseModel, Field, computed_field, model_validator -from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 9e1a9edf82..7a3780f7de 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -8,12 +8,12 @@ from pydantic import BaseModel, Field, field_validator, model_validator from core.agent.plugin_entities import AgentStrategyProviderEntity from core.datasource.entities.datasource_entities import DatasourceProviderEntity -from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 6674228dc0..2dc540e6a8 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -10,14 +10,14 @@ from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin -from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity +from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index e1684f9748..1390323458 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -7,7 +7,8 @@ from flask import Response from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.provider_entities import BasicProviderConfig -from core.model_runtime.entities.message_entities import ( +from core.plugin.utils.http_parser import deserialize_response +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -16,18 +17,17 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelType -from core.plugin.utils.http_parser import deserialize_response -from core.workflow.nodes.parameter_extractor.entities import ( +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) -from core.workflow.nodes.parameter_extractor.entities import ( +from dify_graph.nodes.parameter_extractor.entities import ( ParameterConfig, ) -from core.workflow.nodes.question_classifier.entities import ( +from dify_graph.nodes.question_classifier.entities import ( ClassConfig, ) -from core.workflow.nodes.question_classifier.entities import ( +from dify_graph.nodes.question_classifier.entities import ( ModelConfig as QuestionClassifierModelConfig, ) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7a6a598a2f..737d204105 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -9,14 +9,6 @@ from pydantic import BaseModel from yarl import URL from configs import dify_config -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( @@ -35,6 +27,14 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 5d70980967..49ee5d79cb 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,12 +2,6 @@ import binascii from collections.abc import Generator, Sequence from typing import IO -from core.model_runtime.entities.llm_entities import LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import EmbeddingResult -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -19,6 +13,12 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient +from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 3fe1b84dfa..53bcd9e9c6 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,7 @@ from typing import Any from core.tools.entities.tool_entities import ToolSelector -from core.workflow.file.models import File +from dify_graph.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index b9e41c9250..bb9138874e 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -5,7 +5,12 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.base import BaseMemory from core.model_manager import ModelInstance -from core.model_runtime.entities import ( +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from dify_graph.file import file_manager +from dify_graph.file.models import File +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -13,13 +18,8 @@ from core.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.file import file_manager -from core.workflow.file.models import File -from core.workflow.runtime import VariablePool +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from dify_graph.runtime import VariablePool class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index c1ae47709f..d09a46bfde 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -4,13 +4,13 @@ from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.prompt.prompt_transform import PromptTransform +from dify_graph.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_transform import PromptTransform +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 457800bad2..c5faa42e9b 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -3,7 +3,7 @@ from typing import Literal from pydantic import BaseModel -from core.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole class MemoryMode(StrEnum): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index e37ac1d481..004837c72b 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -3,9 +3,9 @@ from typing import Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.base import BaseMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from dify_graph.model_runtime.entities.message_entities import PromptMessage +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 936a093488..10c44349ae 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -7,7 +7,11 @@ from typing import TYPE_CHECKING, Any, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from dify_graph.file import file_manager +from dify_graph.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentUnionTypes, @@ -15,14 +19,10 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.file import file_manager from models.model import AppMode if TYPE_CHECKING: - from core.workflow.file.models import File + from dify_graph.file.models import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 0a7a467227..85a2201395 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,8 @@ from collections.abc import Sequence from typing import Any, cast -from core.model_runtime.entities import ( +from core.prompt.simple_prompt_transform import ModelMode +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, @@ -10,7 +11,6 @@ from core.model_runtime.entities import ( PromptMessageRole, TextPromptMessageContent, ) -from core.prompt.simple_prompt_transform import ModelMode class PromptMessageUtil: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index fdbfca4330..f82c3a846b 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -28,14 +28,14 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.position_helper import is_filtered -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index bfa8781e9f..2b73ef5f26 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,6 +1,4 @@ from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document @@ -8,6 +6,8 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError class DataPostProcessor: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 91c16ce079..e8a3a05e19 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -10,7 +10,6 @@ from sqlalchemy.orm import Session, load_only from configs import dify_config from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -23,6 +22,7 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import ( ChildChunk, diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index b9772b3c08..3225764693 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -8,13 +8,13 @@ from sqlalchemy import select from configs import dify_config from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 69adac522d..16a5588024 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -6,8 +6,8 @@ from typing import Any from sqlalchemy import func, select from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.models.document import AttachmentDocument, Document +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 0efe19a57c..6d1b65a055 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -9,9 +9,9 @@ from sqlalchemy.exc import IntegrityError from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.embedding.embedding_base import Embeddings +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index 95b197c874..c8f9d29012 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,8 +9,8 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory -from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError -from core.workflow.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview +from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError +from dify_graph.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview from models.dataset import Dataset, Document, DocumentSegment from .index_processor_factory import IndexProcessorFactory diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index df5c89a522..9c21dad488 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -12,15 +12,6 @@ from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword @@ -35,7 +26,16 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols -from core.workflow.file import File, FileTransferMethod, FileType, file_manager +from dify_graph.file import File, FileTransferMethod, FileType, file_manager +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs import helper diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 48639bf4c8..dc3b771406 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.file import File +from dify_graph.file import File class ChildDocument(BaseModel): diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 690e780921..fcb14ffc52 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,12 +1,12 @@ import base64 from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.rerank_entities import RerankResult from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankResult from extensions.ext_database import db from extensions.ext_storage import storage from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 18020608cb..7edd05d2d1 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -4,7 +4,6 @@ from collections import Counter import numpy as np from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.index_processor.constant.doc_type import DocType @@ -12,6 +11,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner +from dify_graph.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 459d7bed95..b56ff9edef 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -25,10 +25,6 @@ from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time @@ -60,9 +56,13 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.nodes.knowledge_retrieval import exc -from core.workflow.repositories.rag_retrieval_protocol import ( +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.nodes.knowledge_retrieval import exc +from dify_graph.repositories.rag_retrieval_protocol import ( KnowledgeRetrievalRequest, Source, SourceChildChunk, diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 5f3e1a8cae..23a2ac8386 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -2,8 +2,8 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index fa2007122d..ea110fa0a7 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -4,12 +4,12 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index b65cb14d8e..7a00e8a886 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -7,7 +7,6 @@ import re from typing import Any from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer from core.rag.splitter.text_splitter import ( TS, Collection, @@ -16,6 +15,7 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) +from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index c7f5942f5f..57764574d7 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -11,8 +11,8 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.workflow.entities.workflow_execution import WorkflowExecution -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.entities.workflow_execution import WorkflowExecution +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 9b8e45b1eb..650cf79550 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -12,8 +12,8 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution -from core.workflow.repositories.workflow_node_execution_repository import ( +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution +from dify_graph.repositories.workflow_node_execution_repository import ( OrderConfig, WorkflowNodeExecutionRepository, ) diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 02fcabab5d..dc9f8c96bf 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -11,8 +11,8 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 0e04c56e0e..6607a87032 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,10 +4,11 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, selectinload, sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, selectinload -from core.workflow.nodes.human_input.entities import ( +from core.db.session_factory import session_factory +from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, EmailDeliveryMethod, EmailRecipients, @@ -17,12 +18,12 @@ from core.workflow.nodes.human_input.entities import ( MemberRecipient, WebAppDeliveryMethod, ) -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( DeliveryMethodType, HumanInputFormKind, HumanInputFormStatus, ) -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, FormNotFoundError, HumanInputFormEntity, @@ -198,12 +199,9 @@ class _InvalidTimeoutStatusError(ValueError): class HumanInputFormRepositoryImpl: def __init__( self, - session_factory: sessionmaker | Engine, + *, tenant_id: str, ): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory self._tenant_id = tenant_id def _delivery_method_to_model( @@ -217,7 +215,7 @@ class HumanInputFormRepositoryImpl: id=delivery_id, form_id=form_id, delivery_method_type=delivery_method.type, - delivery_config_id=delivery_method.id, + delivery_config_id=str(delivery_method.id), channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] @@ -343,7 +341,7 @@ class HumanInputFormRepositoryImpl: def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): # Generate unique form ID form_id = str(uuidv7()) start_time = naive_utc_now() @@ -435,7 +433,7 @@ class HumanInputFormRepositoryImpl: HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: form_model: HumanInputForm | None = session.scalars(form_query).first() if form_model is None: return None @@ -448,18 +446,13 @@ class HumanInputFormRepositoryImpl: class HumanInputFormSubmissionRepository: """Repository for fetching and submitting human input forms.""" - def __init__(self, session_factory: sessionmaker | Engine): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - def get_by_token(self, form_token: str) -> HumanInputFormRecord | None: query = ( select(HumanInputFormRecipient) .options(selectinload(HumanInputFormRecipient.form)) .where(HumanInputFormRecipient.access_token == form_token) ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: recipient_model = session.scalars(query).first() if recipient_model is None or recipient_model.form is None: return None @@ -478,7 +471,7 @@ class HumanInputFormSubmissionRepository: HumanInputFormRecipient.recipient_type == recipient_type, ) ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: recipient_model = session.scalars(query).first() if recipient_model is None or recipient_model.form is None: return None @@ -494,7 +487,7 @@ class HumanInputFormSubmissionRepository: submission_user_id: str | None, submission_end_user_id: str | None, ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): form_model = session.get(HumanInputForm, form_id) if form_model is None: raise FormNotFoundError(f"form not found, id={form_id}") @@ -524,7 +517,7 @@ class HumanInputFormSubmissionRepository: timeout_status: HumanInputFormStatus, reason: str | None = None, ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): form_model = session.get(HumanInputForm, form_id) if form_model is None: raise FormNotFoundError(f"form not found, id={form_id}") diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 9091a3190b..649e2f7358 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -9,10 +9,10 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.workflow.entities import WorkflowExecution -from core.workflow.enums import WorkflowExecutionStatus, WorkflowType -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowExecution +from dify_graph.enums import WorkflowExecutionStatus, WorkflowType +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 1c2c7ef426..d1da1d8bd4 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -17,11 +17,11 @@ from sqlalchemy.orm import sessionmaker from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt from configs import dify_config -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities import WorkflowNodeExecution -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_storage import storage from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index c8048888b1..e68b498f88 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: # pragma: no cover from models.model import File from core.model_runtime.entities.message_entities import PromptMessageTool + from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ToolEntity, diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 2c1e9fb555..dacc49c746 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -3,13 +3,13 @@ from collections.abc import Generator from typing import Any from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.workflow.file.enums import FileType -from core.workflow.file.file_manager import download +from dify_graph.file.enums import FileType +from dify_graph.file.file_manager import download +from dify_graph.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 5009f7ac21..7818bff0ab 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -3,11 +3,11 @@ from collections.abc import Generator from typing import Any from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 51b0407886..00f5931088 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,11 +1,11 @@ from __future__ import annotations -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from dify_graph.model_runtime.entities.llm_entities import LLMResult +from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index afa2ddffed..c6a84e27c6 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -13,7 +13,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from core.workflow.file.file_manager import download +from dify_graph.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 218ffafd55..2545290b57 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -5,11 +5,11 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from dify_graph.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 1d439323f2..9025ff6ef1 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -17,11 +17,11 @@ from core.mcp.types import ( TextContent, TextResourceContents, ) -from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index de476f6461..0f0eacbdc4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -31,8 +31,8 @@ from core.tools.errors import ( ) from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool -from core.workflow.file import FileType -from core.workflow.file.models import FileTransferMethod +from dify_graph.file import FileType +from dify_graph.file.models import FileTransferMethod from extensions.ext_database import db from models.enums import CreatorUserRole from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index ca0dc27f3d..83e4e53418 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -243,7 +243,7 @@ class ToolFileManager: # init tool_file_parser -from core.workflow.file.tool_file_parser import set_tool_file_manager_factory +from dify_graph.file.tool_file_parser import set_tool_file_manager_factory def _factory() -> ToolFileManager: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 5dae773841..3938bd0ed7 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -24,20 +24,19 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from core.workflow.runtime.variable_pool import VariablePool +from dify_graph.runtime.variable_pool import VariablePool from extensions.ext_database import db from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from core.workflow.nodes.tool.entities import ToolEntity + from dify_graph.nodes.tool.entities import ToolEntity from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import Tool from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -58,11 +57,12 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from core.workflow.nodes.tool.entities import ToolEntity + from dify_graph.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) @@ -179,7 +179,6 @@ class ToolManager: :return: the tool """ - if provider_type == ToolProviderType.BUILT_IN: # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id, tenant_id) @@ -1017,8 +1016,8 @@ class ToolManager: """ Convert tool parameters type """ - from core.workflow.nodes.tool.entities import ToolNodeData - from core.workflow.nodes.tool.exc import ToolParameterError + from dify_graph.nodes.tool.entities import ToolNodeData + from dify_graph.nodes.tool.exc import ToolParameterError runtime_parameters = {} for parameter in parameters: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 20e10be075..3dbbbe6563 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -7,13 +7,13 @@ from sqlalchemy import select from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 622cdcf73b..6fc5fead2d 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -10,7 +10,7 @@ import pytz from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index e7fba09359..8f958563bd 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -9,18 +9,18 @@ from decimal import Decimal from typing import cast from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.errors.invoke import ( +from dify_graph.model_runtime.entities.llm_entities import LLMResult +from dify_graph.model_runtime.entities.message_entities import PromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 8e8c5e9c6a..d8ce53083b 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -3,9 +3,9 @@ from typing import Any from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from core.workflow.enums import NodeType -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.enums import NodeType +from dify_graph.nodes.base.entities import OutputVariableEntity +from dify_graph.variables.input_entities import VariableEntity class WorkflowToolConfigurationUtils: diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 56faccb407..d73012375d 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -22,7 +22,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from extensions.ext_database import db from models.account import Account from models.model import App, AppMode diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index b2606009a6..9b9aa7a741 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -8,7 +8,6 @@ from typing import Any, cast from sqlalchemy import select from core.db.session_factory import session_factory -from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( @@ -18,7 +17,8 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from factories.file_factory import build_from_mapping from models import Account, Tenant from models.model import App, EndUser diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index bd1ff4ebfe..9b7b3de614 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -19,9 +19,9 @@ from core.trigger.debug.events import ( build_plugin_pool_key, build_webhook_pool_key, ) -from core.workflow.enums import NodeType -from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig +from dify_graph.enums import NodeType +from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData +from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at diff --git a/api/core/workflow/__init__.py b/api/core/workflow/__init__.py index e69de29bb2..57c2ef3d10 100644 --- a/api/core/workflow/__init__.py +++ b/api/core/workflow/__init__.py @@ -0,0 +1,4 @@ +from .node_factory import DifyNodeFactory +from .workflow_entry import WorkflowEntry + +__all__ = ["DifyNodeFactory", "WorkflowEntry"] diff --git a/api/core/app/workflow/node_factory.py b/api/core/workflow/node_factory.py similarity index 80% rename from api/core/app/workflow/node_factory.py rename to api/core/workflow/node_factory.py index 9a56f0fb0d..4cbee08a65 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config +from core.app.entities.app_invoke_entities import DifyRunContext from core.app.llm.model_access import build_dify_model_access from core.datasource.datasource_manager import DatasourceManager from core.helper.code_executor.code_executor import ( @@ -15,44 +16,47 @@ from core.helper.code_executor.code_executor import ( from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.rag.index_processor.index_processor import IndexProcessor from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.summary_index.summary_index import SummaryIndex +from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.graph_config import NodeConfigDict -from core.workflow.enums import NodeType, SystemVariableKey -from core.workflow.file.file_manager import file_manager -from core.workflow.graph.graph import NodeFactory -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor -from core.workflow.nodes.code.entities import CodeLanguage -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.nodes.datasource import DatasourceNode -from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig -from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config -from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.template_transform.template_renderer import ( +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import NodeType, SystemVariableKey +from dify_graph.file.file_manager import file_manager +from dify_graph.graph.graph import NodeFactory +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor +from dify_graph.nodes.code.entities import CodeLanguage +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.nodes.datasource import DatasourceNode +from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig +from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode +from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from dify_graph.nodes.llm.entities import ModelConfig +from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from dify_graph.nodes.llm.node import LLMNode +from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from dify_graph.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, ) -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.variables.segments import StringSegment +from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState def fetch_memory( @@ -108,6 +112,7 @@ class DifyNodeFactory(NodeFactory): ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state + self._dify_context = self._resolve_dify_context(graph_init_params.run_context) self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor() self._code_limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, @@ -119,7 +124,7 @@ class DifyNodeFactory(NodeFactory): max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) - self._template_renderer = CodeExecutorJinja2TemplateRenderer() + self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy self._http_request_tool_file_manager_factory = ToolFileManager @@ -139,7 +144,16 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + + @staticmethod + def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext: + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) @override def create_node(self, node_config: NodeConfigDict) -> Node: @@ -205,6 +219,15 @@ class DifyNodeFactory(NodeFactory): file_manager=self._http_request_file_manager, ) + if node_type == NodeType.HUMAN_INPUT: + return HumanInputNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + ) + if node_type == NodeType.KNOWLEDGE_INDEX: return KnowledgeIndexNode( id=node_id, @@ -254,6 +277,7 @@ class DifyNodeFactory(NodeFactory): graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, unstructured_api_config=self._document_extractor_unstructured_api_config, + http_client=self._http_request_http_client, ) if node_type == NodeType.QUESTION_CLASSIFIER: @@ -344,7 +368,7 @@ class DifyNodeFactory(NodeFactory): ) return fetch_memory( conversation_id=conversation_id, - app_id=self.graph_init_params.app_id, + app_id=self._dify_context.app_id, node_data_memory=node_memory, model_instance=model_instance, ) diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py deleted file mode 100644 index 82a37acbfa..0000000000 --- a/api/core/workflow/nodes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from core.workflow.enums import NodeType - -__all__ = ["NodeType"] diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py index e24c003e4e..0bcee4613c 100644 --- a/api/core/workflow/nodes/command/node.py +++ b/api/core/workflow/nodes/command/node.py @@ -2,16 +2,17 @@ import logging from collections.abc import Mapping, Sequence from typing import Any +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.entities import VariableSelector +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser + from core.sandbox import sandbox_debug from core.sandbox.bash.session import SANDBOX_READY_TIMEOUT from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError from core.virtual_environment.__base.helpers import submit_command, with_connection -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.nodes.command.entities import CommandNodeData from core.workflow.nodes.command.exc import CommandExecutionError diff --git a/api/core/workflow/nodes/file_upload/node.py b/api/core/workflow/nodes/file_upload/node.py index 79e1db26ad..c97d5fe609 100644 --- a/api/core/workflow/nodes/file_upload/node.py +++ b/api/core/workflow/nodes/file_upload/node.py @@ -5,15 +5,16 @@ from collections.abc import Mapping, Sequence from pathlib import PurePosixPath from typing import Any, cast +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.node import Node +from core.workflow.variables.segments import ArrayStringSegment, FileSegment + from core.sandbox.bash.session import SANDBOX_READY_TIMEOUT from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError from core.virtual_environment.__base.helpers import pipeline -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.file import File, FileTransferMethod from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node from core.workflow.variables import ArrayFileSegment -from core.workflow.variables.segments import ArrayStringSegment, FileSegment from core.zip_sandbox import SandboxDownloadItem from .entities import FileUploadNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/__init__.py b/api/core/workflow/nodes/trigger_schedule/__init__.py deleted file mode 100644 index 6773bae502..0000000000 --- a/api/core/workflow/nodes/trigger_schedule/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode - -__all__ = ["TriggerScheduleNode"] diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2c9cebabc3..284c0619f0 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,36 +5,96 @@ from typing import Any, cast from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.app.workflow.node_factory import DifyNodeFactory from core.sandbox import Sandbox -from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams -from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.file.models import File -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from core.workflow.graph_engine.protocols.command_channel import CommandChannel -from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from core.workflow.nodes import NodeType -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.file.models import File +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_engine.protocols.command_channel import CommandChannel +from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from dify_graph.nodes import NodeType +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory -from models.enums import UserFrom from models.workflow import Workflow logger = logging.getLogger(__name__) +class _WorkflowChildEngineBuilder: + @staticmethod + def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None: + """ + Return whether `graph_config["nodes"]` contains the given node id. + + Returns `None` when the nodes payload shape is unexpected, so graph-level + validation can surface the original configuration error. + """ + nodes = graph_config.get("nodes") + if not isinstance(nodes, list): + return None + + for node in nodes: + if not isinstance(node, Mapping): + return None + current_id = node.get("id") + if isinstance(current_id, str) and current_id == node_id: + return True + return False + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) + if has_root_node is False: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + child_graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=root_node_id, + ) + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + child_engine.layer(LLMQuotaLayer()) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + class WorkflowEntry: def __init__( self, @@ -78,6 +138,7 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, graph=graph, @@ -89,6 +150,7 @@ class WorkflowEntry: scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD, scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME, ), + child_engine_builder=self._child_engine_builder, ) # Add debug logging layer when in debug mode @@ -156,13 +218,15 @@ class WorkflowEntry: # init graph init params and runtime state graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, workflow_id=workflow.id, graph_config=workflow.graph_dict, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -298,13 +362,15 @@ class WorkflowEntry: # init graph init params and runtime state graph_init_params = GraphInitParams( - tenant_id=tenant_id, - app_id="", workflow_id="", graph_config=graph_dict, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=tenant_id, + app_id="", + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/core/workflow/README.md b/api/dify_graph/README.md similarity index 98% rename from api/core/workflow/README.md rename to api/dify_graph/README.md index 9a39f976a6..09c4f5afdc 100644 --- a/api/core/workflow/README.md +++ b/api/dify_graph/README.md @@ -114,7 +114,7 @@ The codebase enforces strict layering via import-linter: 1. Inherit from `BaseNode` or appropriate base class 1. Implement `_run()` method 1. Register in `nodes/node_mapping.py` -1. Add tests in `tests/unit_tests/core/workflow/nodes/` +1. Add tests in `tests/unit_tests/dify_graph/nodes/` ### Implementing a Custom Layer diff --git a/api/core/model_runtime/__init__.py b/api/dify_graph/__init__.py similarity index 100% rename from api/core/model_runtime/__init__.py rename to api/dify_graph/__init__.py diff --git a/api/core/workflow/constants.py b/api/dify_graph/constants.py similarity index 100% rename from api/core/workflow/constants.py rename to api/dify_graph/constants.py diff --git a/api/core/workflow/context/__init__.py b/api/dify_graph/context/__init__.py similarity index 84% rename from api/core/workflow/context/__init__.py rename to api/dify_graph/context/__init__.py index fd60917617..103f526bec 100644 --- a/api/core/workflow/context/__init__.py +++ b/api/dify_graph/context/__init__.py @@ -5,7 +5,7 @@ This package provides Flask-independent context management for workflow execution in multi-threaded environments. """ -from core.workflow.context.execution_context import ( +from dify_graph.context.execution_context import ( AppContext, ContextProviderNotFoundError, ExecutionContext, @@ -17,6 +17,7 @@ from core.workflow.context.execution_context import ( register_context_capturer, reset_context_provider, ) +from dify_graph.context.models import SandboxContext __all__ = [ "AppContext", @@ -24,6 +25,7 @@ __all__ = [ "ExecutionContext", "IExecutionContext", "NullAppContext", + "SandboxContext", "capture_current_context", "read_context", "register_context", diff --git a/api/core/workflow/context/execution_context.py b/api/dify_graph/context/execution_context.py similarity index 100% rename from api/core/workflow/context/execution_context.py rename to api/dify_graph/context/execution_context.py diff --git a/api/core/workflow/context/models.py b/api/dify_graph/context/models.py similarity index 100% rename from api/core/workflow/context/models.py rename to api/dify_graph/context/models.py diff --git a/api/core/workflow/conversation_variable_updater.py b/api/dify_graph/conversation_variable_updater.py similarity index 96% rename from api/core/workflow/conversation_variable_updater.py rename to api/dify_graph/conversation_variable_updater.py index 6bfb2b2880..17b19f2502 100644 --- a/api/core/workflow/conversation_variable_updater.py +++ b/api/dify_graph/conversation_variable_updater.py @@ -1,7 +1,7 @@ import abc from typing import Protocol -from core.workflow.variables import VariableBase +from dify_graph.variables import VariableBase class ConversationVariableUpdater(Protocol): diff --git a/api/core/workflow/entities/__init__.py b/api/dify_graph/entities/__init__.py similarity index 100% rename from api/core/workflow/entities/__init__.py rename to api/dify_graph/entities/__init__.py diff --git a/api/core/workflow/entities/agent.py b/api/dify_graph/entities/agent.py similarity index 100% rename from api/core/workflow/entities/agent.py rename to api/dify_graph/entities/agent.py diff --git a/api/core/workflow/entities/graph_config.py b/api/dify_graph/entities/graph_config.py similarity index 100% rename from api/core/workflow/entities/graph_config.py rename to api/dify_graph/entities/graph_config.py diff --git a/api/core/workflow/entities/graph_init_params.py b/api/dify_graph/entities/graph_init_params.py similarity index 62% rename from api/core/workflow/entities/graph_init_params.py rename to api/dify_graph/entities/graph_init_params.py index ff224a28d1..f785d58a52 100644 --- a/api/core/workflow/entities/graph_init_params.py +++ b/api/dify_graph/entities/graph_init_params.py @@ -3,6 +3,8 @@ from typing import Any from pydantic import BaseModel, Field +DIFY_RUN_CONTEXT_KEY = "_dify" + class GraphInitParams(BaseModel): """GraphInitParams encapsulates the configurations and contextual information @@ -16,15 +18,7 @@ class GraphInitParams(BaseModel): """ # init params - tenant_id: str = Field(..., description="tenant / workspace id") - app_id: str = Field(..., description="app id") workflow_id: str = Field(..., description="workflow id") graph_config: Mapping[str, Any] = Field(..., description="graph config") - user_id: str = Field(..., description="user id") - user_from: str = Field( - ..., description="user from, account or end-user" - ) # Should be UserFrom enum: 'account' | 'end-user' - invoke_from: str = Field( - ..., description="invoke from, service-api, web-app, explore or debugger" - ) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger' + run_context: Mapping[str, Any] = Field(..., description="runtime context") call_depth: int = Field(..., description="call depth") diff --git a/api/core/workflow/entities/pause_reason.py b/api/dify_graph/entities/pause_reason.py similarity index 96% rename from api/core/workflow/entities/pause_reason.py rename to api/dify_graph/entities/pause_reason.py index 147f56e8be..86d8c8ca16 100644 --- a/api/core/workflow/entities/pause_reason.py +++ b/api/dify_graph/entities/pause_reason.py @@ -4,7 +4,7 @@ from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, Field -from core.workflow.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.entities import FormInput, UserAction class PauseReasonType(StrEnum): diff --git a/api/core/workflow/entities/tool_entities.py b/api/dify_graph/entities/tool_entities.py similarity index 98% rename from api/core/workflow/entities/tool_entities.py rename to api/dify_graph/entities/tool_entities.py index f67bc63769..45916e0d5d 100644 --- a/api/core/workflow/entities/tool_entities.py +++ b/api/dify_graph/entities/tool_entities.py @@ -3,7 +3,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.file import File +from dify_graph.file import File class ToolResultStatus(StrEnum): diff --git a/api/core/workflow/entities/workflow_execution.py b/api/dify_graph/entities/workflow_execution.py similarity index 96% rename from api/core/workflow/entities/workflow_execution.py rename to api/dify_graph/entities/workflow_execution.py index 1b3fb36f1f..459ac46415 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/dify_graph/entities/workflow_execution.py @@ -13,7 +13,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.enums import WorkflowExecutionStatus, WorkflowType +from dify_graph.enums import WorkflowExecutionStatus, WorkflowType from libs.datetime_utils import naive_utc_now diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/dify_graph/entities/workflow_node_execution.py similarity index 98% rename from api/core/workflow/entities/workflow_node_execution.py rename to api/dify_graph/entities/workflow_node_execution.py index 4abc9c068d..9dd04e331b 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/dify_graph/entities/workflow_node_execution.py @@ -12,7 +12,7 @@ from typing import Any from pydantic import BaseModel, Field, PrivateAttr -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class WorkflowNodeExecution(BaseModel): diff --git a/api/core/workflow/entities/workflow_start_reason.py b/api/dify_graph/entities/workflow_start_reason.py similarity index 100% rename from api/core/workflow/entities/workflow_start_reason.py rename to api/dify_graph/entities/workflow_start_reason.py diff --git a/api/core/workflow/enums.py b/api/dify_graph/enums.py similarity index 100% rename from api/core/workflow/enums.py rename to api/dify_graph/enums.py diff --git a/api/core/workflow/errors.py b/api/dify_graph/errors.py similarity index 88% rename from api/core/workflow/errors.py rename to api/dify_graph/errors.py index 5bf1faee5d..463d17713e 100644 --- a/api/core/workflow/errors.py +++ b/api/dify_graph/errors.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.base.node import Node +from dify_graph.nodes.base.node import Node class WorkflowNodeRunFailedError(Exception): diff --git a/api/core/workflow/file/__init__.py b/api/dify_graph/file/__init__.py similarity index 100% rename from api/core/workflow/file/__init__.py rename to api/dify_graph/file/__init__.py diff --git a/api/core/workflow/file/constants.py b/api/dify_graph/file/constants.py similarity index 100% rename from api/core/workflow/file/constants.py rename to api/dify_graph/file/constants.py diff --git a/api/core/workflow/file/enums.py b/api/dify_graph/file/enums.py similarity index 100% rename from api/core/workflow/file/enums.py rename to api/dify_graph/file/enums.py diff --git a/api/core/workflow/file/file_manager.py b/api/dify_graph/file/file_manager.py similarity index 98% rename from api/core/workflow/file/file_manager.py rename to api/dify_graph/file/file_manager.py index 0f4579e684..8fa7f52b88 100644 --- a/api/core/workflow/file/file_manager.py +++ b/api/dify_graph/file/file_manager.py @@ -5,14 +5,14 @@ import logging from collections.abc import Mapping from configs import dify_config -from core.model_runtime.entities import ( +from dify_graph.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, VideoPromptMessageContent, ) -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.message_entities import ( MultiModalPromptMessageContent, PromptMessageContentUnionTypes, ) diff --git a/api/core/workflow/file/helpers.py b/api/dify_graph/file/helpers.py similarity index 100% rename from api/core/workflow/file/helpers.py rename to api/dify_graph/file/helpers.py diff --git a/api/core/workflow/file/models.py b/api/dify_graph/file/models.py similarity index 98% rename from api/core/workflow/file/models.py rename to api/dify_graph/file/models.py index cd7d3edde8..db12d4f57a 100644 --- a/api/core/workflow/file/models.py +++ b/api/dify_graph/file/models.py @@ -5,7 +5,7 @@ from typing import Any from pydantic import BaseModel, Field, model_validator -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from . import helpers from .constants import FILE_MODEL_IDENTITY diff --git a/api/core/workflow/file/protocols.py b/api/dify_graph/file/protocols.py similarity index 94% rename from api/core/workflow/file/protocols.py rename to api/dify_graph/file/protocols.py index 8d923148e0..24cbb42735 100644 --- a/api/core/workflow/file/protocols.py +++ b/api/dify_graph/file/protocols.py @@ -14,7 +14,7 @@ class HttpResponseProtocol(Protocol): class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``core.workflow.file``. + """Runtime dependencies required by ``dify_graph.file``. Implementations are expected to be provided by integration layers (for example, ``core.app.workflow.file_runtime``) so the workflow package avoids importing diff --git a/api/core/workflow/file/runtime.py b/api/dify_graph/file/runtime.py similarity index 100% rename from api/core/workflow/file/runtime.py rename to api/dify_graph/file/runtime.py diff --git a/api/core/workflow/file/tool_file_parser.py b/api/dify_graph/file/tool_file_parser.py similarity index 100% rename from api/core/workflow/file/tool_file_parser.py rename to api/dify_graph/file/tool_file_parser.py diff --git a/api/core/workflow/graph/__init__.py b/api/dify_graph/graph/__init__.py similarity index 100% rename from api/core/workflow/graph/__init__.py rename to api/dify_graph/graph/__init__.py diff --git a/api/core/workflow/graph/edge.py b/api/dify_graph/graph/edge.py similarity index 91% rename from api/core/workflow/graph/edge.py rename to api/dify_graph/graph/edge.py index 1d57747dbb..f4f67ea6be 100644 --- a/api/core/workflow/graph/edge.py +++ b/api/dify_graph/graph/edge.py @@ -1,7 +1,7 @@ import uuid from dataclasses import dataclass, field -from core.workflow.enums import NodeState +from dify_graph.enums import NodeState @dataclass diff --git a/api/core/workflow/graph/graph.py b/api/dify_graph/graph/graph.py similarity index 98% rename from api/core/workflow/graph/graph.py rename to api/dify_graph/graph/graph.py index b6f577d193..747f8d9e30 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/dify_graph/graph/graph.py @@ -7,9 +7,9 @@ from typing import Protocol, cast, final from pydantic import TypeAdapter -from core.workflow.entities.graph_config import NodeConfigDict -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType -from core.workflow.nodes.base.node import Node +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType +from dify_graph.nodes.base.node import Node from libs.typing import is_str from .edge import Edge diff --git a/api/core/workflow/graph/graph_template.py b/api/dify_graph/graph/graph_template.py similarity index 100% rename from api/core/workflow/graph/graph_template.py rename to api/dify_graph/graph/graph_template.py diff --git a/api/core/workflow/graph/validation.py b/api/dify_graph/graph/validation.py similarity index 98% rename from api/core/workflow/graph/validation.py rename to api/dify_graph/graph/validation.py index 41b4fdfa60..6840bcfed2 100644 --- a/api/core/workflow/graph/validation.py +++ b/api/dify_graph/graph/validation.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol -from core.workflow.enums import NodeExecutionType, NodeType +from dify_graph.enums import NodeExecutionType, NodeType if TYPE_CHECKING: from .graph import Graph diff --git a/api/core/workflow/graph_engine/__init__.py b/api/dify_graph/graph_engine/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/__init__.py rename to api/dify_graph/graph_engine/__init__.py diff --git a/api/core/workflow/graph_engine/_engine_utils.py b/api/dify_graph/graph_engine/_engine_utils.py similarity index 100% rename from api/core/workflow/graph_engine/_engine_utils.py rename to api/dify_graph/graph_engine/_engine_utils.py diff --git a/api/core/workflow/graph_engine/command_channels/README.md b/api/dify_graph/graph_engine/command_channels/README.md similarity index 100% rename from api/core/workflow/graph_engine/command_channels/README.md rename to api/dify_graph/graph_engine/command_channels/README.md diff --git a/api/core/workflow/graph_engine/command_channels/__init__.py b/api/dify_graph/graph_engine/command_channels/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/command_channels/__init__.py rename to api/dify_graph/graph_engine/command_channels/__init__.py diff --git a/api/core/workflow/graph_engine/command_channels/in_memory_channel.py b/api/dify_graph/graph_engine/command_channels/in_memory_channel.py similarity index 100% rename from api/core/workflow/graph_engine/command_channels/in_memory_channel.py rename to api/dify_graph/graph_engine/command_channels/in_memory_channel.py diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/dify_graph/graph_engine/command_channels/redis_channel.py similarity index 100% rename from api/core/workflow/graph_engine/command_channels/redis_channel.py rename to api/dify_graph/graph_engine/command_channels/redis_channel.py diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/dify_graph/graph_engine/command_processing/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/command_processing/__init__.py rename to api/dify_graph/graph_engine/command_processing/__init__.py diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/dify_graph/graph_engine/command_processing/command_handlers.py similarity index 94% rename from api/core/workflow/graph_engine/command_processing/command_handlers.py rename to api/dify_graph/graph_engine/command_processing/command_handlers.py index cfe856d9e8..eefd0c366b 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/dify_graph/graph_engine/command_processing/command_handlers.py @@ -3,8 +3,8 @@ from typing import final from typing_extensions import override -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.runtime import VariablePool +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.runtime import VariablePool from ..domain.graph_execution import GraphExecution from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand diff --git a/api/core/workflow/graph_engine/command_processing/command_processor.py b/api/dify_graph/graph_engine/command_processing/command_processor.py similarity index 100% rename from api/core/workflow/graph_engine/command_processing/command_processor.py rename to api/dify_graph/graph_engine/command_processing/command_processor.py diff --git a/api/core/workflow/graph_engine/config.py b/api/dify_graph/graph_engine/config.py similarity index 100% rename from api/core/workflow/graph_engine/config.py rename to api/dify_graph/graph_engine/config.py diff --git a/api/core/workflow/graph_engine/domain/__init__.py b/api/dify_graph/graph_engine/domain/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/domain/__init__.py rename to api/dify_graph/graph_engine/domain/__init__.py diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/dify_graph/graph_engine/domain/graph_execution.py similarity index 97% rename from api/core/workflow/graph_engine/domain/graph_execution.py rename to api/dify_graph/graph_engine/domain/graph_execution.py index 3ba6e5e37c..0ee4a9f9a7 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/dify_graph/graph_engine/domain/graph_execution.py @@ -8,9 +8,9 @@ from typing import Literal from pydantic import BaseModel, Field -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.enums import NodeState -from core.workflow.runtime.graph_runtime_state import GraphExecutionProtocol +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.enums import NodeState +from dify_graph.runtime.graph_runtime_state import GraphExecutionProtocol from .node_execution import NodeExecution diff --git a/api/core/workflow/graph_engine/domain/node_execution.py b/api/dify_graph/graph_engine/domain/node_execution.py similarity index 96% rename from api/core/workflow/graph_engine/domain/node_execution.py rename to api/dify_graph/graph_engine/domain/node_execution.py index 85700caa3a..ae8f9a5e50 100644 --- a/api/core/workflow/graph_engine/domain/node_execution.py +++ b/api/dify_graph/graph_engine/domain/node_execution.py @@ -4,7 +4,7 @@ NodeExecution entity representing a node's execution state. from dataclasses import dataclass -from core.workflow.enums import NodeState +from dify_graph.enums import NodeState @dataclass diff --git a/api/core/model_runtime/callbacks/__init__.py b/api/dify_graph/graph_engine/entities/__init__.py similarity index 100% rename from api/core/model_runtime/callbacks/__init__.py rename to api/dify_graph/graph_engine/entities/__init__.py diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/dify_graph/graph_engine/entities/commands.py similarity index 96% rename from api/core/workflow/graph_engine/entities/commands.py rename to api/dify_graph/graph_engine/entities/commands.py index 7e7b65247b..c56845cfc4 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/dify_graph/graph_engine/entities/commands.py @@ -11,7 +11,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.variables.variables import Variable +from dify_graph.variables.variables import Variable class CommandType(StrEnum): diff --git a/api/core/workflow/graph_engine/error_handler.py b/api/dify_graph/graph_engine/error_handler.py similarity index 97% rename from api/core/workflow/graph_engine/error_handler.py rename to api/dify_graph/graph_engine/error_handler.py index 62e144c12a..d4ee2922ec 100644 --- a/api/core/workflow/graph_engine/error_handler.py +++ b/api/dify_graph/graph_engine/error_handler.py @@ -6,21 +6,21 @@ import logging import time from typing import TYPE_CHECKING, final -from core.workflow.enums import ( +from dify_graph.enums import ( ErrorStrategy as ErrorStrategyEnum, ) -from core.workflow.enums import ( +from dify_graph.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetryEvent, ) -from core.workflow.node_events import NodeRunResult +from dify_graph.node_events import NodeRunResult if TYPE_CHECKING: from .domain import GraphExecution diff --git a/api/core/workflow/graph_engine/event_management/__init__.py b/api/dify_graph/graph_engine/event_management/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/event_management/__init__.py rename to api/dify_graph/graph_engine/event_management/__init__.py diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/dify_graph/graph_engine/event_management/event_handlers.py similarity index 98% rename from api/core/workflow/graph_engine/event_management/event_handlers.py rename to api/dify_graph/graph_engine/event_management/event_handlers.py index 865d951f88..62e613c846 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/dify_graph/graph_engine/event_management/event_handlers.py @@ -7,10 +7,9 @@ from collections.abc import Mapping from functools import singledispatchmethod from typing import TYPE_CHECKING, final -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunExceptionEvent, @@ -30,7 +29,8 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.runtime import GraphRuntimeState +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime import GraphRuntimeState from ..domain.graph_execution import GraphExecution from ..response_coordinator import ResponseStreamCoordinator diff --git a/api/core/workflow/graph_engine/event_management/event_manager.py b/api/dify_graph/graph_engine/event_management/event_manager.py similarity index 98% rename from api/core/workflow/graph_engine/event_management/event_manager.py rename to api/dify_graph/graph_engine/event_management/event_manager.py index ae2e659543..616f621c3e 100644 --- a/api/core/workflow/graph_engine/event_management/event_manager.py +++ b/api/dify_graph/graph_engine/event_management/event_manager.py @@ -9,7 +9,7 @@ from collections.abc import Generator from contextlib import contextmanager from typing import final -from core.workflow.graph_events import GraphEngineEvent +from dify_graph.graph_events import GraphEngineEvent from ..layers.base import GraphEngineLayer diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py similarity index 88% rename from api/core/workflow/graph_engine/graph_engine.py rename to api/dify_graph/graph_engine/graph_engine.py index 7c46fc2239..ea98a46b06 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -9,14 +9,14 @@ from __future__ import annotations import logging import queue -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, cast, final -from core.workflow.context import capture_current_context -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import NodeExecutionType -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.context import capture_current_context +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import NodeExecutionType +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphEngineEvent, GraphNodeEventBase, GraphRunAbortedEvent, @@ -26,10 +26,11 @@ from core.workflow.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from core.workflow.runtime.graph_runtime_state import GraphProtocol + from dify_graph.runtime.graph_runtime_state import GraphProtocol from .command_processing import ( AbortCommandHandler, @@ -49,8 +50,9 @@ from .protocols.command_channel import CommandChannel from .worker_management import WorkerPool if TYPE_CHECKING: - from core.workflow.graph_engine.domain.graph_execution import GraphExecution - from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator + from dify_graph.entities import GraphInitParams + from dify_graph.graph_engine.domain.graph_execution import GraphExecution + from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator logger = logging.getLogger(__name__) @@ -74,6 +76,7 @@ class GraphEngine: graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, config: GraphEngineConfig = _DEFAULT_CONFIG, + child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" @@ -83,6 +86,9 @@ class GraphEngine: self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel self._config = config + self._child_engine_builder = child_engine_builder + if child_engine_builder is not None: + self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) # Graph execution tracks the overall execution state self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) @@ -214,6 +220,25 @@ class GraphEngine: self._bind_layer_context(layer) return self + def create_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: dict[str, object] | Mapping[str, object], + root_node_id: str, + layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), + ) -> GraphEngine: + return self._graph_runtime_state.create_child_engine( + workflow_id=workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + graph_config=graph_config, + root_node_id=root_node_id, + layers=layers, + ) + def run(self) -> Generator[GraphEngineEvent, None, None]: """ Execute the graph using the modular architecture. diff --git a/api/core/workflow/graph_engine/graph_state_manager.py b/api/dify_graph/graph_engine/graph_state_manager.py similarity index 98% rename from api/core/workflow/graph_engine/graph_state_manager.py rename to api/dify_graph/graph_engine/graph_state_manager.py index d9773645c3..922a968435 100644 --- a/api/core/workflow/graph_engine/graph_state_manager.py +++ b/api/dify_graph/graph_engine/graph_state_manager.py @@ -6,8 +6,8 @@ import threading from collections.abc import Sequence from typing import TypedDict, final -from core.workflow.enums import NodeState -from core.workflow.graph import Edge, Graph +from dify_graph.enums import NodeState +from dify_graph.graph import Edge, Graph from .ready_queue import ReadyQueue diff --git a/api/core/workflow/graph_engine/graph_traversal/__init__.py b/api/dify_graph/graph_engine/graph_traversal/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/graph_traversal/__init__.py rename to api/dify_graph/graph_engine/graph_traversal/__init__.py diff --git a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py b/api/dify_graph/graph_engine/graph_traversal/edge_processor.py similarity index 97% rename from api/core/workflow/graph_engine/graph_traversal/edge_processor.py rename to api/dify_graph/graph_engine/graph_traversal/edge_processor.py index 9bd0f86fbf..c4625a8ff7 100644 --- a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py +++ b/api/dify_graph/graph_engine/graph_traversal/edge_processor.py @@ -5,9 +5,9 @@ Edge processing logic for graph traversal. from collections.abc import Sequence from typing import TYPE_CHECKING, final -from core.workflow.enums import NodeExecutionType -from core.workflow.graph import Edge, Graph -from core.workflow.graph_events import NodeRunStreamChunkEvent +from dify_graph.enums import NodeExecutionType +from dify_graph.graph import Edge, Graph +from dify_graph.graph_events import NodeRunStreamChunkEvent from ..graph_state_manager import GraphStateManager from ..response_coordinator import ResponseStreamCoordinator diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py similarity index 98% rename from api/core/workflow/graph_engine/graph_traversal/skip_propagator.py rename to api/dify_graph/graph_engine/graph_traversal/skip_propagator.py index b9c9243963..76445bccd2 100644 --- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py +++ b/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py @@ -5,7 +5,7 @@ Skip state propagation through the graph. from collections.abc import Sequence from typing import final -from core.workflow.graph import Edge, Graph +from dify_graph.graph import Edge, Graph from ..graph_state_manager import GraphStateManager diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/dify_graph/graph_engine/layers/README.md similarity index 100% rename from api/core/workflow/graph_engine/layers/README.md rename to api/dify_graph/graph_engine/layers/README.md diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/dify_graph/graph_engine/layers/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/layers/__init__.py rename to api/dify_graph/graph_engine/layers/__init__.py diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/dify_graph/graph_engine/layers/base.py similarity index 94% rename from api/core/workflow/graph_engine/layers/base.py rename to api/dify_graph/graph_engine/layers/base.py index ff4a483aed..890336c1ca 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/dify_graph/graph_engine/layers/base.py @@ -7,10 +7,10 @@ intercept and respond to GraphEngine events. from abc import ABC, abstractmethod -from core.workflow.graph_engine.protocols.command_channel import CommandChannel -from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase -from core.workflow.nodes.base.node import Node -from core.workflow.runtime import ReadOnlyGraphRuntimeState +from dify_graph.graph_engine.protocols.command_channel import CommandChannel +from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase +from dify_graph.nodes.base.node import Node +from dify_graph.runtime import ReadOnlyGraphRuntimeState class GraphEngineLayerNotInitializedError(Exception): diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/dify_graph/graph_engine/layers/debug_logging.py similarity index 99% rename from api/core/workflow/graph_engine/layers/debug_logging.py rename to api/dify_graph/graph_engine/layers/debug_logging.py index e0402cd09c..1af2e2db9e 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/dify_graph/graph_engine/layers/debug_logging.py @@ -11,7 +11,7 @@ from typing import Any, final from typing_extensions import override -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, diff --git a/api/core/workflow/graph_engine/layers/execution_limits.py b/api/dify_graph/graph_engine/layers/execution_limits.py similarity index 94% rename from api/core/workflow/graph_engine/layers/execution_limits.py rename to api/dify_graph/graph_engine/layers/execution_limits.py index a2d36d142d..48ba5608d9 100644 --- a/api/core/workflow/graph_engine/layers/execution_limits.py +++ b/api/dify_graph/graph_engine/layers/execution_limits.py @@ -15,13 +15,13 @@ from typing import final from typing_extensions import override -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType -from core.workflow.graph_engine.layers import GraphEngineLayer -from core.workflow.graph_events import ( +from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType +from dify_graph.graph_engine.layers import GraphEngineLayer +from dify_graph.graph_events import ( GraphEngineEvent, NodeRunStartedEvent, ) -from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent class LimitType(StrEnum): diff --git a/api/core/workflow/graph_engine/manager.py b/api/dify_graph/graph_engine/manager.py similarity index 94% rename from api/core/workflow/graph_engine/manager.py rename to api/dify_graph/graph_engine/manager.py index 36f1612af0..955c149069 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/dify_graph/graph_engine/manager.py @@ -10,8 +10,8 @@ import logging from collections.abc import Sequence from typing import final -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from core.workflow.graph_engine.entities.commands import ( +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol +from dify_graph.graph_engine.entities.commands import ( AbortCommand, GraphEngineCommand, PauseCommand, diff --git a/api/core/workflow/graph_engine/orchestration/__init__.py b/api/dify_graph/graph_engine/orchestration/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/orchestration/__init__.py rename to api/dify_graph/graph_engine/orchestration/__init__.py diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/dify_graph/graph_engine/orchestration/dispatcher.py similarity index 99% rename from api/core/workflow/graph_engine/orchestration/dispatcher.py rename to api/dify_graph/graph_engine/orchestration/dispatcher.py index 76dd1c7768..f8aaf20b2f 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/dify_graph/graph_engine/orchestration/dispatcher.py @@ -8,7 +8,7 @@ import threading import time from typing import TYPE_CHECKING, final -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent, diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/dify_graph/graph_engine/orchestration/execution_coordinator.py similarity index 100% rename from api/core/workflow/graph_engine/orchestration/execution_coordinator.py rename to api/dify_graph/graph_engine/orchestration/execution_coordinator.py diff --git a/api/core/workflow/graph_engine/protocols/command_channel.py b/api/dify_graph/graph_engine/protocols/command_channel.py similarity index 100% rename from api/core/workflow/graph_engine/protocols/command_channel.py rename to api/dify_graph/graph_engine/protocols/command_channel.py diff --git a/api/core/workflow/graph_engine/ready_queue/__init__.py b/api/dify_graph/graph_engine/ready_queue/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/ready_queue/__init__.py rename to api/dify_graph/graph_engine/ready_queue/__init__.py diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/dify_graph/graph_engine/ready_queue/factory.py similarity index 100% rename from api/core/workflow/graph_engine/ready_queue/factory.py rename to api/dify_graph/graph_engine/ready_queue/factory.py diff --git a/api/core/workflow/graph_engine/ready_queue/in_memory.py b/api/dify_graph/graph_engine/ready_queue/in_memory.py similarity index 100% rename from api/core/workflow/graph_engine/ready_queue/in_memory.py rename to api/dify_graph/graph_engine/ready_queue/in_memory.py diff --git a/api/core/workflow/graph_engine/ready_queue/protocol.py b/api/dify_graph/graph_engine/ready_queue/protocol.py similarity index 100% rename from api/core/workflow/graph_engine/ready_queue/protocol.py rename to api/dify_graph/graph_engine/ready_queue/protocol.py diff --git a/api/core/workflow/graph_engine/response_coordinator/__init__.py b/api/dify_graph/graph_engine/response_coordinator/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/response_coordinator/__init__.py rename to api/dify_graph/graph_engine/response_coordinator/__init__.py diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/dify_graph/graph_engine/response_coordinator/coordinator.py similarity index 98% rename from api/core/workflow/graph_engine/response_coordinator/coordinator.py rename to api/dify_graph/graph_engine/response_coordinator/coordinator.py index 443b80ac7b..610bda64b0 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/dify_graph/graph_engine/response_coordinator/coordinator.py @@ -14,17 +14,12 @@ from uuid import uuid4 from pydantic import BaseModel, Field -from core.workflow.enums import NodeExecutionType, NodeState -from core.workflow.graph_events import ( - ChunkType, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - ToolCall, - ToolResult, -) -from core.workflow.nodes.base.template import TextSegment, VariableSegment -from core.workflow.runtime import VariablePool -from core.workflow.runtime.graph_runtime_state import GraphProtocol +from dify_graph.entities import ToolCall, ToolResult +from dify_graph.enums import NodeExecutionType, NodeState +from dify_graph.graph_events import ChunkType, NodeRunStreamChunkEvent, NodeRunSucceededEvent +from dify_graph.nodes.base.template import TextSegment, VariableSegment +from dify_graph.runtime import VariablePool +from dify_graph.runtime.graph_runtime_state import GraphProtocol from .path import Path from .session import ResponseSession diff --git a/api/core/workflow/graph_engine/response_coordinator/path.py b/api/dify_graph/graph_engine/response_coordinator/path.py similarity index 100% rename from api/core/workflow/graph_engine/response_coordinator/path.py rename to api/dify_graph/graph_engine/response_coordinator/path.py diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/dify_graph/graph_engine/response_coordinator/session.py similarity index 85% rename from api/core/workflow/graph_engine/response_coordinator/session.py rename to api/dify_graph/graph_engine/response_coordinator/session.py index 5e4fada7d9..0548e88d93 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/dify_graph/graph_engine/response_coordinator/session.py @@ -9,11 +9,11 @@ from __future__ import annotations from dataclasses import dataclass -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base.template import Template -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.knowledge_index import KnowledgeIndexNode -from core.workflow.runtime.graph_runtime_state import NodeProtocol +from dify_graph.nodes.answer.answer_node import AnswerNode +from dify_graph.nodes.base.template import Template +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.knowledge_index import KnowledgeIndexNode +from dify_graph.runtime.graph_runtime_state import NodeProtocol @dataclass diff --git a/api/core/workflow/graph_engine/worker.py b/api/dify_graph/graph_engine/worker.py similarity index 95% rename from api/core/workflow/graph_engine/worker.py rename to api/dify_graph/graph_engine/worker.py index 9e218f6fa6..5c5d0fe5b9 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/dify_graph/graph_engine/worker.py @@ -14,11 +14,11 @@ from typing import TYPE_CHECKING, final from typing_extensions import override -from core.workflow.context import IExecutionContext -from core.workflow.graph import Graph -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event -from core.workflow.nodes.base.node import Node +from dify_graph.context import IExecutionContext +from dify_graph.graph import Graph +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event +from dify_graph.nodes.base.node import Node from .ready_queue import ReadyQueue diff --git a/api/core/workflow/graph_engine/worker_management/__init__.py b/api/dify_graph/graph_engine/worker_management/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/worker_management/__init__.py rename to api/dify_graph/graph_engine/worker_management/__init__.py diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/dify_graph/graph_engine/worker_management/worker_pool.py similarity index 98% rename from api/core/workflow/graph_engine/worker_management/worker_pool.py rename to api/dify_graph/graph_engine/worker_management/worker_pool.py index 2c14f53746..cc93087783 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/dify_graph/graph_engine/worker_management/worker_pool.py @@ -10,9 +10,9 @@ import queue import threading from typing import final -from core.workflow.context import IExecutionContext -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphNodeEventBase +from dify_graph.context import IExecutionContext +from dify_graph.graph import Graph +from dify_graph.graph_events import GraphNodeEventBase from ..config import GraphEngineConfig from ..layers.base import GraphEngineLayer diff --git a/api/core/workflow/graph_events/__init__.py b/api/dify_graph/graph_events/__init__.py similarity index 100% rename from api/core/workflow/graph_events/__init__.py rename to api/dify_graph/graph_events/__init__.py diff --git a/api/core/workflow/graph_events/agent.py b/api/dify_graph/graph_events/agent.py similarity index 100% rename from api/core/workflow/graph_events/agent.py rename to api/dify_graph/graph_events/agent.py diff --git a/api/core/workflow/graph_events/base.py b/api/dify_graph/graph_events/base.py similarity index 90% rename from api/core/workflow/graph_events/base.py rename to api/dify_graph/graph_events/base.py index c5807f7cc1..5ddf5bf4bf 100644 --- a/api/core/workflow/graph_events/base.py +++ b/api/dify_graph/graph_events/base.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field -from core.workflow.enums import NodeType -from core.workflow.node_events import NodeRunResult +from dify_graph.enums import NodeType +from dify_graph.node_events import NodeRunResult class GraphEngineEvent(BaseModel): diff --git a/api/core/workflow/graph_events/graph.py b/api/dify_graph/graph_events/graph.py similarity index 90% rename from api/core/workflow/graph_events/graph.py rename to api/dify_graph/graph_events/graph.py index f46526bcab..f4aaba64d6 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/dify_graph/graph_events/graph.py @@ -1,8 +1,8 @@ from pydantic import Field -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph_events import BaseGraphEvent +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph_events import BaseGraphEvent class GraphRunStartedEvent(BaseGraphEvent): diff --git a/api/core/workflow/graph_events/human_input.py b/api/dify_graph/graph_events/human_input.py similarity index 100% rename from api/core/workflow/graph_events/human_input.py rename to api/dify_graph/graph_events/human_input.py diff --git a/api/core/workflow/graph_events/iteration.py b/api/dify_graph/graph_events/iteration.py similarity index 100% rename from api/core/workflow/graph_events/iteration.py rename to api/dify_graph/graph_events/iteration.py diff --git a/api/core/workflow/graph_events/loop.py b/api/dify_graph/graph_events/loop.py similarity index 100% rename from api/core/workflow/graph_events/loop.py rename to api/dify_graph/graph_events/loop.py diff --git a/api/core/workflow/graph_events/node.py b/api/dify_graph/graph_events/node.py similarity index 97% rename from api/core/workflow/graph_events/node.py rename to api/dify_graph/graph_events/node.py index e6a392a974..e09bf5c706 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/dify_graph/graph_events/node.py @@ -5,8 +5,8 @@ from enum import StrEnum from pydantic import Field from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult -from core.workflow.entities.pause_reason import PauseReason +from dify_graph.entities import AgentNodeStrategyInit, ToolCall, ToolResult +from dify_graph.entities.pause_reason import PauseReason from .base import GraphNodeEventBase diff --git a/api/core/model_runtime/README.md b/api/dify_graph/model_runtime/README.md similarity index 100% rename from api/core/model_runtime/README.md rename to api/dify_graph/model_runtime/README.md diff --git a/api/core/model_runtime/README_CN.md b/api/dify_graph/model_runtime/README_CN.md similarity index 100% rename from api/core/model_runtime/README_CN.md rename to api/dify_graph/model_runtime/README_CN.md diff --git a/api/core/model_runtime/errors/__init__.py b/api/dify_graph/model_runtime/__init__.py similarity index 100% rename from api/core/model_runtime/errors/__init__.py rename to api/dify_graph/model_runtime/__init__.py diff --git a/api/core/model_runtime/model_providers/__base/__init__.py b/api/dify_graph/model_runtime/callbacks/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/__base/__init__.py rename to api/dify_graph/model_runtime/callbacks/__init__.py diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/dify_graph/model_runtime/callbacks/base_callback.py similarity index 94% rename from api/core/model_runtime/callbacks/base_callback.py rename to api/dify_graph/model_runtime/callbacks/base_callback.py index a745a91510..20faf3d6cd 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/dify_graph/model_runtime/callbacks/base_callback.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel _TEXT_COLOR_MAPPING = { "blue": "36;1", diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/dify_graph/model_runtime/callbacks/logging_callback.py similarity index 94% rename from api/core/model_runtime/callbacks/logging_callback.py rename to api/dify_graph/model_runtime/callbacks/logging_callback.py index b366fcc57b..49b9ab27eb 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/dify_graph/model_runtime/callbacks/logging_callback.py @@ -4,10 +4,10 @@ import sys from collections.abc import Sequence from typing import cast -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/entities/__init__.py b/api/dify_graph/model_runtime/entities/__init__.py similarity index 100% rename from api/core/model_runtime/entities/__init__.py rename to api/dify_graph/model_runtime/entities/__init__.py diff --git a/api/core/model_runtime/entities/common_entities.py b/api/dify_graph/model_runtime/entities/common_entities.py similarity index 100% rename from api/core/model_runtime/entities/common_entities.py rename to api/dify_graph/model_runtime/entities/common_entities.py diff --git a/api/core/model_runtime/entities/defaults.py b/api/dify_graph/model_runtime/entities/defaults.py similarity index 98% rename from api/core/model_runtime/entities/defaults.py rename to api/dify_graph/model_runtime/entities/defaults.py index 51c9c51257..53b732e5c6 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/dify_graph/model_runtime/entities/defaults.py @@ -1,4 +1,4 @@ -from core.model_runtime.entities.model_entities import DefaultParameterName +from dify_graph.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/dify_graph/model_runtime/entities/llm_entities.py similarity index 97% rename from api/core/model_runtime/entities/llm_entities.py rename to api/dify_graph/model_runtime/entities/llm_entities.py index 2c7c421eed..eec682a2ae 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/dify_graph/model_runtime/entities/llm_entities.py @@ -7,8 +7,8 @@ from typing import Any, TypedDict, Union from pydantic import BaseModel, Field -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelUsage, PriceInfo class LLMMode(StrEnum): diff --git a/api/core/model_runtime/entities/message_entities.py b/api/dify_graph/model_runtime/entities/message_entities.py similarity index 100% rename from api/core/model_runtime/entities/message_entities.py rename to api/dify_graph/model_runtime/entities/message_entities.py diff --git a/api/core/model_runtime/entities/model_entities.py b/api/dify_graph/model_runtime/entities/model_entities.py similarity index 98% rename from api/core/model_runtime/entities/model_entities.py rename to api/dify_graph/model_runtime/entities/model_entities.py index 19194d162c..fbcde6740a 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/dify_graph/model_runtime/entities/model_entities.py @@ -6,7 +6,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, model_validator -from core.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.common_entities import I18nObject class ModelType(StrEnum): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/dify_graph/model_runtime/entities/provider_entities.py similarity index 95% rename from api/core/model_runtime/entities/provider_entities.py rename to api/dify_graph/model_runtime/entities/provider_entities.py index 2d88751668..97a99ea7ce 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/dify_graph/model_runtime/entities/provider_entities.py @@ -3,8 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType class ConfigurateMethod(StrEnum): diff --git a/api/core/model_runtime/entities/rerank_entities.py b/api/dify_graph/model_runtime/entities/rerank_entities.py similarity index 100% rename from api/core/model_runtime/entities/rerank_entities.py rename to api/dify_graph/model_runtime/entities/rerank_entities.py diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/dify_graph/model_runtime/entities/text_embedding_entities.py similarity index 89% rename from api/core/model_runtime/entities/text_embedding_entities.py rename to api/dify_graph/model_runtime/entities/text_embedding_entities.py index 854c448250..a0210c169d 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/dify_graph/model_runtime/entities/text_embedding_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal from pydantic import BaseModel -from core.model_runtime.entities.model_entities import ModelUsage +from dify_graph.model_runtime.entities.model_entities import ModelUsage class EmbeddingUsage(ModelUsage): diff --git a/api/core/model_runtime/model_providers/__init__.py b/api/dify_graph/model_runtime/errors/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/__init__.py rename to api/dify_graph/model_runtime/errors/__init__.py diff --git a/api/core/model_runtime/errors/invoke.py b/api/dify_graph/model_runtime/errors/invoke.py similarity index 100% rename from api/core/model_runtime/errors/invoke.py rename to api/dify_graph/model_runtime/errors/invoke.py diff --git a/api/core/model_runtime/errors/validate.py b/api/dify_graph/model_runtime/errors/validate.py similarity index 100% rename from api/core/model_runtime/errors/validate.py rename to api/dify_graph/model_runtime/errors/validate.py diff --git a/api/core/model_runtime/memory/__init__.py b/api/dify_graph/model_runtime/memory/__init__.py similarity index 100% rename from api/core/model_runtime/memory/__init__.py rename to api/dify_graph/model_runtime/memory/__init__.py diff --git a/api/core/model_runtime/memory/prompt_message_memory.py b/api/dify_graph/model_runtime/memory/prompt_message_memory.py similarity index 89% rename from api/core/model_runtime/memory/prompt_message_memory.py rename to api/dify_graph/model_runtime/memory/prompt_message_memory.py index 4491ddfd05..a76a7faf71 100644 --- a/api/core/model_runtime/memory/prompt_message_memory.py +++ b/api/dify_graph/model_runtime/memory/prompt_message_memory.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Sequence from typing import Protocol -from core.model_runtime.entities import PromptMessage +from dify_graph.model_runtime.entities import PromptMessage DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 diff --git a/api/core/model_runtime/schema_validators/__init__.py b/api/dify_graph/model_runtime/model_providers/__base/__init__.py similarity index 100% rename from api/core/model_runtime/schema_validators/__init__.py rename to api/dify_graph/model_runtime/model_providers/__base/__init__.py diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py similarity index 97% rename from api/core/model_runtime/model_providers/__base/ai_model.py rename to api/dify_graph/model_runtime/model_providers/__base/ai_model.py index c3e50eaddd..ac7ae9925b 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py @@ -6,9 +6,10 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationError from redis import RedisError from configs import dify_config -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from core.model_runtime.entities.model_entities import ( +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE +from dify_graph.model_runtime.entities.model_entities import ( AIModelEntity, DefaultParameterName, ModelType, @@ -16,7 +17,7 @@ from core.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) -from core.model_runtime.errors.invoke import ( +from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, @@ -24,7 +25,6 @@ from core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py similarity index 98% rename from api/core/model_runtime/model_providers/__base/large_language_model.py rename to api/dify_graph/model_runtime/model_providers/__base/large_language_model.py index c32ab0879e..bf864ca227 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py @@ -7,21 +7,21 @@ from typing import Union from pydantic import ConfigDict from configs import dify_config -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.callbacks.logging_callback import LoggingCallback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageContentUnionTypes, PromptMessageTool, TextPromptMessageContent, ) -from core.model_runtime.entities.model_entities import ( +from dify_graph.model_runtime.entities.model_entities import ( ModelType, PriceType, ) -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py similarity index 89% rename from api/core/model_runtime/model_providers/__base/moderation_model.py rename to api/dify_graph/model_runtime/model_providers/__base/moderation_model.py index 7aff0184f4..5fa3d1634b 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py @@ -2,8 +2,8 @@ import time from pydantic import ConfigDict -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel class ModerationModel(AIModel): diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py similarity index 92% rename from api/core/model_runtime/model_providers/__base/rerank_model.py rename to api/dify_graph/model_runtime/model_providers/__base/rerank_model.py index 0a576b832a..5da2b84b95 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py @@ -1,6 +1,6 @@ -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel class RerankModel(AIModel): diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py similarity index 88% rename from api/core/model_runtime/model_providers/__base/speech2text_model.py rename to api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py index 9d3bf13e79..e69069a85d 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py @@ -2,8 +2,8 @@ from typing import IO from pydantic import ConfigDict -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel class Speech2TextModel(AIModel): diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py similarity index 94% rename from api/core/model_runtime/model_providers/__base/text_embedding_model.py rename to api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py index 4c902e2c11..3438da2ada 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,9 +1,9 @@ from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.entities.text_embedding_entities import EmbeddingResult -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel class TextEmbeddingModel(AIModel): diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py similarity index 100% rename from api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py rename to api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py similarity index 94% rename from api/core/model_runtime/model_providers/__base/tts_model.py rename to api/dify_graph/model_runtime/model_providers/__base/tts_model.py index a83c8be37c..0656529f22 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py @@ -3,8 +3,8 @@ from collections.abc import Iterable from pydantic import ConfigDict -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/utils/__init__.py b/api/dify_graph/model_runtime/model_providers/__init__.py similarity index 100% rename from api/core/model_runtime/utils/__init__.py rename to api/dify_graph/model_runtime/model_providers/__init__.py diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/dify_graph/model_runtime/model_providers/_position.yaml similarity index 100% rename from api/core/model_runtime/model_providers/_position.yaml rename to api/dify_graph/model_runtime/model_providers/_position.yaml diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py similarity index 93% rename from api/core/model_runtime/model_providers/model_provider_factory.py rename to api/dify_graph/model_runtime/model_providers/model_provider_factory.py index 9cfc6889ac..e168fc11d1 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py @@ -10,18 +10,20 @@ from redis import RedisError import contexts from configs import dify_config -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType -from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from core.model_runtime.model_providers.__base.rerank_model import RerankModel -from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.__base.tts_model import TTSModel -from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel +from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator +from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( + ProviderCredentialSchemaValidator, +) from extensions.ext_redis import redis_client from models.provider_ids import ModelProviderID diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/dify_graph/model_runtime/schema_validators/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/entities/__init__.py rename to api/dify_graph/model_runtime/schema_validators/__init__.py diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/dify_graph/model_runtime/schema_validators/common_validator.py similarity index 97% rename from api/core/model_runtime/schema_validators/common_validator.py rename to api/dify_graph/model_runtime/schema_validators/common_validator.py index 2caedeaf48..04cdb8e4f7 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/dify_graph/model_runtime/schema_validators/common_validator.py @@ -1,6 +1,6 @@ from typing import Union, cast -from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType +from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType class CommonValidator: diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py similarity index 78% rename from api/core/model_runtime/schema_validators/model_credential_schema_validator.py rename to api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py index 0ac935ca31..a97796e98f 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py @@ -1,6 +1,6 @@ -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ModelCredentialSchema -from core.model_runtime.schema_validators.common_validator import CommonValidator +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ModelCredentialSchema +from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator class ModelCredentialSchemaValidator(CommonValidator): diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py similarity index 79% rename from api/core/model_runtime/schema_validators/provider_credential_schema_validator.py rename to api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py index 06350f92a9..2fed75a76c 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -1,5 +1,5 @@ -from core.model_runtime.entities.provider_entities import ProviderCredentialSchema -from core.model_runtime.schema_validators.common_validator import CommonValidator +from dify_graph.model_runtime.entities.provider_entities import ProviderCredentialSchema +from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator class ProviderCredentialSchemaValidator(CommonValidator): diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/dify_graph/model_runtime/utils/__init__.py similarity index 100% rename from api/core/workflow/nodes/answer/__init__.py rename to api/dify_graph/model_runtime/utils/__init__.py diff --git a/api/core/model_runtime/utils/encoders.py b/api/dify_graph/model_runtime/utils/encoders.py similarity index 100% rename from api/core/model_runtime/utils/encoders.py rename to api/dify_graph/model_runtime/utils/encoders.py diff --git a/api/core/workflow/node_events/__init__.py b/api/dify_graph/node_events/__init__.py similarity index 100% rename from api/core/workflow/node_events/__init__.py rename to api/dify_graph/node_events/__init__.py diff --git a/api/core/workflow/node_events/agent.py b/api/dify_graph/node_events/agent.py similarity index 100% rename from api/core/workflow/node_events/agent.py rename to api/dify_graph/node_events/agent.py diff --git a/api/core/workflow/node_events/base.py b/api/dify_graph/node_events/base.py similarity index 86% rename from api/core/workflow/node_events/base.py rename to api/dify_graph/node_events/base.py index 7fec47e21f..2f6259ae7d 100644 --- a/api/core/workflow/node_events/base.py +++ b/api/dify_graph/node_events/base.py @@ -3,8 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage class NodeEventBase(BaseModel): diff --git a/api/core/workflow/node_events/iteration.py b/api/dify_graph/node_events/iteration.py similarity index 100% rename from api/core/workflow/node_events/iteration.py rename to api/dify_graph/node_events/iteration.py diff --git a/api/core/workflow/node_events/loop.py b/api/dify_graph/node_events/loop.py similarity index 100% rename from api/core/workflow/node_events/loop.py rename to api/dify_graph/node_events/loop.py diff --git a/api/core/workflow/node_events/node.py b/api/dify_graph/node_events/node.py similarity index 94% rename from api/core/workflow/node_events/node.py rename to api/dify_graph/node_events/node.py index 3101e2f534..f4432cacab 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/dify_graph/node_events/node.py @@ -4,12 +4,12 @@ from enum import StrEnum from pydantic import Field -from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import ToolCall, ToolResult -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.file import File -from core.workflow.node_events import NodeRunResult +from dify_graph.entities import ToolCall, ToolResult +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.file import File +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeRunResult from .base import NodeEventBase diff --git a/api/dify_graph/nodes/__init__.py b/api/dify_graph/nodes/__init__.py new file mode 100644 index 0000000000..d113ad5e70 --- /dev/null +++ b/api/dify_graph/nodes/__init__.py @@ -0,0 +1,3 @@ +from dify_graph.enums import NodeType + +__all__ = ["NodeType"] diff --git a/api/core/workflow/nodes/agent/__init__.py b/api/dify_graph/nodes/agent/__init__.py similarity index 100% rename from api/core/workflow/nodes/agent/__init__.py rename to api/dify_graph/nodes/agent/__init__.py diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/dify_graph/nodes/agent/agent_node.py similarity index 95% rename from api/core/workflow/nodes/agent/agent_node.py rename to api/dify_graph/nodes/agent/agent_node.py index ee2769eecb..fa3e1b4d7c 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/dify_graph/nodes/agent/agent_node.py @@ -15,15 +15,6 @@ from core.memory.base import BaseMemory from core.memory.node_token_buffer_memory import NodeTokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType -from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import MemoryMode from core.provider_manager import ProviderManager from core.tools.entities.tool_entities import ( @@ -34,25 +25,34 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.enums import ( +from dify_graph.enums import ( NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.file import File, FileTransferMethod -from core.workflow.node_events import ( +from dify_graph.file import File, FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import ( AgentLogEvent, NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent, ) -from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ArrayFileSegment, StringSegment +from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import ArrayFileSegment, StringSegment from extensions.ext_database import db from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy @@ -89,9 +89,11 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError + dify_ctx = self.require_dify_context() + try: strategy = get_plugin_agent_strategy( - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, agent_strategy_name=self.node_data.agent_strategy_name, ) @@ -129,8 +131,8 @@ class AgentNode(Node[AgentNodeData]): try: message_stream = strategy.invoke( params=parameters, - user_id=self.user_id, - app_id=self.app_id, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, credentials=credentials, ) @@ -156,8 +158,8 @@ class AgentNode(Node[AgentNodeData]): "agent_strategy": self.node_data.agent_strategy_name, }, parameters_for_log=parameters_for_log, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, @@ -296,8 +298,13 @@ class AgentNode(Node[AgentNodeData]): runtime_variable_pool: VariablePool | None = None if node_data.version != "1" or node_data.tool_node_version is not None: runtime_variable_pool = variable_pool + dify_ctx = self.require_dify_context() tool_runtime = ToolManager.get_agent_tool_runtime( - self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool + dify_ctx.tenant_id, + dify_ctx.app_id, + entity, + dify_ctx.invoke_from, + runtime_variable_pool, ) if tool_runtime.entity.description: tool_runtime.entity.description.llm = ( @@ -409,7 +416,8 @@ class AgentNode(Node[AgentNodeData]): from core.plugin.impl.plugin import PluginInstaller manager = PluginInstaller() - plugins = manager.list_plugins(self.tenant_id) + dify_ctx = self.require_dify_context() + plugins = manager.list_plugins(dify_ctx.tenant_id) try: current_plugin = next( plugin @@ -442,21 +450,19 @@ class AgentNode(Node[AgentNodeData]): return None conversation_id = conversation_id_variable.value - # Return appropriate memory type based on mode + dify_ctx = self.require_dify_context() if memory_config.mode == MemoryMode.NODE: - # Node-level memory (Chatflow only) return NodeTokenBufferMemory( - app_id=self.app_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id, node_id=self._node_id, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, model_instance=model_instance, ) else: - # Conversation-level memory (default) with Session(db.engine, expire_on_commit=False) as session: stmt = select(Conversation).where( - Conversation.app_id == self.app_id, Conversation.id == conversation_id + Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id ) conversation = session.scalar(stmt) if not conversation: @@ -464,9 +470,10 @@ class AgentNode(Node[AgentNodeData]): return TokenBufferMemory(conversation=conversation, model_instance=model_instance) def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: + dify_ctx = self.require_dify_context() provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM + tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM ) model_name = value.get("model", "") model_credentials = provider_model_bundle.configuration.get_current_credentials( @@ -475,7 +482,7 @@ class AgentNode(Node[AgentNodeData]): provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance model_instance = ModelManager().get_model_instance( - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), model=model_name, @@ -510,9 +517,10 @@ class AgentNode(Node[AgentNodeData]): Fetch memory instance for saving node memory. This is a simplified version that doesn't require model_instance. """ - from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager + node_data = self.node_data if not node_data.memory: return None diff --git a/api/core/workflow/nodes/agent/entities.py b/api/dify_graph/nodes/agent/entities.py similarity index 95% rename from api/core/workflow/nodes/agent/entities.py rename to api/dify_graph/nodes/agent/entities.py index 985ee5eef2..9124420f01 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/dify_graph/nodes/agent/entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from core.workflow.nodes.base.entities import BaseNodeData +from dify_graph.nodes.base.entities import BaseNodeData class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/exc.py b/api/dify_graph/nodes/agent/exc.py similarity index 100% rename from api/core/workflow/nodes/agent/exc.py rename to api/dify_graph/nodes/agent/exc.py diff --git a/api/core/workflow/nodes/end/__init__.py b/api/dify_graph/nodes/answer/__init__.py similarity index 100% rename from api/core/workflow/nodes/end/__init__.py rename to api/dify_graph/nodes/answer/__init__.py diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/dify_graph/nodes/answer/answer_node.py similarity index 83% rename from api/core/workflow/nodes/answer/answer_node.py rename to api/dify_graph/nodes/answer/answer_node.py index 388447368e..d07b9c8062 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/dify_graph/nodes/answer/answer_node.py @@ -1,13 +1,13 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.answer.entities import AnswerNodeData -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.template import Template -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.variables import ArrayFileSegment, FileSegment, Segment +from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.answer.entities import AnswerNodeData +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.template import Template +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.variables import ArrayFileSegment, FileSegment, Segment class AnswerNode(Node[AnswerNodeData]): diff --git a/api/core/workflow/nodes/answer/entities.py b/api/dify_graph/nodes/answer/entities.py similarity index 97% rename from api/core/workflow/nodes/answer/entities.py rename to api/dify_graph/nodes/answer/entities.py index 850ff14880..06927cd71e 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/dify_graph/nodes/answer/entities.py @@ -3,7 +3,7 @@ from enum import StrEnum, auto from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData class AnswerNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/base/__init__.py b/api/dify_graph/nodes/base/__init__.py similarity index 100% rename from api/core/workflow/nodes/base/__init__.py rename to api/dify_graph/nodes/base/__init__.py diff --git a/api/core/workflow/nodes/base/entities.py b/api/dify_graph/nodes/base/entities.py similarity index 99% rename from api/core/workflow/nodes/base/entities.py rename to api/dify_graph/nodes/base/entities.py index fbe7d2c48d..7d6dffe9e2 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/dify_graph/nodes/base/entities.py @@ -9,7 +9,7 @@ from typing import Any, Union from pydantic import BaseModel, field_validator, model_validator -from core.workflow.enums import ErrorStrategy +from dify_graph.enums import ErrorStrategy from .exc import DefaultValueTypeError diff --git a/api/core/workflow/nodes/base/exc.py b/api/dify_graph/nodes/base/exc.py similarity index 100% rename from api/core/workflow/nodes/base/exc.py rename to api/dify_graph/nodes/base/exc.py diff --git a/api/core/workflow/nodes/base/node.py b/api/dify_graph/nodes/base/node.py similarity index 91% rename from api/core/workflow/nodes/base/node.py rename to api/dify_graph/nodes/base/node.py index 12eda2c3ce..6ea67a02b7 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -8,13 +8,19 @@ from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod from types import MappingProxyType -from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin +from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_events import ( +from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import ( + ErrorStrategy, + NodeExecutionType, + NodeState, + NodeType, + WorkflowNodeExecutionStatus, +) +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunFailedEvent, @@ -34,7 +40,7 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import ( +from dify_graph.node_events import ( AgentLogEvent, HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, @@ -56,17 +62,34 @@ from core.workflow.node_events import ( ToolCallChunkEvent, ToolResultChunkEvent, ) -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now -from models.enums import UserFrom from .entities import BaseNodeData, RetryConfig NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) +_MISSING_RUN_CONTEXT_VALUE = object() logger = logging.getLogger(__name__) +class DifyRunContextProtocol(Protocol): + tenant_id: str + app_id: str + user_id: str + user_from: Any + invoke_from: Any + + +class _MappingDifyRunContext: + def __init__(self, mapping: Mapping[str, Any]) -> None: + self.tenant_id = str(mapping["tenant_id"]) + self.app_id = str(mapping["app_id"]) + self.user_id = str(mapping["user_id"]) + self.user_from = mapping["user_from"] + self.invoke_from = mapping["invoke_from"] + + class Node(Generic[NodeDataT]): """BaseNode serves as the foundational class for all node implementations. @@ -159,7 +182,7 @@ class Node(Generic[NodeDataT]): # Skip base class itself if cls is Node: return - # Only register production node implementations defined under core.workflow.nodes.* + # Only register production node implementations defined under dify_graph.nodes.* # This prevents test helper subclasses from polluting the global registry and # accidentally overriding real node types (e.g., a test Answer node). module_name = getattr(cls, "__module__", "") @@ -167,7 +190,7 @@ class Node(Generic[NodeDataT]): node_type = cls.node_type version = cls.version() bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith("core.workflow.nodes."): + if module_name.startswith("dify_graph.nodes."): # Production node definitions take precedence and may override bucket[version] = cls # type: ignore[index] else: @@ -226,14 +249,10 @@ class Node(Generic[NodeDataT]): graph_runtime_state: GraphRuntimeState, ) -> None: self._graph_init_params = graph_init_params + self._run_context = MappingProxyType(dict(graph_init_params.run_context)) self.id = id - self.tenant_id = graph_init_params.tenant_id - self.app_id = graph_init_params.app_id self.workflow_id = graph_init_params.workflow_id self.graph_config = graph_init_params.graph_config - self.user_id = graph_init_params.user_id - self.user_from = UserFrom(graph_init_params.user_from) - self.invoke_from = InvokeFrom(graph_init_params.invoke_from) self.workflow_call_depth = graph_init_params.call_depth self.graph_runtime_state = graph_runtime_state self.state: NodeState = NodeState.UNKNOWN # node execution state @@ -262,6 +281,38 @@ class Node(Generic[NodeDataT]): def graph_init_params(self) -> GraphInitParams: return self._graph_init_params + @property + def run_context(self) -> Mapping[str, Any]: + return self._run_context + + def get_run_context_value(self, key: str, default: Any = None) -> Any: + return self._run_context.get(key, default) + + def require_run_context_value(self, key: str) -> Any: + value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) + if value is _MISSING_RUN_CONTEXT_VALUE: + raise ValueError(f"run_context missing required key: {key}") + return value + + def require_dify_context(self) -> DifyRunContextProtocol: + raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + + if isinstance(raw_ctx, Mapping): + missing_keys = [ + key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx + ] + if missing_keys: + raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") + return _MappingDifyRunContext(raw_ctx) + + for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): + if not hasattr(raw_ctx, attr): + raise TypeError(f"invalid dify context object, missing attribute: {attr}") + + return cast(DifyRunContextProtocol, raw_ctx) + @property def execution_id(self) -> str: return self._node_execution_id @@ -378,13 +429,13 @@ class Node(Generic[NodeDataT]): ) # === FIXME(-LAN-): Needs to refactor. - from core.workflow.nodes.tool.tool_node import ToolNode + from dify_graph.nodes.tool.tool_node import ToolNode if isinstance(self, ToolNode): start_event.provider_id = getattr(self.node_data, "provider_id", "") start_event.provider_type = getattr(self.node_data, "provider_type", "") - from core.workflow.nodes.datasource.datasource_node import DatasourceNode + from dify_graph.nodes.datasource.datasource_node import DatasourceNode if isinstance(self, DatasourceNode): plugin_id = getattr(self.node_data, "plugin_id", "") @@ -393,7 +444,7 @@ class Node(Generic[NodeDataT]): start_event.provider_id = f"{plugin_id}/{provider_name}" start_event.provider_type = getattr(self.node_data, "provider_type", "") - from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode + from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode if isinstance(self, TriggerEventNode): start_event.provider_id = getattr(self.node_data, "provider_id", "") @@ -401,8 +452,8 @@ class Node(Generic[NodeDataT]): from typing import cast - from core.workflow.nodes.agent.agent_node import AgentNode - from core.workflow.nodes.agent.entities import AgentNodeData + from dify_graph.nodes.agent.agent_node import AgentNode + from dify_graph.nodes.agent.entities import AgentNodeData if isinstance(self, AgentNode): start_event.agent_strategy = AgentNodeStrategyInit( @@ -533,22 +584,22 @@ class Node(Generic[NodeDataT]): # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`. # # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING` - # in `api/core/workflow/nodes/__init__.py`. + # in `api/dify_graph/nodes/__init__.py`. raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. - Import all modules under core.workflow.nodes so subclasses register themselves on import. + Import all modules under dify_graph.nodes so subclasses register themselves on import. Then we return a readonly view of the registry to avoid accidental mutation. """ # Import all node modules to ensure they are loaded (thus registered) - import core.workflow.nodes as _nodes_pkg + import dify_graph.nodes as _nodes_pkg for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."): # Avoid importing modules that depend on the registry to prevent circular imports. - if _modname == "core.workflow.nodes.node_mapping": + if _modname == "dify_graph.nodes.node_mapping": continue importlib.import_module(_modname) diff --git a/api/core/workflow/nodes/base/template.py b/api/dify_graph/nodes/base/template.py similarity index 98% rename from api/core/workflow/nodes/base/template.py rename to api/dify_graph/nodes/base/template.py index 81f4b9f6fb..5976e808e3 100644 --- a/api/core/workflow/nodes/base/template.py +++ b/api/dify_graph/nodes/base/template.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Union -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser @dataclass(frozen=True) diff --git a/api/core/workflow/nodes/base/usage_tracking_mixin.py b/api/dify_graph/nodes/base/usage_tracking_mixin.py similarity index 89% rename from api/core/workflow/nodes/base/usage_tracking_mixin.py rename to api/dify_graph/nodes/base/usage_tracking_mixin.py index d9a0ef8972..bd49419fd3 100644 --- a/api/core/workflow/nodes/base/usage_tracking_mixin.py +++ b/api/dify_graph/nodes/base/usage_tracking_mixin.py @@ -1,5 +1,5 @@ -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.runtime import GraphRuntimeState +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime import GraphRuntimeState class LLMUsageTrackingMixin: diff --git a/api/core/workflow/nodes/base/variable_template_parser.py b/api/dify_graph/nodes/base/variable_template_parser.py similarity index 100% rename from api/core/workflow/nodes/base/variable_template_parser.py rename to api/dify_graph/nodes/base/variable_template_parser.py diff --git a/api/core/workflow/nodes/code/__init__.py b/api/dify_graph/nodes/code/__init__.py similarity index 100% rename from api/core/workflow/nodes/code/__init__.py rename to api/dify_graph/nodes/code/__init__.py diff --git a/api/core/workflow/nodes/code/code_node.py b/api/dify_graph/nodes/code/code_node.py similarity index 97% rename from api/core/workflow/nodes/code/code_node.py rename to api/dify_graph/nodes/code/code_node.py index 7b1cbfcfea..83e72deea9 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/dify_graph/nodes/code/code_node.py @@ -3,13 +3,13 @@ from decimal import Decimal from textwrap import dedent from typing import TYPE_CHECKING, Any, Protocol, cast -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.variables.segments import ArrayFileSegment -from core.workflow.variables.types import SegmentType +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.variables.segments import ArrayFileSegment +from dify_graph.variables.types import SegmentType from .exc import ( CodeNodeError, @@ -18,8 +18,8 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class WorkflowCodeExecutor(Protocol): diff --git a/api/core/workflow/nodes/code/entities.py b/api/dify_graph/nodes/code/entities.py similarity index 88% rename from api/core/workflow/nodes/code/entities.py rename to api/dify_graph/nodes/code/entities.py index 8b73b89e2f..9e161c29d0 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/dify_graph/nodes/code/entities.py @@ -3,9 +3,9 @@ from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.base import BaseNodeData +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.variables.types import SegmentType class CodeLanguage(StrEnum): diff --git a/api/core/workflow/nodes/code/exc.py b/api/dify_graph/nodes/code/exc.py similarity index 100% rename from api/core/workflow/nodes/code/exc.py rename to api/dify_graph/nodes/code/exc.py diff --git a/api/core/workflow/nodes/code/limits.py b/api/dify_graph/nodes/code/limits.py similarity index 100% rename from api/core/workflow/nodes/code/limits.py rename to api/dify_graph/nodes/code/limits.py diff --git a/api/core/workflow/nodes/datasource/__init__.py b/api/dify_graph/nodes/datasource/__init__.py similarity index 100% rename from api/core/workflow/nodes/datasource/__init__.py rename to api/dify_graph/nodes/datasource/__init__.py diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/dify_graph/nodes/datasource/datasource_node.py similarity index 91% rename from api/core/workflow/nodes/datasource/datasource_node.py rename to api/dify_graph/nodes/datasource/datasource_node.py index 17f8bcb2db..b97394744e 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/dify_graph/nodes/datasource/datasource_node.py @@ -3,12 +3,12 @@ from typing import TYPE_CHECKING, Any from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.repositories.datasource_manager_protocol import ( +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey +from dify_graph.node_events import NodeRunResult, StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.repositories.datasource_manager_protocol import ( DatasourceManagerProtocol, DatasourceParameter, OnlineDriveDownloadFileParam, @@ -19,8 +19,8 @@ from .entities import DatasourceNodeData from .exc import DatasourceNodeError if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class DatasourceNode(Node[DatasourceNodeData]): @@ -52,6 +52,7 @@ class DatasourceNode(Node[DatasourceNodeData]): Run the datasource node """ + dify_ctx = self.require_dify_context() node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) @@ -75,7 +76,7 @@ class DatasourceNode(Node[DatasourceNodeData]): datasource_info["icon"] = self.datasource_manager.get_icon_url( provider_id=provider_id, datasource_name=node_data.datasource_name or "", - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, datasource_type=datasource_type.value, ) @@ -104,11 +105,11 @@ class DatasourceNode(Node[DatasourceNodeData]): yield from self.datasource_manager.stream_node_events( node_id=self._node_id, - user_id=self.user_id, + user_id=dify_ctx.user_id, datasource_name=node_data.datasource_name or "", datasource_type=datasource_type.value, provider_id=provider_id, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, provider=node_data.provider_name, plugin_id=node_data.plugin_id, credential_id=credential_id, @@ -136,7 +137,7 @@ class DatasourceNode(Node[DatasourceNodeData]): raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=self.tenant_id + file_id=related_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/dify_graph/nodes/datasource/entities.py similarity index 96% rename from api/core/workflow/nodes/datasource/entities.py rename to api/dify_graph/nodes/datasource/entities.py index 4802d3ed98..ba49e65f31 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/dify_graph/nodes/datasource/entities.py @@ -3,7 +3,7 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.workflow.nodes.base.entities import BaseNodeData +from dify_graph.nodes.base.entities import BaseNodeData class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/datasource/exc.py b/api/dify_graph/nodes/datasource/exc.py similarity index 100% rename from api/core/workflow/nodes/datasource/exc.py rename to api/dify_graph/nodes/datasource/exc.py diff --git a/api/core/workflow/nodes/document_extractor/__init__.py b/api/dify_graph/nodes/document_extractor/__init__.py similarity index 100% rename from api/core/workflow/nodes/document_extractor/__init__.py rename to api/dify_graph/nodes/document_extractor/__init__.py diff --git a/api/core/workflow/nodes/document_extractor/entities.py b/api/dify_graph/nodes/document_extractor/entities.py similarity index 84% rename from api/core/workflow/nodes/document_extractor/entities.py rename to api/dify_graph/nodes/document_extractor/entities.py index db05bbf4fe..f4949d0df8 100644 --- a/api/core/workflow/nodes/document_extractor/entities.py +++ b/api/dify_graph/nodes/document_extractor/entities.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from dataclasses import dataclass -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData class DocumentExtractorNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/document_extractor/exc.py b/api/dify_graph/nodes/document_extractor/exc.py similarity index 100% rename from api/core/workflow/nodes/document_extractor/exc.py rename to api/dify_graph/nodes/document_extractor/exc.py diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py similarity index 95% rename from api/core/workflow/nodes/document_extractor/node.py rename to api/dify_graph/nodes/document_extractor/node.py index 59be4c54ef..5945e57926 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/dify_graph/nodes/document_extractor/node.py @@ -20,13 +20,13 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph -from core.helper import ssrf_proxy -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod, file_manager -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.variables import ArrayFileSegment -from core.workflow.variables.segments import ArrayStringSegment, FileSegment +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, file_manager +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.variables import ArrayFileSegment +from dify_graph.variables.segments import ArrayStringSegment, FileSegment from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -34,8 +34,8 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class DocumentExtractorNode(Node[DocumentExtractorNodeData]): @@ -58,6 +58,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): graph_runtime_state: "GraphRuntimeState", *, unstructured_api_config: UnstructuredApiConfig | None = None, + http_client: HttpClientProtocol, ) -> None: super().__init__( id=id, @@ -66,6 +67,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): graph_runtime_state=graph_runtime_state, ) self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() + self._http_client = http_client def _run(self): variable_selector = self.node_data.variable_selector @@ -85,7 +87,9 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): try: if isinstance(value, list): extracted_text_list = [ - _extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config) + _extract_text_from_file( + self._http_client, file, unstructured_api_config=self._unstructured_api_config + ) for file in value ] return NodeRunResult( @@ -95,7 +99,9 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): - extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config) + extracted_text = _extract_text_from_file( + self._http_client, value, unstructured_api_config=self._unstructured_api_config + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -439,13 +445,13 @@ def _extract_text_from_docx(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e -def _download_file_content(file: File) -> bytes: +def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes: """Download the content of a file based on its transfer method.""" try: if file.transfer_method == FileTransferMethod.REMOTE_URL: if file.remote_url is None: raise FileDownloadError("Missing URL for remote file") - response = ssrf_proxy.get(file.remote_url) + response = http_client.get(file.remote_url) response.raise_for_status() return response.content else: @@ -454,8 +460,10 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError(f"Error downloading file: {str(e)}") from e -def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str: - file_content = _download_file_content(file) +def _extract_text_from_file( + http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig +) -> str: + file_content = _download_file_content(http_client, file) if file.extension: extracted_text = _extract_text_by_file_extension( file_content=file_content, diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/dify_graph/nodes/end/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/__init__.py rename to api/dify_graph/nodes/end/__init__.py diff --git a/api/core/workflow/nodes/end/end_node.py b/api/dify_graph/nodes/end/end_node.py similarity index 81% rename from api/core/workflow/nodes/end/end_node.py rename to api/dify_graph/nodes/end/end_node.py index 2efcb4f418..7aa526b85b 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/dify_graph/nodes/end/end_node.py @@ -1,8 +1,8 @@ -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.template import Template -from core.workflow.nodes.end.entities import EndNodeData +from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.template import Template +from dify_graph.nodes.end.entities import EndNodeData class EndNode(Node[EndNodeData]): diff --git a/api/core/workflow/nodes/end/entities.py b/api/dify_graph/nodes/end/entities.py similarity index 87% rename from api/core/workflow/nodes/end/entities.py rename to api/dify_graph/nodes/end/entities.py index 87a221b5f6..a410087214 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/dify_graph/nodes/end/entities.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, Field -from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity +from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity class EndNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/dify_graph/nodes/http_request/__init__.py similarity index 100% rename from api/core/workflow/nodes/http_request/__init__.py rename to api/dify_graph/nodes/http_request/__init__.py diff --git a/api/core/workflow/nodes/http_request/config.py b/api/dify_graph/nodes/http_request/config.py similarity index 100% rename from api/core/workflow/nodes/http_request/config.py rename to api/dify_graph/nodes/http_request/config.py diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/dify_graph/nodes/http_request/entities.py similarity index 99% rename from api/core/workflow/nodes/http_request/entities.py rename to api/dify_graph/nodes/http_request/entities.py index 0eda20f485..a5564689f8 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/dify_graph/nodes/http_request/entities.py @@ -8,7 +8,7 @@ import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/dify_graph/nodes/http_request/exc.py similarity index 100% rename from api/core/workflow/nodes/http_request/exc.py rename to api/dify_graph/nodes/http_request/exc.py diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/dify_graph/nodes/http_request/executor.py similarity index 99% rename from api/core/workflow/nodes/http_request/executor.py rename to api/dify_graph/nodes/http_request/executor.py index de14c8c517..892b0fc688 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/dify_graph/nodes/http_request/executor.py @@ -10,9 +10,9 @@ from urllib.parse import urlencode, urlparse import httpx from json_repair import repair_json -from core.workflow.file.enums import FileTransferMethod -from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ArrayFileSegment, FileSegment +from dify_graph.file.enums import FileTransferMethod +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import ArrayFileSegment, FileSegment from ..protocols import FileManagerProtocol, HttpClientProtocol from .entities import ( diff --git a/api/core/workflow/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py similarity index 92% rename from api/core/workflow/nodes/http_request/node.py rename to api/dify_graph/nodes/http_request/node.py index 11458db758..2e48d5502a 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -3,15 +3,15 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.http_request.executor import Executor -from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol -from core.workflow.variables.segments import ArrayFileSegment +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base import variable_template_parser +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.http_request.executor import Executor +from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.variables.segments import ArrayFileSegment from factories import file_factory from .config import build_http_request_config, resolve_http_request_config @@ -27,8 +27,8 @@ from .exc import HttpRequestNodeError, RequestBodyError logger = logging.getLogger(__name__) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class HttpRequestNode(Node[HttpRequestNodeData]): @@ -212,6 +212,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): """ Extract files from response by checking both Content-Type header and URL """ + dify_ctx = self.require_dify_context() files: list[File] = [] is_file = response.is_file content_type = response.content_type @@ -236,8 +237,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, conversation_id=None, file_binary=content, mimetype=mime_type, @@ -249,7 +250,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): } file = file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, ) files.append(file) diff --git a/api/core/workflow/nodes/human_input/__init__.py b/api/dify_graph/nodes/human_input/__init__.py similarity index 100% rename from api/core/workflow/nodes/human_input/__init__.py rename to api/dify_graph/nodes/human_input/__init__.py diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py similarity index 98% rename from api/core/workflow/nodes/human_input/entities.py rename to api/dify_graph/nodes/human_input/entities.py index a4473dfa7d..5616949dcc 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/dify_graph/nodes/human_input/entities.py @@ -10,10 +10,10 @@ from typing import Annotated, Any, ClassVar, Literal, Self from pydantic import BaseModel, Field, field_validator, model_validator -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.runtime import VariablePool -from core.workflow.variables.consts import SELECTORS_LENGTH +from dify_graph.nodes.base import BaseNodeData +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.runtime import VariablePool +from dify_graph.variables.consts import SELECTORS_LENGTH from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit diff --git a/api/core/workflow/nodes/human_input/enums.py b/api/dify_graph/nodes/human_input/enums.py similarity index 100% rename from api/core/workflow/nodes/human_input/enums.py rename to api/dify_graph/nodes/human_input/enums.py diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py similarity index 87% rename from api/core/workflow/nodes/human_input/human_input_node.py rename to api/dify_graph/nodes/human_input/human_input_node.py index 1d7522ea25..03c2d17b1d 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -3,37 +3,36 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import InvokeFrom -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import ( +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import ( HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, NodeRunResult, PauseRequestedEvent, ) -from core.workflow.node_events.base import NodeEventBase -from core.workflow.node_events.node import StreamCompletedEvent -from core.workflow.nodes.base.node import Node -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.node_events.base import NodeEventBase +from dify_graph.node_events.node import StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType if TYPE_CHECKING: - from core.workflow.entities.graph_init_params import GraphInitParams - from core.workflow.runtime.graph_runtime_state import GraphRuntimeState + from dify_graph.entities.graph_init_params import GraphInitParams + from dify_graph.runtime.graph_runtime_state import GraphRuntimeState _SELECTED_BRANCH_KEY = "selected_branch" +_INVOKE_FROM_DEBUGGER = "debugger" +_INVOKE_FROM_EXPLORE = "explore" logger = logging.getLogger(__name__) @@ -67,7 +66,7 @@ class HumanInputNode(Node[HumanInputNodeData]): config: Mapping[str, Any], graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository | None = None, + form_repository: HumanInputFormRepository, ) -> None: super().__init__( id=id, @@ -75,11 +74,6 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - if form_repository is None: - form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, - tenant_id=self.tenant_id, - ) self._form_repository = form_repository @classmethod @@ -163,30 +157,39 @@ class HumanInputNode(Node[HumanInputNodeData]): return resolved_defaults def _should_require_console_recipient(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: + invoke_from = self._invoke_from_value() + if invoke_from == _INVOKE_FROM_DEBUGGER: return True - if self.invoke_from == InvokeFrom.EXPLORE: + if invoke_from == _INVOKE_FROM_EXPLORE: return self._node_data.is_webapp_enabled() return False def _display_in_ui(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: + if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: return True return self._node_data.is_webapp_enabled() def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: + dify_ctx = self.require_dify_context() + invoke_from = self._invoke_from_value() enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}: + if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] return [ apply_debug_email_recipient( method, - enabled=self.invoke_from == InvokeFrom.DEBUGGER, - user_id=self.user_id or "", + enabled=invoke_from == _INVOKE_FROM_DEBUGGER, + user_id=dify_ctx.user_id, ) for method in enabled_methods ] + def _invoke_from_value(self) -> str: + invoke_from = self.require_dify_context().invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: node_data = self._node_data resolved_default_values = self.resolve_default_values() @@ -220,10 +223,11 @@ class HumanInputNode(Node[HumanInputNodeData]): """ repo = self._form_repository form = repo.get_form(self._workflow_execution_id, self.id) + dify_ctx = self.require_dify_context() if form is None: display_in_ui = self._display_in_ui() params = FormCreateParams( - app_id=self.app_id, + app_id=dify_ctx.app_id, workflow_execution_id=self._workflow_execution_id, node_id=self.id, form_config=self._node_data, @@ -233,7 +237,9 @@ class HumanInputNode(Node[HumanInputNodeData]): resolved_default_values=self.resolve_default_values(), console_recipient_required=self._should_require_console_recipient(), console_creator_account_id=( - self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None + dify_ctx.user_id + if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} + else None ), backstage_recipient_required=True, ) diff --git a/api/core/workflow/nodes/if_else/__init__.py b/api/dify_graph/nodes/if_else/__init__.py similarity index 100% rename from api/core/workflow/nodes/if_else/__init__.py rename to api/dify_graph/nodes/if_else/__init__.py diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/dify_graph/nodes/if_else/entities.py similarity index 82% rename from api/core/workflow/nodes/if_else/entities.py rename to api/dify_graph/nodes/if_else/entities.py index b22bd6f508..4733944039 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/dify_graph/nodes/if_else/entities.py @@ -2,8 +2,8 @@ from typing import Literal from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData -from core.workflow.utils.condition.entities import Condition +from dify_graph.nodes.base import BaseNodeData +from dify_graph.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py similarity index 90% rename from api/core/workflow/nodes/if_else/if_else_node.py rename to api/dify_graph/nodes/if_else/if_else_node.py index cda5f1dd42..3c5a33e2b7 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/dify_graph/nodes/if_else/if_else_node.py @@ -3,13 +3,13 @@ from typing import Any, Literal from typing_extensions import deprecated -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.if_else.entities import IfElseNodeData -from core.workflow.runtime import VariablePool -from core.workflow.utils.condition.entities import Condition -from core.workflow.utils.condition.processor import ConditionProcessor +from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.if_else.entities import IfElseNodeData +from dify_graph.runtime import VariablePool +from dify_graph.utils.condition.entities import Condition +from dify_graph.utils.condition.processor import ConditionProcessor class IfElseNode(Node[IfElseNodeData]): diff --git a/api/core/workflow/nodes/iteration/__init__.py b/api/dify_graph/nodes/iteration/__init__.py similarity index 100% rename from api/core/workflow/nodes/iteration/__init__.py rename to api/dify_graph/nodes/iteration/__init__.py diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/dify_graph/nodes/iteration/entities.py similarity index 94% rename from api/core/workflow/nodes/iteration/entities.py rename to api/dify_graph/nodes/iteration/entities.py index 63a41ec755..a31b05463e 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/dify_graph/nodes/iteration/entities.py @@ -3,7 +3,7 @@ from typing import Any from pydantic import Field -from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData +from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData class ErrorHandleMode(StrEnum): diff --git a/api/core/workflow/nodes/iteration/exc.py b/api/dify_graph/nodes/iteration/exc.py similarity index 100% rename from api/core/workflow/nodes/iteration/exc.py rename to api/dify_graph/nodes/iteration/exc.py diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py similarity index 90% rename from api/core/workflow/nodes/iteration/iteration_node.py rename to api/dify_graph/nodes/iteration/iteration_node.py index 54b0561dd8..6d26cbfce4 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -6,21 +6,21 @@ from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.enums import ( +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.enums import ( NodeExecutionType, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphNodeEventBase, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunSucceededEvent, ) -from core.workflow.node_events import ( +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import ( IterationFailedEvent, IterationNextEvent, IterationStartedEvent, @@ -29,13 +29,13 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) -from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from core.workflow.runtime import VariablePool -from core.workflow.variables import IntegerVariable, NoneSegment -from core.workflow.variables.segments import ArrayAnySegment, ArraySegment -from core.workflow.variables.variables import Variable +from dify_graph.nodes.base import LLMUsageTrackingMixin +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from dify_graph.runtime import VariablePool +from dify_graph.variables import IntegerVariable, NoneSegment +from dify_graph.variables.segments import ArrayAnySegment, ArraySegment +from dify_graph.variables.variables import Variable from libs.datetime_utils import naive_utc_now from .exc import ( @@ -48,8 +48,8 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.context import IExecutionContext - from core.workflow.graph_engine import GraphEngine + from dify_graph.context import IExecutionContext + from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -337,7 +337,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _capture_execution_context(self) -> "IExecutionContext": """Capture current execution context for parallel iterations.""" - from core.workflow.context import capture_current_context + from dify_graph.context import capture_current_context return capture_current_context() @@ -488,7 +488,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # variable selector to variable mapping try: # Get node class - from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING node_type = NodeType(sub_node_config.get("data", {}).get("type")) if node_type not in NODE_TYPE_CLASSES_MAPPING: @@ -587,24 +587,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return def _create_graph_engine(self, index: int, item: object): - # Import dependencies - from core.app.workflow.layers.llm_quota import LLMQuotaLayer - from core.app.workflow.node_factory import DifyNodeFactory - from core.workflow.entities import GraphInitParams - from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine, GraphEngineConfig - from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState - # Create GraphInitParams from node attributes + # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) # Create a deep copy of the variable pool for each iteration @@ -621,28 +611,17 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): total_tokens=0, node_run_steps=0, ) + root_node_id = self.node_data.start_node_id + if root_node_id is None: + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - # Create a new node factory with the new GraphRuntimeState - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy - ) - - # Initialize the iteration graph with the new node factory - iteration_graph = Graph.init( - graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id - ) - - if not iteration_graph: - raise IterationGraphNotFoundError("iteration graph not found") - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( - workflow_id=self.workflow_id, - graph=iteration_graph, - graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), - ) - graph_engine.layer(LLMQuotaLayer()) - - return graph_engine + try: + return self.graph_runtime_state.create_child_engine( + workflow_id=self.workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state_copy, + graph_config=self.graph_config, + root_node_id=root_node_id, + ) + except ChildGraphNotFoundError as exc: + raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/dify_graph/nodes/iteration/iteration_start_node.py similarity index 60% rename from api/core/workflow/nodes/iteration/iteration_start_node.py rename to api/dify_graph/nodes/iteration/iteration_start_node.py index 30d9fccbfd..2e1f555ed2 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/dify_graph/nodes/iteration/iteration_start_node.py @@ -1,7 +1,7 @@ -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.iteration.entities import IterationStartNodeData +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.iteration.entities import IterationStartNodeData class IterationStartNode(Node[IterationStartNodeData]): diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/dify_graph/nodes/knowledge_index/__init__.py similarity index 100% rename from api/core/workflow/nodes/knowledge_index/__init__.py rename to api/dify_graph/nodes/knowledge_index/__init__.py diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/dify_graph/nodes/knowledge_index/entities.py similarity index 98% rename from api/core/workflow/nodes/knowledge_index/entities.py rename to api/dify_graph/nodes/knowledge_index/entities.py index bfeb9b5b79..493b5eadd8 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/dify_graph/nodes/knowledge_index/entities.py @@ -3,7 +3,7 @@ from typing import Literal, Union from pydantic import BaseModel from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/exc.py b/api/dify_graph/nodes/knowledge_index/exc.py similarity index 100% rename from api/core/workflow/nodes/knowledge_index/exc.py rename to api/dify_graph/nodes/knowledge_index/exc.py diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py similarity index 87% rename from api/core/workflow/nodes/knowledge_index/knowledge_index_node.py rename to api/dify_graph/nodes/knowledge_index/knowledge_index_node.py index 8fb5b99454..eeb4f3c229 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py @@ -2,14 +2,13 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.template import Template -from core.workflow.repositories.index_processor_protocol import IndexProcessorProtocol -from core.workflow.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.template import Template +from dify_graph.repositories.index_processor_protocol import IndexProcessorProtocol +from dify_graph.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol from .entities import KnowledgeIndexNodeData from .exc import ( @@ -17,10 +16,11 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) +_INVOKE_FROM_DEBUGGER = "debugger" class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): @@ -59,7 +59,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False + invoke_from_value = str(invoke_from.value) if invoke_from else None + is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value variables = {"chunks": chunks} diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/dify_graph/nodes/knowledge_retrieval/__init__.py similarity index 100% rename from api/core/workflow/nodes/knowledge_retrieval/__init__.py rename to api/dify_graph/nodes/knowledge_retrieval/__init__.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/dify_graph/nodes/knowledge_retrieval/entities.py similarity index 96% rename from api/core/workflow/nodes/knowledge_retrieval/entities.py rename to api/dify_graph/nodes/knowledge_retrieval/entities.py index 86bb2495e7..c3059897c7 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/dify_graph/nodes/knowledge_retrieval/entities.py @@ -3,8 +3,8 @@ from typing import Literal from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig +from dify_graph.nodes.base import BaseNodeData +from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/exc.py b/api/dify_graph/nodes/knowledge_retrieval/exc.py similarity index 100% rename from api/core/workflow/nodes/knowledge_retrieval/exc.py rename to api/dify_graph/nodes/knowledge_retrieval/exc.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py similarity index 89% rename from api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py rename to api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0cfd39e485..d84dda42d6 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -3,25 +3,25 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities import GraphInitParams -from core.workflow.enums import ( +from dify_graph.entities import GraphInitParams +from dify_graph.enums import ( NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source -from core.workflow.variables import ( +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base import LLMUsageTrackingMixin +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver +from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source +from dify_graph.variables import ( ArrayFileSegment, FileSegment, StringSegment, ) -from core.workflow.variables.segments import ArrayObjectSegment +from dify_graph.variables.segments import ArrayObjectSegment from .entities import KnowledgeRetrievalNodeData from .exc import ( @@ -30,8 +30,8 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.file.models import File - from core.workflow.runtime import GraphRuntimeState + from dify_graph.file.models import File + from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -66,9 +66,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD self._rag_retrieval = rag_retrieval if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -160,6 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: + dify_ctx = self.require_dify_context() dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -176,10 +178,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD model = node_data.single_retrieval_config.model retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( request=KnowledgeRetrievalRequest( - tenant_id=self.tenant_id, - user_id=self.user_id, - app_id=self.app_id, - user_from=self.user_from.value, + tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, + user_from=dify_ctx.user_from.value, dataset_ids=dataset_ids, retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value, completion_params=model.completion_params, @@ -229,10 +231,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( request=KnowledgeRetrievalRequest( - app_id=self.app_id, - tenant_id=self.tenant_id, - user_id=self.user_id, - user_from=self.user_from.value, + app_id=dify_ctx.app_id, + tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, + user_from=dify_ctx.user_from.value, dataset_ids=dataset_ids, query=query, retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value, diff --git a/api/core/workflow/nodes/knowledge_retrieval/template_prompts.py b/api/dify_graph/nodes/knowledge_retrieval/template_prompts.py similarity index 100% rename from api/core/workflow/nodes/knowledge_retrieval/template_prompts.py rename to api/dify_graph/nodes/knowledge_retrieval/template_prompts.py diff --git a/api/core/workflow/nodes/list_operator/__init__.py b/api/dify_graph/nodes/list_operator/__init__.py similarity index 100% rename from api/core/workflow/nodes/list_operator/__init__.py rename to api/dify_graph/nodes/list_operator/__init__.py diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/dify_graph/nodes/list_operator/entities.py similarity index 96% rename from api/core/workflow/nodes/list_operator/entities.py rename to api/dify_graph/nodes/list_operator/entities.py index e51a91f07f..0fdd85f210 100644 --- a/api/core/workflow/nodes/list_operator/entities.py +++ b/api/dify_graph/nodes/list_operator/entities.py @@ -3,7 +3,7 @@ from enum import StrEnum from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData class FilterOperator(StrEnum): diff --git a/api/core/workflow/nodes/list_operator/exc.py b/api/dify_graph/nodes/list_operator/exc.py similarity index 100% rename from api/core/workflow/nodes/list_operator/exc.py rename to api/dify_graph/nodes/list_operator/exc.py diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/dify_graph/nodes/list_operator/node.py similarity index 97% rename from api/core/workflow/nodes/list_operator/node.py rename to api/dify_graph/nodes/list_operator/node.py index d9ef16fbe7..d2fdadc29c 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/dify_graph/nodes/list_operator/node.py @@ -1,12 +1,12 @@ from collections.abc import Callable, Sequence from typing import Any, TypeAlias, TypeVar -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.file import File -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from core.workflow.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.file import File +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from dify_graph.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment from .entities import FilterOperator, ListOperatorNodeData, Order from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/dify_graph/nodes/llm/__init__.py similarity index 100% rename from api/core/workflow/nodes/llm/__init__.py rename to api/dify_graph/nodes/llm/__init__.py diff --git a/api/core/workflow/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py similarity index 97% rename from api/core/workflow/nodes/llm/entities.py rename to api/dify_graph/nodes/llm/entities.py index bda7a344e8..f4212f04d7 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -5,15 +5,15 @@ from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from core.agent.entities import AgentLog, AgentResult -from core.model_runtime.entities import ImagePromptMessageContent, LLMMode -from core.model_runtime.entities.llm_entities import LLMUsage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.entities import ToolCall, ToolCallResult -from core.workflow.file import File -from core.workflow.node_events import AgentLogEvent -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.entities import VariableSelector +from dify_graph.entities import ToolCall, ToolCallResult +from dify_graph.file import File +from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import AgentLogEvent +from dify_graph.nodes.base import BaseNodeData +from dify_graph.nodes.base.entities import VariableSelector class ModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/llm/exc.py b/api/dify_graph/nodes/llm/exc.py similarity index 100% rename from api/core/workflow/nodes/llm/exc.py rename to api/dify_graph/nodes/llm/exc.py diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py similarity index 98% rename from api/core/workflow/nodes/llm/file_saver.py rename to api/dify_graph/nodes/llm/file_saver.py index 3c06ab7d81..b4f64f4093 100644 --- a/api/core/workflow/nodes/llm/file_saver.py +++ b/api/dify_graph/nodes/llm/file_saver.py @@ -7,7 +7,7 @@ from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE from core.helper import ssrf_proxy from core.tools.signature import sign_tool_file from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db as global_db diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py similarity index 93% rename from api/core/workflow/nodes/llm/llm_utils.py rename to api/dify_graph/nodes/llm/llm_utils.py index c107734b81..81d7d8cd73 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -7,8 +7,11 @@ from sqlalchemy.orm import Session from core.memory import NodeTokenBufferMemory, TokenBufferMemory from core.memory.base import BaseMemory from core.model_manager import ModelInstance -from core.model_runtime.entities import PromptMessageRole -from core.model_runtime.entities.message_entities import ( +from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode +from dify_graph.enums import SystemVariableKey +from dify_graph.file.models import File +from dify_graph.model_runtime.entities import PromptMessageRole +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, MultiModalPromptMessageContent, @@ -17,15 +20,12 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, ToolPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode -from core.workflow.enums import SystemVariableKey -from core.workflow.file.models import File -from core.workflow.nodes.llm.entities import LLMGenerationData -from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment +from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.nodes.llm.entities import LLMGenerationData +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment from .exc import InvalidVariableTypeError diff --git a/api/core/workflow/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py similarity index 98% rename from api/core/workflow/nodes/llm/node.py rename to api/dify_graph/nodes/llm/node.py index d24807f2a1..9c9386e28f 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -29,30 +29,6 @@ from core.llm_generator.output_parser.structured_output import ( ) from core.memory.base import BaseMemory from core.model_manager import ModelInstance -from core.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - TextPromptMessageContent, -) -from core.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMStructuredOutput, - LLMUsage, -) -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - PromptMessageRole, - SystemPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata @@ -68,17 +44,41 @@ from core.tools.__base.tool import Tool from core.tools.signature import sign_tool_file, sign_upload_file from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_manager import ToolManager -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus -from core.workflow.entities.tool_entities import ToolCallResult -from core.workflow.enums import ( +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus +from dify_graph.entities.tool_entities import ToolCallResult +from dify_graph.enums import ( NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.file import File, FileTransferMethod, FileType, file_manager -from core.workflow.node_events import ( +from dify_graph.file import File, FileTransferMethod, FileType, file_manager +from dify_graph.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + TextPromptMessageContent, +) +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMStructuredOutput, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageContentUnionTypes, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import ( AgentLogEvent, ModelInvokeCompletedEvent, NodeEventBase, @@ -90,13 +90,13 @@ from core.workflow.node_events import ( ToolCallChunkEvent, ToolResultChunkEvent, ) -from core.workflow.node_events.node import ChunkType, ThoughtEndChunkEvent, ThoughtStartChunkEvent -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.runtime import VariablePool -from core.workflow.variables import ( +from dify_graph.node_events.node import ChunkType, ThoughtEndChunkEvent, ThoughtStartChunkEvent +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.runtime import VariablePool +from dify_graph.variables import ( ArrayFileSegment, ArrayPromptMessageSegment, ArraySegment, @@ -139,8 +139,8 @@ from .exc import ( from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: - from core.workflow.file.models import File - from core.workflow.runtime import GraphRuntimeState + from dify_graph.file.models import File + from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -189,9 +189,10 @@ class LLMNode(Node[LLMNodeData]): self._memory = memory if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -308,6 +309,21 @@ class LLMNode(Node[LLMNodeData]): sandbox=self.graph_runtime_state.sandbox, ) + # handle invoke result + generator = LLMNode.invoke_llm( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + user_id=self.require_dify_context().user_id, + structured_output_enabled=self.node_data.structured_output_enabled, + structured_output=self.node_data.structured_output, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self._node_id, + node_type=self.node_type, + reasoning_format=self.node_data.reasoning_format, + ) + # Variables for outputs generation_data: LLMGenerationData | None = None structured_output: LLMStructuredOutput | None = None @@ -1234,7 +1250,7 @@ class LLMNode(Node[LLMNodeData]): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=self.tenant_id, + tenant_id=self.require_dify_context().tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, diff --git a/api/core/workflow/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py similarity index 100% rename from api/core/workflow/nodes/llm/protocols.py rename to api/dify_graph/nodes/llm/protocols.py diff --git a/api/core/workflow/nodes/loop/__init__.py b/api/dify_graph/nodes/loop/__init__.py similarity index 100% rename from api/core/workflow/nodes/loop/__init__.py rename to api/dify_graph/nodes/loop/__init__.py diff --git a/api/core/workflow/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py similarity index 91% rename from api/core/workflow/nodes/loop/entities.py rename to api/dify_graph/nodes/loop/entities.py index 4090f27799..b4a8518048 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/dify_graph/nodes/loop/entities.py @@ -3,9 +3,9 @@ from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator -from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData -from core.workflow.utils.condition.entities import Condition -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData +from dify_graph.utils.condition.entities import Condition +from dify_graph.variables.types import SegmentType _VALID_VAR_TYPE = frozenset( [ diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/dify_graph/nodes/loop/loop_end_node.py similarity index 59% rename from api/core/workflow/nodes/loop/loop_end_node.py rename to api/dify_graph/nodes/loop/loop_end_node.py index 1e3e317b53..73ac5da927 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/dify_graph/nodes/loop/loop_end_node.py @@ -1,7 +1,7 @@ -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.loop.entities import LoopEndNodeData +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.loop.entities import LoopEndNodeData class LoopEndNode(Node[LoopEndNodeData]): diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py similarity index 89% rename from api/core/workflow/nodes/loop/loop_node.py rename to api/dify_graph/nodes/loop/loop_node.py index 40ec0cf8b1..8279f0fc66 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -5,19 +5,19 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, cast -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import ( +from dify_graph.enums import ( NodeExecutionType, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphNodeEventBase, GraphRunFailedEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import ( +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import ( LoopFailedEvent, LoopNextEvent, LoopStartedEvent, @@ -26,16 +26,16 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) -from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from core.workflow.utils.condition.processor import ConditionProcessor -from core.workflow.variables import Segment, SegmentType +from dify_graph.nodes.base import LLMUsageTrackingMixin +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData +from dify_graph.utils.condition.processor import ConditionProcessor +from dify_graph.variables import Segment, SegmentType from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable from libs.datetime_utils import naive_utc_now if TYPE_CHECKING: - from core.workflow.graph_engine import GraphEngine + from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -318,7 +318,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # variable selector to variable mapping try: # Get node class - from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING node_type = NodeType(sub_node_config.get("data", {}).get("type")) if node_type not in NODE_TYPE_CLASSES_MAPPING: @@ -412,24 +412,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): return build_segment_with_type(var_type, value) def _create_graph_engine(self, start_at: datetime, root_node_id: str): - # Import dependencies - from core.app.workflow.layers.llm_quota import LLMQuotaLayer - from core.app.workflow.node_factory import DifyNodeFactory - from core.workflow.entities import GraphInitParams - from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine, GraphEngineConfig - from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState - # Create GraphInitParams from node attributes + # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -439,22 +429,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): start_at=start_at.timestamp(), ) - # Create a new node factory with the new GraphRuntimeState - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy - ) - - # Initialize the loop graph with the new node factory - loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id) - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( + return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, - graph=loop_graph, + graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), + graph_config=self.graph_config, + root_node_id=root_node_id, ) - graph_engine.layer(LLMQuotaLayer()) - - return graph_engine diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/dify_graph/nodes/loop/loop_start_node.py similarity index 59% rename from api/core/workflow/nodes/loop/loop_start_node.py rename to api/dify_graph/nodes/loop/loop_start_node.py index 95bb5c4018..f469c8286e 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/dify_graph/nodes/loop/loop_start_node.py @@ -1,7 +1,7 @@ -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.loop.entities import LoopStartNodeData +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.loop.entities import LoopStartNodeData class LoopStartNode(Node[LoopStartNodeData]): diff --git a/api/core/workflow/nodes/node_mapping.py b/api/dify_graph/nodes/node_mapping.py similarity index 65% rename from api/core/workflow/nodes/node_mapping.py rename to api/dify_graph/nodes/node_mapping.py index 85df543a2a..8e5405f1aa 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/dify_graph/nodes/node_mapping.py @@ -1,9 +1,9 @@ from collections.abc import Mapping -from core.workflow.enums import NodeType -from core.workflow.nodes.base.node import Node +from dify_graph.enums import NodeType +from dify_graph.nodes.base.node import Node LATEST_VERSION = "latest" -# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes +# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks dify_graph.nodes NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() diff --git a/api/core/workflow/nodes/parameter_extractor/__init__.py b/api/dify_graph/nodes/parameter_extractor/__init__.py similarity index 100% rename from api/core/workflow/nodes/parameter_extractor/__init__.py rename to api/dify_graph/nodes/parameter_extractor/__init__.py diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py similarity index 95% rename from api/core/workflow/nodes/parameter_extractor/entities.py rename to api/dify_graph/nodes/parameter_extractor/entities.py index 90d78ae429..3b042710f9 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/dify_graph/nodes/parameter_extractor/entities.py @@ -8,9 +8,9 @@ from pydantic import ( ) from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.base import BaseNodeData +from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig +from dify_graph.variables.types import SegmentType _OLD_BOOL_TYPE_NAME = "bool" _OLD_SELECT_TYPE_NAME = "select" diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/dify_graph/nodes/parameter_extractor/exc.py similarity index 97% rename from api/core/workflow/nodes/parameter_extractor/exc.py rename to api/dify_graph/nodes/parameter_extractor/exc.py index 5a58780575..c25b809d1c 100644 --- a/api/core/workflow/nodes/parameter_extractor/exc.py +++ b/api/dify_graph/nodes/parameter_extractor/exc.py @@ -1,6 +1,6 @@ from typing import Any -from core.workflow.variables.types import SegmentType +from dify_graph.variables.types import SegmentType class ParameterExtractorNodeError(ValueError): diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py similarity index 96% rename from api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py rename to api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 0c8b122e1c..0e6cc772ca 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -6,9 +6,19 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast from core.model_manager import ModelInstance -from core.model_runtime.entities import ImagePromptMessageContent -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from dify_graph.enums import ( + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.file import File +from dify_graph.model_runtime.entities import ImagePromptMessageContent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -16,26 +26,16 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.enums import ( - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from core.workflow.file import File -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.llm import llm_utils -from core.workflow.runtime import VariablePool -from core.workflow.variables.types import ArrayValidation, SegmentType +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base import variable_template_parser +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.llm import llm_utils +from dify_graph.runtime import VariablePool +from dify_graph.variables.types import ArrayValidation, SegmentType from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData @@ -64,9 +64,9 @@ from .prompts import ( logger = logging.getLogger(__name__) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory + from dify_graph.runtime import GraphRuntimeState def extract_json(text): @@ -302,7 +302,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): tools=tools, stop=list(stop), stream=False, - user=self.user_id, + user=self.require_dify_context().user_id, ) # handle invoke result diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/dify_graph/nodes/parameter_extractor/prompts.py similarity index 100% rename from api/core/workflow/nodes/parameter_extractor/prompts.py rename to api/dify_graph/nodes/parameter_extractor/prompts.py diff --git a/api/core/workflow/nodes/protocols.py b/api/dify_graph/nodes/protocols.py similarity index 96% rename from api/core/workflow/nodes/protocols.py rename to api/dify_graph/nodes/protocols.py index fda524d701..cc007150f1 100644 --- a/api/core/workflow/nodes/protocols.py +++ b/api/dify_graph/nodes/protocols.py @@ -2,7 +2,7 @@ from typing import Any, Protocol import httpx -from core.workflow.file import File +from dify_graph.file import File class HttpClientProtocol(Protocol): diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/dify_graph/nodes/question_classifier/__init__.py similarity index 100% rename from api/core/workflow/nodes/question_classifier/__init__.py rename to api/dify_graph/nodes/question_classifier/__init__.py diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py similarity index 87% rename from api/core/workflow/nodes/question_classifier/entities.py rename to api/dify_graph/nodes/question_classifier/entities.py index edde30708a..03e0a0ac53 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/dify_graph/nodes/question_classifier/entities.py @@ -1,8 +1,8 @@ from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm import ModelConfig, VisionConfig +from dify_graph.nodes.base import BaseNodeData +from dify_graph.nodes.llm import ModelConfig, VisionConfig class ClassConfig(BaseModel): diff --git a/api/core/workflow/nodes/question_classifier/exc.py b/api/dify_graph/nodes/question_classifier/exc.py similarity index 100% rename from api/core/workflow/nodes/question_classifier/exc.py rename to api/dify_graph/nodes/question_classifier/exc.py diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py similarity index 92% rename from api/core/workflow/nodes/question_classifier/question_classifier_node.py rename to api/dify_graph/nodes/question_classifier/question_classifier_node.py index c83534cf35..860db05c84 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -4,30 +4,30 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any from core.model_manager import ModelInstance -from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities import GraphInitParams -from core.workflow.enums import ( +from dify_graph.entities import GraphInitParams +from dify_graph.enums import ( NodeExecutionType, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm import ( +from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.llm import ( LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils, ) -from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from libs.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData @@ -43,8 +43,8 @@ from .template_prompts import ( ) if TYPE_CHECKING: - from core.workflow.file.models import File - from core.workflow.runtime import GraphRuntimeState + from dify_graph.file.models import File + from dify_graph.runtime import GraphRuntimeState class QuestionClassifierNode(Node[QuestionClassifierNodeData]): @@ -86,9 +86,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self._memory = memory if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -159,8 +160,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.user_id, - structured_output_schema=None, + user_id=self.require_dify_context().user_id, + structured_output_enabled=False, + structured_output=None, file_saver=self._llm_file_saver, file_outputs=self._file_outputs, node_id=self._node_id, diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/dify_graph/nodes/question_classifier/template_prompts.py similarity index 100% rename from api/core/workflow/nodes/question_classifier/template_prompts.py rename to api/dify_graph/nodes/question_classifier/template_prompts.py diff --git a/api/core/workflow/nodes/start/__init__.py b/api/dify_graph/nodes/start/__init__.py similarity index 100% rename from api/core/workflow/nodes/start/__init__.py rename to api/dify_graph/nodes/start/__init__.py diff --git a/api/core/workflow/nodes/start/entities.py b/api/dify_graph/nodes/start/entities.py similarity index 64% rename from api/core/workflow/nodes/start/entities.py rename to api/dify_graph/nodes/start/entities.py index 3a99e2cbc2..0df832740e 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/dify_graph/nodes/start/entities.py @@ -2,8 +2,8 @@ from collections.abc import Sequence from pydantic import Field -from core.workflow.nodes.base import BaseNodeData -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.nodes.base import BaseNodeData +from dify_graph.variables.input_entities import VariableEntity class StartNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/start/start_node.py b/api/dify_graph/nodes/start/start_node.py similarity index 84% rename from api/core/workflow/nodes/start/start_node.py rename to api/dify_graph/nodes/start/start_node.py index 4e5545d330..c09ead0124 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/dify_graph/nodes/start/start_node.py @@ -2,12 +2,12 @@ from typing import Any from jsonschema import Draft7Validator, ValidationError -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.variables.input_entities import VariableEntityType +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.variables.input_entities import VariableEntityType class StartNode(Node[StartNodeData]): diff --git a/api/core/workflow/nodes/template_transform/__init__.py b/api/dify_graph/nodes/template_transform/__init__.py similarity index 100% rename from api/core/workflow/nodes/template_transform/__init__.py rename to api/dify_graph/nodes/template_transform/__init__.py diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/dify_graph/nodes/template_transform/entities.py similarity index 57% rename from api/core/workflow/nodes/template_transform/entities.py rename to api/dify_graph/nodes/template_transform/entities.py index efb7a72f59..123fd41f81 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/dify_graph/nodes/template_transform/entities.py @@ -1,5 +1,5 @@ -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.entities import VariableSelector +from dify_graph.nodes.base import BaseNodeData +from dify_graph.nodes.base.entities import VariableSelector class TemplateTransformNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/template_transform/template_renderer.py b/api/dify_graph/nodes/template_transform/template_renderer.py similarity index 62% rename from api/core/workflow/nodes/template_transform/template_renderer.py rename to api/dify_graph/nodes/template_transform/template_renderer.py index a5f06bf2bb..9b679d4497 100644 --- a/api/core/workflow/nodes/template_transform/template_renderer.py +++ b/api/dify_graph/nodes/template_transform/template_renderer.py @@ -3,7 +3,8 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any, Protocol -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage +from dify_graph.nodes.code.code_node import WorkflowCodeExecutor +from dify_graph.nodes.code.entities import CodeLanguage class TemplateRenderError(ValueError): @@ -21,18 +22,18 @@ class Jinja2TemplateRenderer(Protocol): class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Adapter that renders Jinja2 templates via CodeExecutor.""" - _code_executor: type[CodeExecutor] + _code_executor: WorkflowCodeExecutor - def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None: - self._code_executor = code_executor or CodeExecutor + def __init__(self, code_executor: WorkflowCodeExecutor) -> None: + self._code_executor = code_executor def render_template(self, template: str, variables: Mapping[str, Any]) -> str: try: - result = self._code_executor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=template, inputs=variables - ) - except CodeExecutionError as exc: - raise TemplateRenderError(str(exc)) from exc + result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables) + except Exception as exc: + if self._code_executor.is_execution_error(exc): + raise TemplateRenderError(str(exc)) from exc + raise rendered = result.get("result") if not isinstance(rendered, str): diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py similarity index 83% rename from api/core/workflow/nodes/template_transform/template_transform_node.py rename to api/dify_graph/nodes/template_transform/template_transform_node.py index 3dc8afd9be..367442e997 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/dify_graph/nodes/template_transform/template_transform_node.py @@ -1,19 +1,18 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData -from core.workflow.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData +from dify_graph.nodes.template_transform.template_renderer import ( Jinja2TemplateRenderer, TemplateRenderError, ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 @@ -30,7 +29,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - template_renderer: Jinja2TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer, max_output_length: int | None = None, ) -> None: super().__init__( @@ -39,7 +38,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() + self._template_renderer = template_renderer if max_output_length is not None and max_output_length <= 0: raise ValueError("max_output_length must be a positive integer") diff --git a/api/core/workflow/nodes/tool/__init__.py b/api/dify_graph/nodes/tool/__init__.py similarity index 100% rename from api/core/workflow/nodes/tool/__init__.py rename to api/dify_graph/nodes/tool/__init__.py diff --git a/api/core/workflow/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py similarity index 98% rename from api/core/workflow/nodes/tool/entities.py rename to api/dify_graph/nodes/tool/entities.py index 031cc73dc8..cd690fff04 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, field_validator, model_validator from pydantic_core.core_schema import ValidationInfo from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.nodes.base.entities import BaseNodeData +from dify_graph.nodes.base.entities import BaseNodeData # Pattern to match mention format: {{@node.context@}}instruction MENTION_VALUE_PATTERN = re.compile(r"^\{\{@([a-zA-Z0-9_]+)\.context@\}\}(.*)$", re.DOTALL) diff --git a/api/core/workflow/nodes/tool/exc.py b/api/dify_graph/nodes/tool/exc.py similarity index 100% rename from api/core/workflow/nodes/tool/exc.py rename to api/dify_graph/nodes/tool/exc.py diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py similarity index 96% rename from api/core/workflow/nodes/tool/tool_node.py rename to api/dify_graph/nodes/tool/tool_node.py index e3f60c34c0..6a6ffa6432 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -8,24 +8,24 @@ logger = logging.getLogger(__name__) from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.enums import ( +from dify_graph.enums import ( NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.file import File, FileTransferMethod -from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment -from core.workflow.variables.variables import ArrayAnyVariable +from dify_graph.file import File, FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment +from dify_graph.variables.variables import ArrayAnyVariable from extensions.ext_database import db from factories import file_factory from models import ToolFile @@ -39,7 +39,7 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.runtime import VariablePool + from dify_graph.runtime import VariablePool class ToolNode(Node[ToolNodeData]): @@ -59,6 +59,8 @@ class ToolNode(Node[ToolNodeData]): """ from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError + dify_ctx = self.require_dify_context() + # fetch tool icon tool_info = { "provider_type": self.node_data.provider_type.value, @@ -78,7 +80,12 @@ class ToolNode(Node[ToolNodeData]): if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool + dify_ctx.tenant_id, + dify_ctx.app_id, + self._node_id, + self.node_data, + dify_ctx.invoke_from, + variable_pool, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -112,10 +119,10 @@ class ToolNode(Node[ToolNodeData]): message_stream = ToolEngine.generic_invoke( tool=tool_runtime, tool_parameters=parameters, - user_id=self.user_id, + user_id=dify_ctx.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, - app_id=self.app_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, ) except ToolNodeError as e: @@ -136,8 +143,8 @@ class ToolNode(Node[ToolNodeData]): messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, node_id=self._node_id, tool_runtime=tool_runtime, ) diff --git a/api/core/workflow/nodes/trigger_plugin/__init__.py b/api/dify_graph/nodes/trigger_plugin/__init__.py similarity index 100% rename from api/core/workflow/nodes/trigger_plugin/__init__.py rename to api/dify_graph/nodes/trigger_plugin/__init__.py diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/dify_graph/nodes/trigger_plugin/entities.py similarity index 95% rename from api/core/workflow/nodes/trigger_plugin/entities.py rename to api/dify_graph/nodes/trigger_plugin/entities.py index 6c53acee4f..75d10ecaa4 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/dify_graph/nodes/trigger_plugin/entities.py @@ -4,8 +4,8 @@ from typing import Any, Literal, Union from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.entities.entities import EventParameter -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError +from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError class TriggerEventNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_plugin/exc.py b/api/dify_graph/nodes/trigger_plugin/exc.py similarity index 100% rename from api/core/workflow/nodes/trigger_plugin/exc.py rename to api/dify_graph/nodes/trigger_plugin/exc.py diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py similarity index 85% rename from api/core/workflow/nodes/trigger_plugin/trigger_event_node.py rename to api/dify_graph/nodes/trigger_plugin/trigger_event_node.py index e11cb30a7f..b4f1116f7e 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py @@ -1,10 +1,10 @@ from collections.abc import Mapping -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, NodeType +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node from .entities import TriggerEventNodeData diff --git a/api/dify_graph/nodes/trigger_schedule/__init__.py b/api/dify_graph/nodes/trigger_schedule/__init__.py new file mode 100644 index 0000000000..c9b3ae6a0d --- /dev/null +++ b/api/dify_graph/nodes/trigger_schedule/__init__.py @@ -0,0 +1,3 @@ +from dify_graph.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode + +__all__ = ["TriggerScheduleNode"] diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/dify_graph/nodes/trigger_schedule/entities.py similarity index 97% rename from api/core/workflow/nodes/trigger_schedule/entities.py rename to api/dify_graph/nodes/trigger_schedule/entities.py index a515d02d55..6daadc7666 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/dify_graph/nodes/trigger_schedule/entities.py @@ -2,7 +2,7 @@ from typing import Literal, Union from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/exc.py b/api/dify_graph/nodes/trigger_schedule/exc.py similarity index 90% rename from api/core/workflow/nodes/trigger_schedule/exc.py rename to api/dify_graph/nodes/trigger_schedule/exc.py index 2f99880ff1..caea6241e4 100644 --- a/api/core/workflow/nodes/trigger_schedule/exc.py +++ b/api/dify_graph/nodes/trigger_schedule/exc.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.base.exc import BaseNodeError +from dify_graph.nodes.base.exc import BaseNodeError class ScheduleNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/dify_graph/nodes/trigger_schedule/trigger_schedule_node.py similarity index 77% rename from api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py rename to api/dify_graph/nodes/trigger_schedule/trigger_schedule_node.py index fb5c8a4dce..7e92eb3f4f 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/dify_graph/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,11 @@ from collections.abc import Mapping -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, NodeType +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.trigger_schedule.entities import TriggerScheduleNodeData class TriggerScheduleNode(Node[TriggerScheduleNodeData]): diff --git a/api/core/workflow/nodes/trigger_webhook/__init__.py b/api/dify_graph/nodes/trigger_webhook/__init__.py similarity index 100% rename from api/core/workflow/nodes/trigger_webhook/__init__.py rename to api/dify_graph/nodes/trigger_webhook/__init__.py diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/dify_graph/nodes/trigger_webhook/entities.py similarity index 97% rename from api/core/workflow/nodes/trigger_webhook/entities.py rename to api/dify_graph/nodes/trigger_webhook/entities.py index 1011e60b43..fa36aeabd3 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/dify_graph/nodes/trigger_webhook/entities.py @@ -4,7 +4,7 @@ from typing import Literal from pydantic import BaseModel, Field, field_validator -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData class Method(StrEnum): diff --git a/api/core/workflow/nodes/trigger_webhook/exc.py b/api/dify_graph/nodes/trigger_webhook/exc.py similarity index 86% rename from api/core/workflow/nodes/trigger_webhook/exc.py rename to api/dify_graph/nodes/trigger_webhook/exc.py index dc2239c287..853b2456c5 100644 --- a/api/core/workflow/nodes/trigger_webhook/exc.py +++ b/api/dify_graph/nodes/trigger_webhook/exc.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.base.exc import BaseNodeError +from dify_graph.nodes.base.exc import BaseNodeError class WebhookNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/dify_graph/nodes/trigger_webhook/node.py similarity index 92% rename from api/core/workflow/nodes/trigger_webhook/node.py rename to api/dify_graph/nodes/trigger_webhook/node.py index 9f6046c11a..e466541908 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/dify_graph/nodes/trigger_webhook/node.py @@ -2,14 +2,14 @@ import logging from collections.abc import Mapping from typing import Any -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType -from core.workflow.file import FileTransferMethod -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.variables.types import SegmentType -from core.workflow.variables.variables import FileVariable +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, NodeType +from dify_graph.file import FileTransferMethod +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.variables.types import SegmentType +from dify_graph.variables.variables import FileVariable from factories import file_factory from factories.variable_factory import build_segment_with_type @@ -69,6 +69,7 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): + dify_ctx = self.require_dify_context() related_id = file.get("related_id") transfer_method_value = file.get("transfer_method") if transfer_method_value: @@ -84,7 +85,7 @@ class TriggerWebhookNode(Node[WebhookData]): try: file_obj = file_factory.build_from_mapping( mapping=file, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, ) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) diff --git a/api/core/workflow/nodes/variable_aggregator/__init__.py b/api/dify_graph/nodes/variable_aggregator/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_aggregator/__init__.py rename to api/dify_graph/nodes/variable_aggregator/__init__.py diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/dify_graph/nodes/variable_aggregator/entities.py similarity index 83% rename from api/core/workflow/nodes/variable_aggregator/entities.py rename to api/dify_graph/nodes/variable_aggregator/entities.py index febbf1d1d6..5f7c1dbe93 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/dify_graph/nodes/variable_aggregator/entities.py @@ -1,7 +1,7 @@ from pydantic import BaseModel -from core.workflow.nodes.base import BaseNodeData -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.base import BaseNodeData +from dify_graph.variables.types import SegmentType class AdvancedSettings(BaseModel): diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py similarity index 81% rename from api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py rename to api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py index 762b7dab07..98ab8105fe 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,10 +1,10 @@ from collections.abc import Mapping -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from core.workflow.variables.segments import Segment +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.variable_aggregator.entities import VariableAggregatorNodeData +from dify_graph.variables.segments import Segment class VariableAggregatorNode(Node[VariableAggregatorNodeData]): diff --git a/api/core/workflow/nodes/variable_assigner/common/__init__.py b/api/dify_graph/nodes/variable_assigner/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/common/__init__.py rename to api/dify_graph/nodes/variable_assigner/__init__.py diff --git a/api/core/workflow/utils/__init__.py b/api/dify_graph/nodes/variable_assigner/common/__init__.py similarity index 100% rename from api/core/workflow/utils/__init__.py rename to api/dify_graph/nodes/variable_assigner/common/__init__.py diff --git a/api/core/workflow/nodes/variable_assigner/common/exc.py b/api/dify_graph/nodes/variable_assigner/common/exc.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/common/exc.py rename to api/dify_graph/nodes/variable_assigner/common/exc.py diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/dify_graph/nodes/variable_assigner/common/helpers.py similarity index 90% rename from api/core/workflow/nodes/variable_assigner/common/helpers.py rename to api/dify_graph/nodes/variable_assigner/common/helpers.py index 37fde9d1b0..f0b22904a9 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/dify_graph/nodes/variable_assigner/common/helpers.py @@ -3,9 +3,9 @@ from typing import Any, TypeVar from pydantic import BaseModel -from core.workflow.variables import Segment -from core.workflow.variables.consts import SELECTORS_LENGTH -from core.workflow.variables.types import SegmentType +from dify_graph.variables import Segment +from dify_graph.variables.consts import SELECTORS_LENGTH +from dify_graph.variables.types import SegmentType # Use double underscore (`__`) prefix for internal variables # to minimize risk of collision with user-defined variable names. diff --git a/api/core/workflow/nodes/variable_assigner/v1/__init__.py b/api/dify_graph/nodes/variable_assigner/v1/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/v1/__init__.py rename to api/dify_graph/nodes/variable_assigner/v1/__init__.py diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py similarity index 88% rename from api/core/workflow/nodes/variable_assigner/v1/node.py rename to api/dify_graph/nodes/variable_assigner/v1/node.py index b987949541..1aa7042b02 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node.py @@ -1,19 +1,19 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.variables import SegmentType, VariableBase +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.entities import GraphInitParams +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from dify_graph.variables import SegmentType, VariableBase from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: - from core.workflow.runtime import GraphRuntimeState + from dify_graph.runtime import GraphRuntimeState class VariableAssignerNode(Node[VariableAssignerData]): diff --git a/api/core/workflow/nodes/variable_assigner/v1/node_data.py b/api/dify_graph/nodes/variable_assigner/v1/node_data.py similarity index 86% rename from api/core/workflow/nodes/variable_assigner/v1/node_data.py rename to api/dify_graph/nodes/variable_assigner/v1/node_data.py index 9734d64712..11e8f93f35 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node_data.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node_data.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from enum import StrEnum -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData class WriteMode(StrEnum): diff --git a/api/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/dify_graph/nodes/variable_assigner/v2/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/v2/__init__.py rename to api/dify_graph/nodes/variable_assigner/v2/__init__.py diff --git a/api/core/workflow/nodes/variable_assigner/v2/entities.py b/api/dify_graph/nodes/variable_assigner/v2/entities.py similarity index 94% rename from api/core/workflow/nodes/variable_assigner/v2/entities.py rename to api/dify_graph/nodes/variable_assigner/v2/entities.py index 2955730289..5f9211d600 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/entities.py +++ b/api/dify_graph/nodes/variable_assigner/v2/entities.py @@ -3,7 +3,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData +from dify_graph.nodes.base import BaseNodeData from .enums import InputType, Operation diff --git a/api/core/workflow/nodes/variable_assigner/v2/enums.py b/api/dify_graph/nodes/variable_assigner/v2/enums.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/v2/enums.py rename to api/dify_graph/nodes/variable_assigner/v2/enums.py diff --git a/api/core/workflow/nodes/variable_assigner/v2/exc.py b/api/dify_graph/nodes/variable_assigner/v2/exc.py similarity index 93% rename from api/core/workflow/nodes/variable_assigner/v2/exc.py rename to api/dify_graph/nodes/variable_assigner/v2/exc.py index 05173b3ca1..c50aab8668 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/exc.py +++ b/api/dify_graph/nodes/variable_assigner/v2/exc.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import Any -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError from .enums import InputType, Operation diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/dify_graph/nodes/variable_assigner/v2/helpers.py similarity index 98% rename from api/core/workflow/nodes/variable_assigner/v2/helpers.py rename to api/dify_graph/nodes/variable_assigner/v2/helpers.py index ce3fe9620c..38c69cbe3c 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/dify_graph/nodes/variable_assigner/v2/helpers.py @@ -1,6 +1,6 @@ from typing import Any -from core.workflow.variables import SegmentType +from dify_graph.variables import SegmentType from .enums import Operation diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py similarity index 93% rename from api/core/workflow/nodes/variable_assigner/v2/node.py rename to api/dify_graph/nodes/variable_assigner/v2/node.py index 0d4c3d2774..7753382cd0 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/dify_graph/nodes/variable_assigner/v2/node.py @@ -2,14 +2,14 @@ import json from collections.abc import Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.variables import SegmentType, VariableBase -from core.workflow.variables.consts import SELECTORS_LENGTH +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from dify_graph.variables import SegmentType, VariableBase +from dify_graph.variables.consts import SELECTORS_LENGTH from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem @@ -23,8 +23,8 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): diff --git a/api/core/workflow/repositories/__init__.py b/api/dify_graph/repositories/__init__.py similarity index 69% rename from api/core/workflow/repositories/__init__.py rename to api/dify_graph/repositories/__init__.py index a778151baa..ef70eb09cc 100644 --- a/api/core/workflow/repositories/__init__.py +++ b/api/dify_graph/repositories/__init__.py @@ -6,7 +6,7 @@ for accessing and manipulating data, regardless of the underlying storage mechanism. """ -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository __all__ = [ "OrderConfig", diff --git a/api/core/workflow/repositories/datasource_manager_protocol.py b/api/dify_graph/repositories/datasource_manager_protocol.py similarity index 91% rename from api/core/workflow/repositories/datasource_manager_protocol.py rename to api/dify_graph/repositories/datasource_manager_protocol.py index 4acf486bef..fbe2016d3c 100644 --- a/api/core/workflow/repositories/datasource_manager_protocol.py +++ b/api/dify_graph/repositories/datasource_manager_protocol.py @@ -3,8 +3,8 @@ from typing import Any, Protocol from pydantic import BaseModel -from core.workflow.file import File -from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent +from dify_graph.file import File +from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent class DatasourceParameter(BaseModel): diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/dify_graph/repositories/draft_variable_repository.py similarity index 95% rename from api/core/workflow/repositories/draft_variable_repository.py rename to api/dify_graph/repositories/draft_variable_repository.py index 66ef714c16..b2ebfacffd 100644 --- a/api/core/workflow/repositories/draft_variable_repository.py +++ b/api/dify_graph/repositories/draft_variable_repository.py @@ -6,7 +6,7 @@ from typing import Any, Protocol from sqlalchemy.orm import Session -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType class DraftVariableSaver(Protocol): diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py similarity index 96% rename from api/core/workflow/repositories/human_input_form_repository.py rename to api/dify_graph/repositories/human_input_form_repository.py index efde59c6fd..88966831cb 100644 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ b/api/dify_graph/repositories/human_input_form_repository.py @@ -4,8 +4,8 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Protocol -from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus class HumanInputError(Exception): diff --git a/api/core/workflow/repositories/index_processor_protocol.py b/api/dify_graph/repositories/index_processor_protocol.py similarity index 100% rename from api/core/workflow/repositories/index_processor_protocol.py rename to api/dify_graph/repositories/index_processor_protocol.py diff --git a/api/core/workflow/repositories/rag_retrieval_protocol.py b/api/dify_graph/repositories/rag_retrieval_protocol.py similarity index 96% rename from api/core/workflow/repositories/rag_retrieval_protocol.py rename to api/dify_graph/repositories/rag_retrieval_protocol.py index f91cecb694..5f3d38167e 100644 --- a/api/core/workflow/repositories/rag_retrieval_protocol.py +++ b/api/dify_graph/repositories/rag_retrieval_protocol.py @@ -2,9 +2,9 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, Field -from core.model_runtime.entities import LLMUsage -from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition -from core.workflow.nodes.llm.entities import ModelConfig +from dify_graph.model_runtime.entities import LLMUsage +from dify_graph.nodes.knowledge_retrieval.entities import MetadataFilteringCondition +from dify_graph.nodes.llm.entities import ModelConfig class SourceChildChunk(BaseModel): diff --git a/api/core/workflow/repositories/summary_index_service_protocol.py b/api/dify_graph/repositories/summary_index_service_protocol.py similarity index 100% rename from api/core/workflow/repositories/summary_index_service_protocol.py rename to api/dify_graph/repositories/summary_index_service_protocol.py diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/dify_graph/repositories/workflow_execution_repository.py similarity index 95% rename from api/core/workflow/repositories/workflow_execution_repository.py rename to api/dify_graph/repositories/workflow_execution_repository.py index d9ce591db8..ef83f07649 100644 --- a/api/core/workflow/repositories/workflow_execution_repository.py +++ b/api/dify_graph/repositories/workflow_execution_repository.py @@ -1,6 +1,6 @@ from typing import Protocol -from core.workflow.entities import WorkflowExecution +from dify_graph.entities import WorkflowExecution class WorkflowExecutionRepository(Protocol): diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/dify_graph/repositories/workflow_node_execution_repository.py similarity index 97% rename from api/core/workflow/repositories/workflow_node_execution_repository.py rename to api/dify_graph/repositories/workflow_node_execution_repository.py index 43b41ff6b8..e6c1c3e497 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/dify_graph/repositories/workflow_node_execution_repository.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol -from core.workflow.entities import WorkflowNodeExecution +from dify_graph.entities import WorkflowNodeExecution @dataclass diff --git a/api/core/workflow/runtime/__init__.py b/api/dify_graph/runtime/__init__.py similarity index 64% rename from api/core/workflow/runtime/__init__.py rename to api/dify_graph/runtime/__init__.py index 10014c7182..adca07e59a 100644 --- a/api/core/workflow/runtime/__init__.py +++ b/api/dify_graph/runtime/__init__.py @@ -1,9 +1,17 @@ -from .graph_runtime_state import GraphRuntimeState +from .graph_runtime_state import ( + ChildEngineBuilderNotConfiguredError, + ChildEngineError, + ChildGraphNotFoundError, + GraphRuntimeState, +) from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper from .variable_pool import VariablePool, VariableValue __all__ = [ + "ChildEngineBuilderNotConfiguredError", + "ChildEngineError", + "ChildGraphNotFoundError", "GraphRuntimeState", "ReadOnlyGraphRuntimeState", "ReadOnlyGraphRuntimeStateWrapper", diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py similarity index 91% rename from api/core/workflow/runtime/graph_runtime_state.py rename to api/dify_graph/runtime/graph_runtime_state.py index 1e8dba750f..0fb3a54ce8 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/dify_graph/runtime/graph_runtime_state.py @@ -10,13 +10,14 @@ from typing import TYPE_CHECKING, Any, ClassVar, Protocol from pydantic import BaseModel, Field from pydantic.json import pydantic_encoder -from core.model_runtime.entities.llm_entities import LLMUsage from core.sandbox.sandbox import Sandbox -from core.workflow.enums import NodeExecutionType, NodeState, NodeType -from core.workflow.runtime.variable_pool import VariablePool +from dify_graph.enums import NodeExecutionType, NodeState, NodeType +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime.variable_pool import VariablePool if TYPE_CHECKING: - from core.workflow.entities.pause_reason import PauseReason + from dify_graph.entities import GraphInitParams + from dify_graph.entities.pause_reason import PauseReason class ReadyQueueProtocol(Protocol): @@ -136,6 +137,31 @@ class GraphProtocol(Protocol): def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... +class ChildGraphEngineBuilderProtocol(Protocol): + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> Any: ... + + +class ChildEngineError(ValueError): + """Base error type for child-engine creation failures.""" + + +class ChildEngineBuilderNotConfiguredError(ChildEngineError): + """Raised when child-engine creation is requested without a bound builder.""" + + +class ChildGraphNotFoundError(ChildEngineError): + """Raised when the requested child graph entry point cannot be resolved.""" + + class _GraphStateSnapshot(BaseModel): """Serializable graph state snapshot for node/edge states.""" @@ -210,6 +236,7 @@ class GraphRuntimeState: self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() self._deferred_nodes: set[str] = set() + self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None # Node and edges states needed to be restored into # graph object. @@ -253,6 +280,31 @@ class GraphRuntimeState: if self._graph is not None: _ = self.response_coordinator + def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: + self._child_engine_builder = builder + + def create_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> Any: + if self._child_engine_builder is None: + raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") + + return self._child_engine_builder.build_child_engine( + workflow_id=workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + graph_config=graph_config, + root_node_id=root_node_id, + layers=layers, + ) + # ------------------------------------------------------------------ # Primary collaborators # ------------------------------------------------------------------ @@ -446,13 +498,13 @@ class GraphRuntimeState: # ------------------------------------------------------------------ def _build_ready_queue(self) -> ReadyQueueProtocol: # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("core.workflow.graph_engine.ready_queue") + module = importlib.import_module("dify_graph.graph_engine.ready_queue") in_memory_cls = module.InMemoryReadyQueue return in_memory_cls() def _build_graph_execution(self) -> GraphExecutionProtocol: # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution") + module = importlib.import_module("dify_graph.graph_engine.domain.graph_execution") graph_execution_cls = module.GraphExecution workflow_id = self._pending_graph_execution_workflow_id or "" self._pending_graph_execution_workflow_id = None @@ -460,7 +512,7 @@ class GraphRuntimeState: def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("core.workflow.graph_engine.response_coordinator") + module = importlib.import_module("dify_graph.graph_engine.response_coordinator") coordinator_cls = module.ResponseStreamCoordinator return coordinator_cls(variable_pool=self.variable_pool, graph=graph) diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/dify_graph/runtime/graph_runtime_state_protocol.py similarity index 92% rename from api/core/workflow/runtime/graph_runtime_state_protocol.py rename to api/dify_graph/runtime/graph_runtime_state_protocol.py index 75b1a170ec..6109325012 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/dify_graph/runtime/graph_runtime_state_protocol.py @@ -1,9 +1,9 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.system_variable import SystemVariableReadOnlyView -from core.workflow.variables.segments import Segment +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.system_variable import SystemVariableReadOnlyView +from dify_graph.variables.segments import Segment class ReadOnlyVariablePool(Protocol): diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/dify_graph/runtime/read_only_wrappers.py similarity index 93% rename from api/core/workflow/runtime/read_only_wrappers.py rename to api/dify_graph/runtime/read_only_wrappers.py index e7162811d5..cbda4dcbe4 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/dify_graph/runtime/read_only_wrappers.py @@ -4,9 +4,9 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.system_variable import SystemVariableReadOnlyView -from core.workflow.variables.segments import Segment +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.system_variable import SystemVariableReadOnlyView +from dify_graph.variables.segments import Segment from .graph_runtime_state import GraphRuntimeState from .variable_pool import VariablePool diff --git a/api/core/workflow/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py similarity index 96% rename from api/core/workflow/runtime/variable_pool.py rename to api/dify_graph/runtime/variable_pool.py index bedb52ae5f..4602d690e2 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/dify_graph/runtime/variable_pool.py @@ -8,18 +8,18 @@ from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field -from core.workflow.constants import ( +from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from core.workflow.file import File, FileAttribute, file_manager -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import Segment, SegmentGroup, VariableBase -from core.workflow.variables.consts import SELECTORS_LENGTH -from core.workflow.variables.segments import FileSegment, ObjectSegment -from core.workflow.variables.variables import RAGPipelineVariableInput, Variable +from dify_graph.file import File, FileAttribute, file_manager +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import Segment, SegmentGroup, VariableBase +from dify_graph.variables.consts import SELECTORS_LENGTH +from dify_graph.variables.segments import FileSegment, ObjectSegment +from dify_graph.variables.variables import RAGPipelineVariableInput, Variable from factories import variable_factory VariableValue = Union[str, int, float, dict[str, object], list[object], File] diff --git a/api/core/workflow/system_variable.py b/api/dify_graph/system_variable.py similarity index 98% rename from api/core/workflow/system_variable.py rename to api/dify_graph/system_variable.py index 4144f79b8a..cc5deda892 100644 --- a/api/core/workflow/system_variable.py +++ b/api/dify_graph/system_variable.py @@ -7,8 +7,8 @@ from uuid import uuid4 from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator -from core.workflow.enums import SystemVariableKey -from core.workflow.file.models import File +from dify_graph.enums import SystemVariableKey +from dify_graph.file.models import File class SystemVariable(BaseModel): diff --git a/api/core/workflow/utils/condition/__init__.py b/api/dify_graph/utils/__init__.py similarity index 100% rename from api/core/workflow/utils/condition/__init__.py rename to api/dify_graph/utils/__init__.py diff --git a/api/dify_graph/utils/condition/__init__.py b/api/dify_graph/utils/condition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/utils/condition/entities.py b/api/dify_graph/utils/condition/entities.py similarity index 100% rename from api/core/workflow/utils/condition/entities.py rename to api/dify_graph/utils/condition/entities.py diff --git a/api/core/workflow/utils/condition/processor.py b/api/dify_graph/utils/condition/processor.py similarity index 98% rename from api/core/workflow/utils/condition/processor.py rename to api/dify_graph/utils/condition/processor.py index 4e635cc2f2..dea72d96c2 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/dify_graph/utils/condition/processor.py @@ -2,10 +2,10 @@ import json from collections.abc import Mapping, Sequence from typing import Literal, NamedTuple -from core.workflow.file import FileAttribute, file_manager -from core.workflow.runtime import VariablePool -from core.workflow.variables import ArrayFileSegment -from core.workflow.variables.segments import ArrayBooleanSegment, BooleanSegment +from dify_graph.file import FileAttribute, file_manager +from dify_graph.runtime import VariablePool +from dify_graph.variables import ArrayFileSegment +from dify_graph.variables.segments import ArrayBooleanSegment, BooleanSegment from .entities import Condition, SubCondition, SupportedComparisonOperator diff --git a/api/core/workflow/variable_loader.py b/api/dify_graph/variable_loader.py similarity index 95% rename from api/core/workflow/variable_loader.py rename to api/dify_graph/variable_loader.py index dfa4ce2e75..d263450334 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/dify_graph/variable_loader.py @@ -2,9 +2,9 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.workflow.runtime import VariablePool -from core.workflow.variables import VariableBase -from core.workflow.variables.consts import SELECTORS_LENGTH +from dify_graph.runtime import VariablePool +from dify_graph.variables import VariableBase +from dify_graph.variables.consts import SELECTORS_LENGTH class VariableLoader(Protocol): diff --git a/api/core/workflow/variables/__init__.py b/api/dify_graph/variables/__init__.py similarity index 100% rename from api/core/workflow/variables/__init__.py rename to api/dify_graph/variables/__init__.py diff --git a/api/core/workflow/variables/consts.py b/api/dify_graph/variables/consts.py similarity index 100% rename from api/core/workflow/variables/consts.py rename to api/dify_graph/variables/consts.py diff --git a/api/core/workflow/variables/exc.py b/api/dify_graph/variables/exc.py similarity index 100% rename from api/core/workflow/variables/exc.py rename to api/dify_graph/variables/exc.py diff --git a/api/core/workflow/variables/input_entities.py b/api/dify_graph/variables/input_entities.py similarity index 97% rename from api/core/workflow/variables/input_entities.py rename to api/dify_graph/variables/input_entities.py index 9a42012f0a..e6a68ea359 100644 --- a/api/core/workflow/variables/input_entities.py +++ b/api/dify_graph/variables/input_entities.py @@ -5,7 +5,7 @@ from typing import Any from jsonschema import Draft7Validator, SchemaError from pydantic import BaseModel, Field, field_validator -from core.workflow.file import FileTransferMethod, FileType +from dify_graph.file import FileTransferMethod, FileType class VariableEntityType(StrEnum): diff --git a/api/core/workflow/variables/segment_group.py b/api/dify_graph/variables/segment_group.py similarity index 100% rename from api/core/workflow/variables/segment_group.py rename to api/dify_graph/variables/segment_group.py diff --git a/api/core/workflow/variables/segments.py b/api/dify_graph/variables/segments.py similarity index 98% rename from api/core/workflow/variables/segments.py rename to api/dify_graph/variables/segments.py index f266d92e2e..8060fb573f 100644 --- a/api/core/workflow/variables/segments.py +++ b/api/dify_graph/variables/segments.py @@ -5,8 +5,8 @@ from typing import Annotated, Any, TypeAlias from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator -from core.model_runtime.entities import PromptMessage -from core.workflow.file import File +from dify_graph.file import File +from dify_graph.model_runtime.entities import PromptMessage from .types import SegmentType diff --git a/api/core/workflow/variables/types.py b/api/dify_graph/variables/types.py similarity index 99% rename from api/core/workflow/variables/types.py rename to api/dify_graph/variables/types.py index 0f979dcf25..b295edd6e2 100644 --- a/api/core/workflow/variables/types.py +++ b/api/dify_graph/variables/types.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from enum import StrEnum from typing import TYPE_CHECKING, Any -from core.workflow.file.models import File +from dify_graph.file.models import File if TYPE_CHECKING: pass diff --git a/api/core/workflow/variables/utils.py b/api/dify_graph/variables/utils.py similarity index 100% rename from api/core/workflow/variables/utils.py rename to api/dify_graph/variables/utils.py diff --git a/api/core/workflow/variables/variables.py b/api/dify_graph/variables/variables.py similarity index 100% rename from api/core/workflow/variables/variables.py rename to api/dify_graph/variables/variables.py diff --git a/api/core/workflow/workflow_type_encoder.py b/api/dify_graph/workflow_type_encoder.py similarity index 95% rename from api/core/workflow/workflow_type_encoder.py rename to api/dify_graph/workflow_type_encoder.py index a192b884f7..3dd846b3cb 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/dify_graph/workflow_type_encoder.py @@ -4,8 +4,8 @@ from typing import Any, overload from pydantic import BaseModel -from core.workflow.file.models import File -from core.workflow.variables import Segment +from dify_graph.file.models import File +from dify_graph.variables import Segment class WorkflowRuntimeTypeConverter: diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index bac2fbef47..5c02a16a7d 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -2,8 +2,8 @@ import logging from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.nodes import NodeType -from core.workflow.nodes.tool.entities import ToolEntity +from dify_graph.nodes import NodeType +from dify_graph.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/sync_workflow_schedule_when_app_published.py b/api/events/event_handlers/sync_workflow_schedule_when_app_published.py index 168513fc04..90f562d167 100644 --- a/api/events/event_handlers/sync_workflow_schedule_when_app_published.py +++ b/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @@ -4,7 +4,7 @@ from typing import cast from sqlalchemy import select from sqlalchemy.orm import Session -from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate +from dify_graph.nodes.trigger_schedule.entities import SchedulePlanUpdate from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models import AppMode, Workflow, WorkflowSchedulePlan diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 53e0065f6e..8da33d03b9 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -2,8 +2,8 @@ from typing import cast from sqlalchemy import select -from core.workflow.nodes import NodeType -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from dify_graph.nodes import NodeType +from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models.dataset import AppDatasetJoin diff --git a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py index 430514ada2..fd211a3e55 100644 --- a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @@ -3,7 +3,7 @@ from typing import cast from sqlalchemy import select from sqlalchemy.orm import Session -from core.workflow.nodes import NodeType +from dify_graph.nodes import NodeType from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models import AppMode diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 658e6a0738..26262484f9 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -18,6 +18,7 @@ from dify_app import DifyApp from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel if TYPE_CHECKING: from redis.lock import Lock @@ -181,13 +182,18 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")] + sentinel_kwargs = { + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, + } + + if dify_config.REDIS_MAX_CONNECTIONS: + sentinel_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + sentinel = Sentinel( sentinel_hosts, - sentinel_kwargs={ - "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, - "username": dify_config.REDIS_SENTINEL_USERNAME, - "password": dify_config.REDIS_SENTINEL_PASSWORD, - }, + sentinel_kwargs=sentinel_kwargs, ) master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) @@ -204,12 +210,15 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: for node in dify_config.REDIS_CLUSTERS.split(",") ] - cluster: RedisCluster = RedisCluster( - startup_nodes=nodes, - password=dify_config.REDIS_CLUSTERS_PASSWORD, - protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, - cache_config=_get_cache_configuration(), - ) + cluster_kwargs: dict[str, Any] = { + "startup_nodes": nodes, + "password": dify_config.REDIS_CLUSTERS_PASSWORD, + "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, + "cache_config": _get_cache_configuration(), + } + if dify_config.REDIS_MAX_CONNECTIONS: + cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + cluster: RedisCluster = RedisCluster(**cluster_kwargs) return cluster @@ -225,6 +234,9 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis } ) + if dify_config.REDIS_MAX_CONNECTIONS: + redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + if ssl_kwargs: redis_params.update(ssl_kwargs) @@ -234,9 +246,17 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: + max_conns = dify_config.REDIS_MAX_CONNECTIONS if use_clusters: - return RedisCluster.from_url(pubsub_url) - return redis.Redis.from_url(pubsub_url) + if max_conns: + return RedisCluster.from_url(pubsub_url, max_connections=max_conns) + else: + return RedisCluster.from_url(pubsub_url) + + if max_conns: + return redis.Redis.from_url(pubsub_url, max_connections=max_conns) + else: + return redis.Redis.from_url(pubsub_url) def init_app(app: DifyApp): @@ -269,6 +289,11 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here." if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": return ShardedRedisBroadcastChannel(_pubsub_redis_client) + if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams": + return StreamsBroadcastChannel( + _pubsub_redis_client, + retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS, + ) return RedisBroadcastChannel(_pubsub_redis_client) diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index c3aa8edf80..9a34acb0c1 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -10,7 +10,7 @@ def init_app(app: DifyApp): from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from core.model_runtime.errors.invoke import InvokeRateLimitError + from dify_graph.model_runtime.errors.invoke import InvokeRateLimitError def before_send(event, hint): if "exc_info" in hint: diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 817c8b0448..7ee4638e77 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -13,7 +13,7 @@ from typing import Any from sqlalchemy.orm import sessionmaker -from core.workflow.enums import WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 9928879a7b..c58aa6adbb 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -8,9 +8,9 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.workflow.entities import WorkflowExecution -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowExecution +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id from models import ( diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 4897171b12..bd1c08d96e 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -16,13 +16,13 @@ from typing import Any, Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities import WorkflowNodeExecution -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.enums import NodeType -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeType +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 66d1c977d6..fc84147e01 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -9,11 +9,11 @@ from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel -from core.workflow.enums import NodeType -from core.workflow.file.models import File -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node -from core.workflow.variables import Segment +from dify_graph.enums import NodeType +from dify_graph.file.models import File +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node +from dify_graph.variables import Segment from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 8556974080..3da9a9e97d 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -8,8 +8,8 @@ from typing import Any from opentelemetry.trace import Span -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 82cb865b8b..dd658b250b 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -8,9 +8,9 @@ from typing import Any from opentelemetry.trace import Span -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node -from core.workflow.variables import Segment +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node +from dify_graph.variables import Segment from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index b99180722b..f4e6a18b4d 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -4,10 +4,10 @@ Parser for tool nodes that captures tool-specific metadata. from opentelemetry.trace import Span -from core.workflow.enums import WorkflowNodeExecutionMetadataKey -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.tool.entities import ToolNodeData +from dify_graph.enums import WorkflowNodeExecutionMetadataKey +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.tool.entities import ToolNodeData from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index f534f9e79a..e594a66a38 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -14,7 +14,7 @@ from werkzeug.http import parse_options_header from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.helper import ssrf_proxy -from core.workflow.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers +from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index a0a1812f4e..2ec1e31c8b 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -3,21 +3,21 @@ from typing import Any, cast from uuid import uuid4 from configs import dify_config -from core.model_runtime.entities import PromptMessage -from core.model_runtime.entities.message_entities import ( +from dify_graph.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) +from dify_graph.file import File +from dify_graph.model_runtime.entities import PromptMessage +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageRole, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) -from core.workflow.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) -from core.workflow.file import File -from core.workflow.variables.exc import VariableError -from core.workflow.variables.segments import ( +from dify_graph.variables.exc import VariableError +from dify_graph.variables.segments import ( ArrayAnySegment, ArrayBooleanSegment, ArrayFileSegment, @@ -35,8 +35,8 @@ from core.workflow.variables.segments import ( Segment, StringSegment, ) -from core.workflow.variables.types import SegmentType -from core.workflow.variables.variables import ( +from dify_graph.variables.types import SegmentType +from dify_graph.variables.variables import ( ArrayAnyVariable, ArrayBooleanVariable, ArrayFileVariable, diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index 461c163e2f..ac7c5376fb 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -1,7 +1,7 @@ from typing import TypedDict -from core.workflow.variables.segments import Segment -from core.workflow.variables.types import SegmentType +from dify_graph.variables.segments import Segment +from dify_graph.variables.types import SegmentType class _VarTypedDict(TypedDict, total=False): diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 9876b1aba6..54f787c2d5 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -5,7 +5,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from core.workflow.file import File +from dify_graph.file import File JSONValue: TypeAlias = Any diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 29b9e40242..7ee628726b 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -5,7 +5,7 @@ from datetime import datetime from flask_restx import fields from pydantic import BaseModel, ConfigDict, computed_field, field_validator -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 38d04f2435..91c8c788d6 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -7,7 +7,7 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel -from core.workflow.file import File +from dify_graph.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile JSONValueType: TypeAlias = JSONValue diff --git a/api/fields/raws.py b/api/fields/raws.py index 33b47ba2c3..318dedc25c 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,6 +1,6 @@ from flask_restx import fields -from core.workflow.file import File +from dify_graph.file import File class FilesContainedField(fields.Raw): diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 019949e105..7ce2139687 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields from core.helper import encrypter -from core.workflow.variables import SecretVariable, SegmentType, VariableBase +from dify_graph.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py new file mode 100644 index 0000000000..d6ec5504ca --- /dev/null +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import logging +import queue +import threading +from collections.abc import Iterator +from typing import Self + +from libs.broadcast_channel.channel import Producer, Subscriber, Subscription +from libs.broadcast_channel.exc import SubscriptionClosedError +from redis import Redis, RedisCluster + +logger = logging.getLogger(__name__) + + +class StreamsBroadcastChannel: + """ + Redis Streams based broadcast channel implementation. + + Characteristics: + - At-least-once delivery for late subscribers within the stream retention window. + - Each topic is stored as a dedicated Redis Stream key. + - The stream key expires `retention_seconds` after the last event is published (to bound storage). + """ + + def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600): + self._client = redis_client + self._retention_seconds = max(int(retention_seconds or 0), 0) + + def topic(self, topic: str) -> StreamsTopic: + return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds) + + +class StreamsTopic: + def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): + self._client = redis_client + self._topic = topic + self._key = f"stream:{topic}" + self._retention_seconds = retention_seconds + self.max_length = 5000 + + def as_producer(self) -> Producer: + return self + + def publish(self, payload: bytes) -> None: + self._client.xadd(self._key, {b"data": payload}, maxlen=self.max_length) + if self._retention_seconds > 0: + try: + self._client.expire(self._key, self._retention_seconds) + except Exception as e: + logger.warning("Failed to set expire for stream key %s: %s", self._key, e, exc_info=True) + + def as_subscriber(self) -> Subscriber: + return self + + def subscribe(self) -> Subscription: + return _StreamsSubscription(self._client, self._key) + + +class _StreamsSubscription(Subscription): + _SENTINEL = object() + + def __init__(self, client: Redis | RedisCluster, key: str): + self._client = client + self._key = key + self._closed = threading.Event() + self._last_id = "0-0" + self._queue: queue.Queue[object] = queue.Queue() + self._start_lock = threading.Lock() + self._listener: threading.Thread | None = None + + def _listen(self) -> None: + try: + while not self._closed.is_set(): + streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + + if not streams: + continue + + for _key, entries in streams: + for entry_id, fields in entries: + data = None + if isinstance(fields, dict): + data = fields.get(b"data") + data_bytes: bytes | None = None + if isinstance(data, str): + data_bytes = data.encode() + elif isinstance(data, (bytes, bytearray)): + data_bytes = bytes(data) + if data_bytes is not None: + self._queue.put_nowait(data_bytes) + self._last_id = entry_id + finally: + self._queue.put_nowait(self._SENTINEL) + self._listener = None + + def _start_if_needed(self) -> None: + if self._listener is not None: + return + # Ensure only one listener thread is created under concurrent calls + with self._start_lock: + if self._listener is not None or self._closed.is_set(): + return + self._listener = threading.Thread( + target=self._listen, + name=f"redis-streams-sub-{self._key}", + daemon=True, + ) + self._listener.start() + + def __iter__(self) -> Iterator[bytes]: + # Iterator delegates to receive with timeout; stops on closure. + self._start_if_needed() + while not self._closed.is_set(): + item = self.receive(timeout=1) + if item is not None: + yield item + + def receive(self, timeout: float | None = 0.1) -> bytes | None: + if self._closed.is_set(): + raise SubscriptionClosedError("The Redis streams subscription is closed") + self._start_if_needed() + + try: + if timeout is None: + item = self._queue.get() + else: + item = self._queue.get(timeout=timeout) + except queue.Empty: + return None + + if item is self._SENTINEL or self._closed.is_set(): + raise SubscriptionClosedError("The Redis streams subscription is closed") + assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" + return bytes(item) + + def close(self) -> None: + if self._closed.is_set(): + return + self._closed.set() + listener = self._listener + if listener is not None: + listener.join(timeout=2.0) + if listener.is_alive(): + logger.warning( + "Streams subscription listener for key %s did not stop within timeout; keeping reference.", + self._key, + ) + else: + self._listener = None + + # Context manager helpers + def __enter__(self) -> Self: + self._start_if_needed() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool | None: + self.close() + return None diff --git a/api/libs/helper.py b/api/libs/helper.py index 206bb8fd81..6151eb0940 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -21,8 +21,8 @@ from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client if TYPE_CHECKING: diff --git a/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py b/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py new file mode 100644 index 0000000000..ed794178b3 --- /dev/null +++ b/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py @@ -0,0 +1,37 @@ +"""add partial indexes on conversations for app_id with created_at and updated_at + +Revision ID: e288952f2994 +Revises: fce013ca180e +Create Date: 2026-02-26 13:36:45.928922 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'e288952f2994' +down_revision = 'fce013ca180e' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.create_index( + 'conversation_app_created_at_idx', + ['app_id', sa.literal_column('created_at DESC')], + unique=False, + postgresql_where=sa.text('is_deleted IS false'), + ) + batch_op.create_index( + 'conversation_app_updated_at_idx', + ['app_id', sa.literal_column('updated_at DESC')], + unique=False, + postgresql_where=sa.text('is_deleted IS false'), + ) + + +def downgrade(): + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_index('conversation_app_updated_at_idx') + batch_op.drop_index('conversation_app_created_at_idx') diff --git a/api/models/__init__.py b/api/models/__init__.py index 6b9d509482..c5dbb250a2 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -36,7 +36,6 @@ from .enums import ( AppTriggerStatus, AppTriggerType, CreatorUserRole, - UserFrom, WorkflowRunTriggeredFrom, WorkflowTriggerStatus, ) @@ -218,7 +217,6 @@ __all__ = [ "TriggerOAuthTenantClient", "TriggerSubscription", "UploadFile", - "UserFrom", "Whitelist", "Workflow", "WorkflowAppLog", diff --git a/api/models/dataset.py b/api/models/dataset.py index e7da2961bc..4ef39fcde1 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -19,6 +19,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -51,6 +52,7 @@ class Dataset(Base): INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] + DOC_FORM_LIST = [member.value for member in IndexStructureType] id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID) diff --git a/api/models/enums.py b/api/models/enums.py index 2bc61120ce..ed6236209f 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,6 +1,6 @@ from enum import StrEnum -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType class CreatorUserRole(StrEnum): @@ -8,11 +8,6 @@ class CreatorUserRole(StrEnum): END_USER = "end_user" -class UserFrom(StrEnum): - ACCOUNT = "account" - END_USER = "end-user" - - class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" # webapp / service api diff --git a/api/models/human_input.py b/api/models/human_input.py index 5208461de1..709cc8fe61 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -6,7 +6,7 @@ import sqlalchemy as sa from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( DeliveryMethodType, HumanInputFormKind, HumanInputFormStatus, diff --git a/api/models/model.py b/api/models/model.py index fa538d107a..30a09d7e73 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -19,9 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from core.workflow.file import helpers as file_helpers +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 @@ -713,6 +713,18 @@ class Conversation(Base): __table_args__ = ( sa.PrimaryKeyConstraint("id", name="conversation_pkey"), sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), + sa.Index( + "conversation_app_created_at_idx", + "app_id", + sa.text("created_at DESC"), + postgresql_where=sa.text("is_deleted IS false"), + ), + sa.Index( + "conversation_app_updated_at_idx", + "app_id", + sa.text("updated_at DESC"), + postgresql_where=sa.text("is_deleted IS false"), + ), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) diff --git a/api/models/workflow.py b/api/models/workflow.py index b4eed0caef..b445f0eee2 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -22,17 +22,17 @@ from sqlalchemy import ( from sqlalchemy.orm import Mapped, declared_attr, mapped_column from typing_extensions import deprecated -from core.workflow.constants import ( +from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import NodeType, WorkflowExecutionStatus -from core.workflow.file.constants import maybe_file_object -from core.workflow.file.models import File -from core.workflow.variables import utils as variable_utils -from core.workflow.variables.variables import FloatVariable, IntegerVariable, StringVariable +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from dify_graph.enums import NodeType, WorkflowExecutionStatus +from dify_graph.file.constants import maybe_file_object +from dify_graph.file.models import File +from dify_graph.variables import utils as variable_utils +from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -46,7 +46,7 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.workflow.variables import SecretVariable, Segment, SegmentType, VariableBase +from dify_graph.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper @@ -379,7 +379,7 @@ class Workflow(Base): # bug "selected": false, } - For specific node type, refer to `core.workflow.nodes` + For specific node type, refer to `dify_graph.nodes` """ graph_dict = self.graph_dict if "nodes" not in graph_dict: @@ -1383,7 +1383,7 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. # - # ref: api/core/workflow/entities/variable_pool.py:18 + # ref: api/dify_graph/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), diff --git a/api/pyproject.toml b/api/pyproject.toml index 8dc256c3fe..d90a9b7db9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -246,7 +246,7 @@ module = [ "configs.middleware.cache.redis_pubsub_config", "extensions.ext_redis", "tasks.workflow_execution_tasks", - "core.workflow.nodes.base.node", + "dify_graph.nodes.base.node", "services.human_input_delivery_test_service", "core.app.apps.advanced_chat.app_generator", "controllers.console.human_input_form", diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 6446eb0d6e..2fa065bcc8 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index ffa87b209f..a96c4acb31 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -40,9 +40,9 @@ from typing import Protocol from sqlalchemy.orm import Session -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.enums import WorkflowType -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.enums import WorkflowType +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index a3c4039aaa..be28b7e613 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime -from core.workflow.entities.pause_reason import PauseReason +from dify_graph.entities.pause_reason import PauseReason class WorkflowPauseEntity(ABC): diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 6c696b6478..2266c2e646 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -14,7 +14,7 @@ from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 5ba7a7e7e8..fdd3e123e4 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -33,9 +33,9 @@ from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import WorkflowExecutionStatus, WorkflowType -from core.workflow.nodes.human_input.entities import FormDefinition +from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from dify_graph.enums import WorkflowExecutionStatus, WorkflowType +from dify_graph.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 5a2c0ea46f..508db22eb0 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -18,9 +18,9 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from core.workflow.nodes.human_input.entities import FormDefinition -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.human_input.entities import FormDefinition +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/services/account_service.py b/api/services/account_service.py index 648b5e834f..f0eac2a522 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -74,6 +74,16 @@ from tasks.mail_reset_password_task import ( logger = logging.getLogger(__name__) +def _try_join_enterprise_default_workspace(account_id: str) -> None: + """Best-effort join to enterprise default workspace.""" + if not dify_config.ENTERPRISE_ENABLED: + return + + from services.enterprise.enterprise_service import try_join_default_workspace + + try_join_default_workspace(account_id) + + class TokenPair(BaseModel): access_token: str refresh_token: str @@ -287,13 +297,14 @@ class AccountService: email=email, name=name, interface_language=interface_language, password=password ) - TenantService.create_owner_tenant_if_not_exist(account=account) + try: + TenantService.create_owner_tenant_if_not_exist(account=account) + except Exception: + # Enterprise-only side-effect should run independently from personal workspace creation. + _try_join_enterprise_default_workspace(str(account.id)) + raise - # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). - if dify_config.ENTERPRISE_ENABLED: - from services.enterprise.enterprise_service import try_join_default_workspace - - try_join_default_workspace(str(account.id)) + _try_join_enterprise_default_workspace(str(account.id)) return account @@ -1407,18 +1418,18 @@ class RegisterService: and create_workspace_required and FeatureService.get_system_features().license.workspaces.is_available() ): - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role="owner") - account.current_tenant = tenant - tenant_was_created.send(tenant) + try: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + except Exception: + _try_join_enterprise_default_workspace(str(account.id)) + raise db.session.commit() - # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). - if dify_config.ENTERPRISE_ENABLED: - from services.enterprise.enterprise_service import try_join_default_workspace - - try_join_default_workspace(str(account.id)) + _try_join_enterprise_default_workspace(str(account.id)) except WorkSpaceNotAllowedCreateError: db.session.rollback() logger.exception("Register failed") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 9400362605..5790c8b9ec 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -18,15 +18,15 @@ from sqlalchemy.orm import Session from configs import dify_config from core.helper import ssrf_proxy -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import PluginDependency -from core.workflow.enums import NodeType -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData -from core.workflow.nodes.tool.entities import ToolNodeData -from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode +from dify_graph.enums import NodeType +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from dify_graph.nodes.llm.entities import LLMNodeData +from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData +from dify_graph.nodes.tool.entities import ToolNodeData +from dify_graph.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 31003cb8f7..40013f2b66 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -38,6 +38,13 @@ if TYPE_CHECKING: class AppGenerateService: @staticmethod def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]: + """ + Build a subscription callback that coordinates when the background task starts. + + - streams transport: start immediately (events are durable; late subscribers can replay). + - pubsub/sharded transport: start on first subscribe, with a short fallback timer so the task + still runs if the client never connects. + """ started = False lock = threading.Lock() @@ -54,10 +61,18 @@ class AppGenerateService: started = True return True - # XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber. - # The Celery task may publish the first event before the API side actually subscribes, - # causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe, - # but also use a short fallback timer so the task still runs if the client never consumes. + channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE + if channel_type == "streams": + # With Redis Streams, we can safely start right away; consumers can read past events. + _try_start() + + # Keep return type Callable[[], None] consistent while allowing an extra (no-op) call. + def _on_subscribe_streams() -> None: + _try_start() + + return _on_subscribe_streams + + # Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback. timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) timer.daemon = True timer.start() diff --git a/api/services/app_service.py b/api/services/app_service.py index e57253f8b6..ce6826ef5c 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -10,10 +10,10 @@ from constants.model_template import default_app_templates from core.agent.entities import AgentToolEntity from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.app_event import app_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 5ae1fba2e8..d556230044 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -7,7 +7,7 @@ new GraphEngine command channel mechanism. from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.manager import GraphEngineManager from extensions.ext_redis import redis_client from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a95361cebd..1b698fad17 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -8,7 +8,7 @@ from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.enums import MessageStatus from models.model import App, AppMode, Message diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index aefc34fcae..0e0eab00ad 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -10,7 +10,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 4c87150cf7..566c27c0f3 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -10,7 +10,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator -from core.workflow.variables.types import SegmentType +from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index b0012d6f6a..f00e3fe01e 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from core.workflow.variables.variables import VariableBase +from dify_graph.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 35b20f7601..3a7d483a9d 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -20,12 +20,12 @@ from core.db.session_factory import session_factory from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index eeb14072bd..95a50f0512 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -10,11 +10,11 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper import encrypter from core.helper.name_generator import generate_incremental_name from core.helper.provider_cache import NoOpProviderCredentialCache -from core.model_runtime.entities.provider_entities import FormType from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter +from dify_graph.model_runtime.entities.provider_entities import FormType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 8dc5b93501..66309f0e59 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,8 +1,9 @@ from enum import StrEnum from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, field_validator +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -127,6 +128,18 @@ class KnowledgeConfig(BaseModel): name: str | None = None is_multimodal: bool = False + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] + if value not in valid_forms: + raise ValueError("Invalid doc_form.") + return value + class SegmentCreateArgs(BaseModel): content: str | None = None diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index a29d848ac5..9dd595f516 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -15,9 +15,9 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, ModelCredentialSchema, ProviderCredentialSchema, diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 65dd41af43..4cf42b7f44 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -9,7 +9,7 @@ from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition -from core.workflow.nodes.http_request.exc import InvalidHttpMethodError +from dify_graph.nodes.http_request.exc import InvalidHttpMethodError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import ( diff --git a/api/services/file_service.py b/api/services/file_service.py index da99a66bb9..e08b78bf4c 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -20,7 +20,7 @@ from constants import ( VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 8cbf3a25c3..c00c76a826 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -4,12 +4,12 @@ import time from typing import Any from core.app.app_config.entities import ModelConfig -from core.model_runtime.entities import LLMMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod +from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index ff37ff098f..7b43c49686 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -8,14 +8,14 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, ExternalRecipient, MemberRecipient, ) -from core.workflow.runtime import VariablePool +from dify_graph.runtime import VariablePool from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_template_renderer import render_email_template diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 87816643f6..2e74c50963 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -11,12 +11,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormDefinition, HumanInputSubmissionValidationError, validate_human_input_submission, ) -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType @@ -130,7 +130,7 @@ class HumanInputService: if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory - self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) + self._form_repository = form_repository or HumanInputFormSubmissionRepository() def get_form_by_token(self, form_token: str) -> Form | None: record = self._form_repository.get_by_token(form_token) diff --git a/api/services/message_service.py b/api/services/message_service.py index ce699e79d4..789b6c2f8c 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -9,10 +9,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 69da3bfb79..2133dc5b3a 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -10,13 +10,13 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( +from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.provider_manager import ProviderManager +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index edd1004b82..0ddd6b9b1a 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,9 @@ import logging from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity -from core.model_runtime.entities.model_entities import ModelType, ParameterRule -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelType, ParameterRule +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index c0f9e4f323..ce745a4679 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,23 +36,23 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities.workflow_node_execution import ( +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from core.workflow.graph_events.base import GraphNodeEventBase -from core.workflow.node_events.base import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.variables import VariableBase -from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.enums import ErrorStrategy, NodeType, SystemVariableKey +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.graph_events.base import GraphNodeEventBase +from dify_graph.node_events.base import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.variables import VariableBase from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index be1ce834f6..58bb4b7c90 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -21,15 +21,15 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import PluginDependency -from core.workflow.enums import NodeType -from core.workflow.nodes.datasource.entities import DatasourceNodeData -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData -from core.workflow.nodes.tool.entities import ToolNodeData +from dify_graph.enums import NodeType +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.nodes.datasource.entities import DatasourceNodeData +from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from dify_graph.nodes.llm.entities import LLMNodeData +from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData +from dify_graph.nodes.tool.entities import ToolNodeData from extensions.ext_redis import redis_client from factories import variable_factory from models import Account diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index ea5cbb7740..00a2144800 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -31,7 +31,7 @@ from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from core.workflow.enums import WorkflowType +from dify_graph.enums import WorkflowType from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.archive_storage import ( diff --git a/api/services/skill_service.py b/api/services/skill_service.py index df8f66ac75..347295139d 100644 --- a/api/services/skill_service.py +++ b/api/services/skill_service.py @@ -18,6 +18,8 @@ from collections.abc import Mapping from functools import reduce from typing import Any, cast +from core.workflow.enums import NodeType + from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode from core.sandbox.entities.config import AppAssets from core.skill.assembler import SkillBundleAssembler, SkillDocumentAssembler @@ -26,7 +28,6 @@ from core.skill.entities.skill_document import SkillDocument from core.skill.entities.skill_metadata import SkillMetadata from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency from core.skill.skill_manager import SkillManager -from core.workflow.enums import NodeType from models.model import App from services.app_asset_service import AppAssetService diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 7c03ceed5b..eb78be8f88 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -10,11 +10,11 @@ from sqlalchemy.orm import Session from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index c32157919b..dc883f0daa 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -7,7 +7,6 @@ from httpx import get from sqlalchemy import select from core.entities.provider_entities import ProviderConfig -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -21,6 +20,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index ff0b276f77..101b2fe5a2 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -5,7 +5,6 @@ from datetime import datetime from sqlalchemy import or_, select from sqlalchemy.orm import Session -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration @@ -13,6 +12,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.model import App from models.tools import WorkflowToolProvider diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index b49d14f860..8389ccbb34 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -7,9 +7,9 @@ from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session -from core.workflow.nodes import NodeType -from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig -from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError +from dify_graph.nodes import NodeType +from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig +from dify_graph.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 7f12c2e19c..f1f0d0ea84 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -16,8 +16,8 @@ from core.trigger.debug.events import PluginTriggerDebugEvent from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription -from core.workflow.enums import NodeType -from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData +from dify_graph.enums import NodeType +from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client from models.model import App diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 75a1350e60..285645edce 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -16,9 +16,9 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.tool_file_manager import ToolFileManager -from core.workflow.enums import NodeType -from core.workflow.file.models import FileTransferMethod -from core.workflow.variables.types import SegmentType +from dify_graph.enums import NodeType +from dify_graph.file.models import FileTransferMethod +from dify_graph.variables.types import SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index e641f68ca7..9cfdf55eda 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -6,10 +6,10 @@ from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload from configs import dify_config -from core.model_runtime.entities import PromptMessage -from core.workflow.file.models import File -from core.workflow.nodes.variable_assigner.common.helpers import UpdatedVariable -from core.workflow.variables.segments import ( +from dify_graph.file.models import File +from dify_graph.model_runtime.entities import PromptMessage +from dify_graph.nodes.variable_assigner.common.helpers import UpdatedVariable +from dify_graph.variables.segments import ( ArrayFileSegment, ArraySegment, BooleanSegment, @@ -21,7 +21,7 @@ from core.workflow.variables.segments import ( Segment, StringSegment, ) -from core.workflow.variables.utils import dumps_with_segments +from dify_graph.variables.utils import dumps_with_segments _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index f1fa33cb75..73bb46b797 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,7 +1,6 @@ import logging from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType @@ -9,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding diff --git a/api/services/workflow/nested_node_graph_service.py b/api/services/workflow/nested_node_graph_service.py index c30aab4331..fac463e0b1 100644 --- a/api/services/workflow/nested_node_graph_service.py +++ b/api/services/workflow/nested_node_graph_service.py @@ -7,10 +7,10 @@ extracting values from list[PromptMessage] variables. from typing import Any +from core.workflow.enums import NodeType from sqlalchemy.orm import Session from core.model_runtime.entities import LLMMode -from core.workflow.enums import NodeType from services.model_provider_service import ModelProviderService from services.workflow.entities import NestedNodeGraphRequest, NestedNodeGraphResponse, NestedNodeParameterSchema diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 5527c108a2..0153046acc 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,13 +13,13 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManage from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.helper import encrypter -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.file.models import FileUploadConfig -from core.workflow.nodes import NodeType -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.file.models import FileUploadConfig +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.nodes import NodeType +from dify_graph.variables.input_entities import VariableEntity from events.app_event import app_was_created from extensions.ext_database import db from models import Account diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index efc76c33bc..7147fe1eab 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -6,8 +6,8 @@ from typing import Any from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session -from core.workflow.enums import WorkflowExecutionStatus -from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from dify_graph.enums import WorkflowExecutionStatus +from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog from services.plugin.plugin_service import PluginService @@ -132,7 +132,14 @@ class WorkflowAppService: ), ) if created_by_account: - account = session.scalar(select(Account).where(Account.email == created_by_account)) + account = session.scalar( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where( + Account.email == created_by_account, + TenantAccountJoin.tenant_id == app_model.tenant_id, + ) + ) if not account: raise ValueError(f"Account not found: {created_by_account}") diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 18ad6c5c16..b6f6fc5490 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -14,20 +14,20 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import SystemVariableKey -from core.workflow.file.models import File -from core.workflow.nodes import NodeType -from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables -from core.workflow.variable_loader import VariableLoader -from core.workflow.variables import Segment, StringSegment, VariableBase -from core.workflow.variables.consts import SELECTORS_LENGTH -from core.workflow.variables.segments import ( +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.enums import SystemVariableKey +from dify_graph.file.models import File +from dify_graph.nodes import NodeType +from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables +from dify_graph.variable_loader import VariableLoader +from dify_graph.variables import Segment, StringSegment, VariableBase +from dify_graph.variables.consts import SELECTORS_LENGTH +from dify_graph.variables.segments import ( ArrayFileSegment, FileSegment, ) -from core.workflow.variables.types import SegmentType -from core.workflow.variables.utils import dumps_with_segments +from dify_graph.variables.types import SegmentType +from dify_graph.variables.utils import dumps_with_segments from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable @@ -70,7 +70,7 @@ class UpdateNotSupportedError(WorkflowDraftVariableError): class DraftVarLoader(VariableLoader): # This implements the VariableLoader interface for loading draft variables. # - # ref: core.workflow.variable_loader.VariableLoader + # ref: dify_graph.variable_loader.VariableLoader # Database engine used for loading variables. _engine: Engine diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 09037a92ce..8f323ebb8b 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -22,10 +22,10 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.entities import WorkflowStartReason -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from core.workflow.runtime import GraphRuntimeState -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowStartReason +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from dify_graph.runtime import GraphRuntimeState +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b347917bb0..9bc1371895 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,37 +11,37 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.entities import GraphInitParams, WorkflowNodeExecution -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.file import File -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes import NodeType -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from core.workflow.nodes.human_input.entities import ( +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities import GraphInitParams, WorkflowNodeExecution +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.file import File +from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes import NodeType +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient, validate_human_input_submission, ) -from core.workflow.nodes.human_input.enums import HumanInputFormKind -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.repositories.human_input_form_repository import FormCreateParams -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import load_into_variable_pool -from core.workflow.variables import VariableBase -from core.workflow.variables.input_entities import VariableEntityType -from core.workflow.variables.variables import Variable -from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.nodes.human_input.enums import HumanInputFormKind +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.repositories.human_input_form_repository import FormCreateParams +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import load_into_variable_pool +from dify_graph.variables import VariableBase +from dify_graph.variables.input_entities import VariableEntityType +from dify_graph.variables.variables import Variable from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -49,7 +49,6 @@ from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import UserFrom from models.human_input import HumanInputFormRecipient, RecipientType from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -515,8 +514,8 @@ class WorkflowService: """ try: from core.model_manager import ModelManager - from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager + from dify_graph.model_runtime.entities.model_entities import ModelType # Get model instance to validate provider+model combination model_manager = ModelManager() @@ -635,8 +634,8 @@ class WorkflowService: :return: True if load balancing is enabled, False otherwise """ try: - from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager + from dify_graph.model_runtime.entities.model_entities import ModelType # Get provider configurations provider_manager = ProviderManager() @@ -1108,7 +1107,7 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) params = FormCreateParams( app_id=app_model.id, workflow_execution_id=None, @@ -1156,13 +1155,15 @@ class WorkflowService: variable_pool: VariablePool, ) -> HumanInputNode: graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, workflow_id=workflow.id, graph_config=workflow.graph_dict, - user_id=account.id, - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.DEBUGGER.value, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=account.id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -1174,6 +1175,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), ) return node @@ -1457,7 +1459,7 @@ class WorkflowService: Raises: ValueError: If the node data format is invalid """ - from core.workflow.nodes.human_input.entities import HumanInputNodeData + from dify_graph.nodes.human_input.entities import HumanInputNodeData try: HumanInputNodeData.model_validate(node_data) diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index e58d334f41..174aa50343 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -21,7 +21,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db from libs.flask_utils import set_login_user from models.account import Account @@ -321,7 +321,13 @@ def _resume_app_execution(payload: dict[str, Any]) -> None: return message = session.scalar( - select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc()) + select(Message) + .where( + Message.conversation_id == conversation.id, + Message.workflow_run_id == workflow_run_id, + ) + .order_by(Message.created_at.desc()) + .limit(1) ) if message is None: logger.warning("Message not found for workflow run %s", workflow_run_id) diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index cc96542d4b..d247cf5cf7 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -21,7 +21,7 @@ from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index f69f17b16d..49dee00919 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,7 +11,7 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index 5413a33d6a..dd3b6a4530 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -7,8 +7,8 @@ from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import ensure_naive_utc, naive_utc_now @@ -58,7 +58,7 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None: """Scan for expired human input forms and resume or end workflows.""" session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - form_repo = HumanInputFormSubmissionRepository(session_factory) + form_repo = HumanInputFormSubmissionRepository() service = HumanInputService(session_factory, form_repository=form_repo) now = naive_utc_now() global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index d1cd0fbadc..bded4cea2b 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -11,8 +11,8 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod -from core.workflow.runtime import GraphRuntimeState, VariablePool +from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod +from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_database import db from extensions.ext_mail import mail from models.human_input import ( diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index d18ea2c23c..d06b8c980b 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -25,8 +25,8 @@ from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager -from core.workflow.enums import NodeType, WorkflowExecutionStatus -from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData +from dify_graph.enums import NodeType, WorkflowExecutionStatus +from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType, unlimited from models.enums import ( AppTriggerType, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 3b3c6e5313..db8721e90b 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -12,8 +12,8 @@ from celery import shared_task from sqlalchemy import select from core.db.session_factory import session_factory -from core.workflow.entities.workflow_execution import WorkflowExecution -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities.workflow_execution import WorkflowExecution +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index b30a4ff15b..3f607dc55e 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -12,10 +12,10 @@ from celery import shared_task from sqlalchemy import select from core.db.session_factory import session_factory -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 8c64d3ab27..ced7ef973b 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -3,7 +3,7 @@ import logging from celery import shared_task from core.db.session_factory import session_factory -from core.workflow.nodes.trigger_schedule.exc import ( +from dify_graph.nodes.trigger_schedule.exc import ( ScheduleExecutionError, ScheduleNotFoundError, TenantOwnerNotFoundError, diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 003bb356e5..4fdbb7d9f3 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -2,7 +2,7 @@ from collections.abc import Generator from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from core.workflow.node_events import StreamCompletedEvent +from dify_graph.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 909d6377ce..c043c7dc10 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,6 +1,6 @@ -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent -from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult, StreamCompletedEvent +from dify_graph.nodes.datasource.datasource_node import DatasourceNode class _Seg: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index 16a66bc3f1..b4e3a0e4de 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -6,7 +6,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db from factories.file_factory import StorageKeyLoader from models import ToolFile, UploadFile diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index 5012defdad..4e184c93fd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,20 +4,27 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + # import monkeypatch -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ( +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ( AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey, ModelType, ) -from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity class MockModelClass(PluginModelClient): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 5faa002fff..7c4dcda2dc 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -6,11 +6,11 @@ import pytest from sqlalchemy import delete from sqlalchemy.orm import Session -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.nodes import NodeType -from core.workflow.variables.segments import StringSegment -from core.workflow.variables.types import SegmentType -from core.workflow.variables.variables import StringVariable +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.nodes import NodeType +from dify_graph.variables.segments import StringSegment +from dify_graph.variables.types import SegmentType +from dify_graph.variables.variables import StringVariable from extensions.ext_database import db from extensions.ext_storage import storage from factories.variable_factory import build_segment diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index a259ccb2b9..988313e68d 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -5,7 +5,7 @@ import pytest from sqlalchemy import delete from core.db.session_factory import session_factory -from core.workflow.variables.segments import StringSegment +from dify_graph.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -191,7 +191,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant - from core.workflow.variables.types import SegmentType + from dify_graph.variables.types import SegmentType from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -422,7 +422,7 @@ class TestDeleteDraftVariablesSessionCommit: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" - from core.workflow.variables.types import SegmentType + from dify_graph.variables.types import SegmentType from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index cdecdf41d2..5b0f86fed1 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e0ea14b789..f8b7f95493 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -4,18 +4,17 @@ import uuid import pytest from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.code.code_node import CodeNode +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +from tests.workflow_test_utils import build_test_graph_init_params CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH @@ -32,11 +31,11 @@ def init_code_node(code_config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index e0f2363799..f691113511 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,19 +5,18 @@ from urllib.parse import urlencode import pytest from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.file.file_manager import file_manager -from core.workflow.graph import Graph -from core.workflow.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.file.file_manager import file_manager +from dify_graph.graph import Graph +from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -42,11 +41,11 @@ def init_http_node(config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -190,15 +189,15 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from core.workflow.nodes.http_request.entities import ( + from dify_graph.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) - from core.workflow.nodes.http_request.exc import AuthorizationConfigError - from core.workflow.nodes.http_request.executor import Executor - from core.workflow.runtime import VariablePool - from core.workflow.system_variable import SystemVariable + from dify_graph.nodes.http_request.exc import AuthorizationConfigError + from dify_graph.nodes.http_request.executor import Executor + from dify_graph.runtime import VariablePool + from dify_graph.system_variable import SystemVariable # Create variable pool variable_pool = VariablePool( @@ -686,11 +685,11 @@ def test_nested_object_variable_selector(setup_http_mock): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index b5b0fb5334..b4779ebcdd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,18 +4,17 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.node_events import StreamCompletedEvent -from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.node_events import StreamCompletedEvent +from dify_graph.nodes.llm.node import LLMNode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -38,11 +37,11 @@ def init_llm_node(config: dict) -> LLMNode: workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" - init_params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, + init_params = build_test_graph_init_params( workflow_id=workflow_id, graph_config=graph_config, + tenant_id=tenant_id, + app_id=app_id, user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -113,8 +112,8 @@ def test_execute_llm(): from decimal import Decimal from unittest.mock import MagicMock - from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from core.model_runtime.entities.message_entities import AssistantPromptMessage + from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -158,7 +157,7 @@ def test_execute_llm(): # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_1(**_kwargs): - from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), @@ -229,8 +228,8 @@ def test_execute_llm_with_jinja2(): from decimal import Decimal from unittest.mock import MagicMock - from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from core.model_runtime.entities.message_entities import AssistantPromptMessage + from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -274,7 +273,7 @@ def test_execute_llm_with_jinja2(): # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_2(**_kwargs): - from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 773074e92d..62d9af0196 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,18 +3,17 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db -from models.enums import UserFrom from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance +from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock @@ -44,11 +43,11 @@ def init_parameter_extractor_node(config: dict, memory=None): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index bc03ce1b96..970e2cae00 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,22 +1,30 @@ import time import uuid -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom -from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError +from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params -@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_code(setup_code_executor_mock): +class _SimpleJinja2Renderer: + """Minimal Jinja2-based renderer for integration tests (no code executor).""" + + def render_template(self, template: str, variables: dict[str, object]) -> str: + from jinja2 import Template + + try: + return Template(template).render(**variables) + except Exception as exc: + raise TemplateRenderError(str(exc)) from exc + + +def test_execute_template_transform(): code = """{{args2}}""" config = { "id": "1", @@ -45,11 +53,11 @@ def test_execute_code(setup_code_executor_mock): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -68,19 +76,21 @@ def test_execute_code(setup_code_executor_mock): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - # Create node factory + # Create node factory (graph init path still works regardless of renderer choice below) node_factory = DifyNodeFactory( graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + assert graph is not None node = TemplateTransformNode( id=str(uuid.uuid4()), config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + template_renderer=_SimpleJinja2Renderer(), ) # execute node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index cfbef52c93..f70bf46979 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,17 +2,16 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.node_events import StreamCompletedEvent -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.node_events import StreamCompletedEvent +from dify_graph.nodes.tool.tool_node import ToolNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params def init_tool_node(config: dict): @@ -27,11 +26,11 @@ def init_tool_node(config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 7fad603a6d..6f2e008d44 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index dcf31aeca7..96fb7ea293 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -31,16 +31,16 @@ from core.app.layers.pause_state_persist_layer import ( PauseStatePersistenceLayer, WorkflowResumptionContext, ) -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.graph_engine.entities.commands import GraphEngineCommand -from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from core.workflow.graph_events.graph import GraphRunPausedEvent -from core.workflow.runtime.graph_runtime_state import GraphRuntimeState -from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from core.workflow.runtime.variable_pool import SystemVariable, VariablePool +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.graph_engine.entities.commands import GraphEngineCommand +from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from dify_graph.graph_events.graph import GraphRunPausedEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime.graph_runtime_state import GraphRuntimeState +from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from dify_graph.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime.variable_pool import SystemVariable, VariablePool from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account @@ -544,7 +544,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from core.workflow.graph_events.graph import ( + from dify_graph.graph_events.graph import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 4e6cc620ac..e5d3655771 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -5,7 +5,7 @@ import pytest from faker import Faker from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest +from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest from models.dataset import Dataset, Document from services.account_service import AccountService, TenantService diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 079e4934bb..9d0fad4b12 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -20,7 +20,7 @@ from core.workflow.nodes.human_input.entities import ( UserAction, WebAppDeliveryMethod, ) -from core.workflow.repositories.human_input_form_repository import FormCreateParams +from dify_graph.repositories.human_input_form_repository import FormCreateParams from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, @@ -100,7 +100,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,7 +129,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = _build_form_params( delivery_methods=[ _build_email_delivery( @@ -173,7 +173,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( app_id=str(uuid4()), @@ -210,7 +210,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = FormCreateParams( app_id=str(uuid4()), workflow_execution_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 06d55177eb..9733735df3 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -12,27 +12,27 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowType -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.enums import WorkflowType +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, IconType from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun +from tests.workflow_test_utils import build_test_graph_init_params def _mock_form_repository_without_submission() -> HumanInputFormRepository: @@ -87,11 +87,11 @@ def _build_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, + params = build_test_graph_init_params( workflow_id=workflow_id, graph_config=graph_config, + tenant_id=tenant_id, + app_id=app_id, user_id=user_id, user_from="account", invoke_from="debugger", diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 3568a8b070..cb7cd37a3f 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -6,7 +6,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db from factories.file_factory import StorageKeyLoader from models import ToolFile, UploadFile diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index 19d7772c39..573f84cb0b 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta from decimal import Decimal from uuid import uuid4 -from core.workflow.nodes.human_input.entities import FormDefinition, UserAction +from dify_graph.nodes.human_input.entities import FormDefinition, UserAction from models.account import Account, Tenant, TenantAccountJoin from models.execution_extra_content import HumanInputContent from models.human_input import HumanInputForm, HumanInputFormStatus diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 556c029b24..458862b0ec 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -8,7 +8,7 @@ from uuid import uuid4 from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 05a868c0c2..76e586e65f 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -11,9 +11,9 @@ import pytest from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from core.workflow.entities import WorkflowExecution -from core.workflow.entities.pause_reason import PauseReasonType -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.entities import WorkflowExecution +from dify_graph.entities.pause_reason import PauseReasonType +from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom diff --git a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py index 73df2d9ed9..191c161613 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py +++ b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py @@ -9,8 +9,8 @@ from itertools import starmap from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from services.dataset_service import DatasetCollectionBindingService @@ -28,6 +28,7 @@ class DatasetCollectionBindingTestDataFactory: @staticmethod def create_collection_binding( + db_session_with_containers: Session, provider_name: str = "openai", model_name: str = "text-embedding-ada-002", collection_name: str = "collection-abc", @@ -51,8 +52,8 @@ class DatasetCollectionBindingTestDataFactory: collection_name=collection_name, type=collection_type, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return binding @@ -64,7 +65,7 @@ class TestDatasetCollectionBindingServiceGetBinding: including various provider/model combinations, collection types, and edge cases. """ - def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers: Session): """ Test successful retrieval of an existing collection binding. @@ -77,6 +78,7 @@ class TestDatasetCollectionBindingServiceGetBinding: model_name = "text-embedding-ada-002" collection_type = "dataset" existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name=provider_name, model_name=model_name, collection_name="existing-collection", @@ -92,7 +94,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.id == existing_binding.id assert result.collection_name == "existing-collection" - def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers: Session): """ Test successful creation of a new collection binding when none exists. @@ -116,7 +118,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.type == collection_type assert result.collection_name is not None - def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers: Session): """Test get_dataset_collection_binding with different collection type.""" # Arrange provider_name = "openai" @@ -133,7 +135,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.provider_name == provider_name assert result.model_name == model_name - def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers: Session): """Test get_dataset_collection_binding with default collection type parameter.""" # Arrange provider_name = "openai" @@ -147,7 +149,9 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.provider_name == provider_name assert result.model_name == model_name - def test_get_dataset_collection_binding_different_provider_model_combination(self, db_session_with_containers): + def test_get_dataset_collection_binding_different_provider_model_combination( + self, db_session_with_containers: Session + ): """Test get_dataset_collection_binding with various provider/model combinations.""" # Arrange combinations = [ @@ -174,10 +178,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: including successful retrieval and error handling for missing bindings. """ - def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers: Session): """Test successful retrieval of collection binding by ID and type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -194,7 +199,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.collection_name == "test-collection" assert result.type == "dataset" - def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session): """Test error handling when collection binding is not found by ID and type.""" # Arrange non_existent_id = str(uuid4()) @@ -203,10 +208,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: with pytest.raises(ValueError, match="Dataset collection binding not found"): DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset") - def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_different_collection_type( + self, db_session_with_containers: Session + ): """Test retrieval by ID and type with different collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -222,10 +230,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.id == binding.id assert result.type == "custom_type" - def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_default_collection_type( + self, db_session_with_containers: Session + ): """Test retrieval by ID with default collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -239,10 +250,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.id == binding.id assert result.type == "dataset" - def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session): """Test error when binding exists but with wrong collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 9871ef37e6..4b98bddd26 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -10,9 +10,9 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.model import App @@ -27,6 +27,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL, tenant: Tenant | None = None, ) -> tuple[Account, Tenant]: @@ -37,13 +38,13 @@ class DatasetUpdateDeleteTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() if tenant is None: tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -51,14 +52,15 @@ class DatasetUpdateDeleteTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -78,12 +80,12 @@ class DatasetUpdateDeleteTestDataFactory: retrieval_model={"top_k": 2}, enable_api=enable_api, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_app(tenant_id: str, created_by: str, name: str = "Test App") -> App: + def create_app(db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test App") -> App: """Create a real app for AppDatasetJoin.""" app = App( tenant_id=tenant_id, @@ -96,16 +98,16 @@ class DatasetUpdateDeleteTestDataFactory: enable_api=True, created_by=created_by, ) - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app @staticmethod - def create_app_dataset_join(app_id: str, dataset_id: str) -> AppDatasetJoin: + def create_app_dataset_join(db_session_with_containers: Session, app_id: str, dataset_id: str) -> AppDatasetJoin: """Create a real AppDatasetJoin record.""" join = AppDatasetJoin(app_id=app_id, dataset_id=dataset_id) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return join @@ -114,7 +116,7 @@ class TestDatasetServiceDeleteDataset: Comprehensive integration tests for DatasetService.delete_dataset method. """ - def test_delete_dataset_success(self, db_session_with_containers): + def test_delete_dataset_success(self, db_session_with_containers: Session): """ Test successful deletion of a dataset. @@ -130,8 +132,10 @@ class TestDatasetServiceDeleteDataset: - Method returns True """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted: @@ -139,10 +143,10 @@ class TestDatasetServiceDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None mock_dataset_was_deleted.send.assert_called_once_with(dataset) - def test_delete_dataset_not_found(self, db_session_with_containers): + def test_delete_dataset_not_found(self, db_session_with_containers: Session): """ Test handling when dataset is not found. @@ -156,7 +160,9 @@ class TestDatasetServiceDeleteDataset: - No database operations are performed """ # Arrange - owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) dataset_id = str(uuid4()) # Act @@ -165,7 +171,7 @@ class TestDatasetServiceDeleteDataset: # Assert assert result is False - def test_delete_dataset_permission_denied_error(self, db_session_with_containers): + def test_delete_dataset_permission_denied_error(self, db_session_with_containers: Session): """ Test error handling when user lacks permission. @@ -178,19 +184,22 @@ class TestDatasetServiceDeleteDataset: - No database operations are performed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) normal_user, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL, tenant=tenant, ) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act & Assert with pytest.raises(NoPermissionError): DatasetService.delete_dataset(dataset.id, normal_user) # Verify no deletion was attempted - assert db.session.get(Dataset, dataset.id) is not None + assert db_session_with_containers.get(Dataset, dataset.id) is not None class TestDatasetServiceDatasetUseCheck: @@ -198,7 +207,7 @@ class TestDatasetServiceDatasetUseCheck: Comprehensive integration tests for DatasetService.dataset_use_check method. """ - def test_dataset_use_check_in_use(self, db_session_with_containers): + def test_dataset_use_check_in_use(self, db_session_with_containers: Session): """ Test detection when dataset is in use. @@ -211,10 +220,12 @@ class TestDatasetServiceDatasetUseCheck: - Database query is executed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) - app = DatasetUpdateDeleteTestDataFactory.create_app(tenant.id, owner.id) - DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(app.id, dataset.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + app = DatasetUpdateDeleteTestDataFactory.create_app(db_session_with_containers, tenant.id, owner.id) + DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(db_session_with_containers, app.id, dataset.id) # Act result = DatasetService.dataset_use_check(dataset.id) @@ -222,7 +233,7 @@ class TestDatasetServiceDatasetUseCheck: # Assert assert result is True - def test_dataset_use_check_not_in_use(self, db_session_with_containers): + def test_dataset_use_check_not_in_use(self, db_session_with_containers: Session): """ Test detection when dataset is not in use. @@ -235,8 +246,10 @@ class TestDatasetServiceDatasetUseCheck: - Database query is executed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act result = DatasetService.dataset_use_check(dataset.id) @@ -250,7 +263,7 @@ class TestDatasetServiceUpdateDatasetApiStatus: Comprehensive integration tests for DatasetService.update_dataset_api_status method. """ - def test_update_dataset_api_status_enable_success(self, db_session_with_containers): + def test_update_dataset_api_status_enable_success(self, db_session_with_containers: Session): """ Test successful enabling of dataset API access. @@ -264,8 +277,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - Transaction is committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=False + ) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) # Act @@ -276,12 +293,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, True) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.enable_api is True assert dataset.updated_by == owner.id assert dataset.updated_at == current_time - def test_update_dataset_api_status_disable_success(self, db_session_with_containers): + def test_update_dataset_api_status_disable_success(self, db_session_with_containers: Session): """ Test successful disabling of dataset API access. @@ -295,8 +312,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - Transaction is committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=True) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=True + ) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) # Act @@ -307,11 +328,11 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, False) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.enable_api is False assert dataset.updated_by == owner.id - def test_update_dataset_api_status_not_found_error(self, db_session_with_containers): + def test_update_dataset_api_status_not_found_error(self, db_session_with_containers: Session): """ Test error handling when dataset is not found. @@ -330,7 +351,7 @@ class TestDatasetServiceUpdateDatasetApiStatus: with pytest.raises(NotFound, match="Dataset not found"): DatasetService.update_dataset_api_status(dataset_id, True) - def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers): + def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers: Session): """ Test error handling when current_user is missing. @@ -343,8 +364,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - No updates are committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=False + ) # Act & Assert with ( @@ -354,6 +379,6 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, True) # Verify no commit was attempted - db.session.rollback() - db.session.refresh(dataset) + db_session_with_containers.rollback() + db_session_with_containers.refresh(dataset) assert dataset.enable_api is False diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 606e7e0b57..8595f5bf14 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -45,7 +46,7 @@ class TestAccountService: "passport_service": mock_passport_service, } - def test_create_account_and_login(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_and_login(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account creation and login with correct password. """ @@ -70,7 +71,9 @@ class TestAccountService: logged_in = AccountService.authenticate(email, password) assert logged_in.id == account.id - def test_create_account_without_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_without_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation without password (for OAuth users). """ @@ -92,7 +95,7 @@ class TestAccountService: assert account.password_salt is None def test_create_account_password_invalid_new_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account create with invalid new password format. @@ -113,7 +116,9 @@ class TestAccountService: password="invalid_new_password", ) - def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_registration_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation when registration is disabled. """ @@ -131,7 +136,9 @@ class TestAccountService: password=fake.password(length=12), ) - def test_create_account_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_email_in_freeze( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation when email is in freeze period. """ @@ -154,7 +161,9 @@ class TestAccountService: dify_config.BILLING_ENABLED = False # Reset config for other tests - def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with non-existent account. """ @@ -164,7 +173,7 @@ class TestAccountService: with pytest.raises(AccountPasswordError): AccountService.authenticate(email, password) - def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with banned account. """ @@ -186,14 +195,13 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(AccountLoginError): AccountService.authenticate(email, password) - def test_authenticate_wrong_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_wrong_password(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with wrong password. """ @@ -217,7 +225,9 @@ class TestAccountService: with pytest.raises(AccountPasswordError): AccountService.authenticate(email, wrong_password) - def test_authenticate_with_invite_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_with_invite_token( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with invite token to set password for account without password. """ @@ -249,7 +259,7 @@ class TestAccountService: assert authenticated_account.password_salt is not None def test_authenticate_pending_account_activation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test authentication activates pending account. @@ -270,16 +280,17 @@ class TestAccountService: password=password, ) account.status = AccountStatus.PENDING - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Authenticate should activate the account authenticated_account = AccountService.authenticate(email, password) assert authenticated_account.status == AccountStatus.ACTIVE assert authenticated_account.initialized_at is not None - def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_password_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful password update. """ @@ -308,7 +319,7 @@ class TestAccountService: assert authenticated_account.id == account.id def test_update_account_password_wrong_current_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test password update with wrong current password. @@ -335,7 +346,7 @@ class TestAccountService: AccountService.update_account_password(account, wrong_password, new_password) def test_update_account_password_invalid_new_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test password update with invalid new password format. @@ -360,7 +371,7 @@ class TestAccountService: with pytest.raises(ValueError): # Password validation error AccountService.update_account_password(account, old_password, "123") - def test_create_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account creation with automatic tenant creation. """ @@ -387,14 +398,13 @@ class TestAccountService: assert account.email == email # Verify tenant was created and linked - from extensions.ext_database import db - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" def test_create_account_and_tenant_workspace_creation_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account creation when workspace creation is disabled. @@ -419,7 +429,7 @@ class TestAccountService: ) def test_create_account_and_tenant_workspace_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account creation when workspace limit is exceeded. @@ -446,7 +456,9 @@ class TestAccountService: password=password, ) - def test_link_account_integrate_new_provider(self, db_session_with_containers, mock_external_service_dependencies): + def test_link_account_integrate_new_provider( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test linking account with new OAuth provider. """ @@ -469,15 +481,18 @@ class TestAccountService: AccountService.link_account_integrate("new-google", "google_open_id_123", account) # Verify integration was created - from extensions.ext_database import db from models import AccountIntegrate - integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="new-google").first() + integration = ( + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider="new-google") + .first() + ) assert integration is not None assert integration.open_id == "google_open_id_123" def test_link_account_integrate_existing_provider( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test linking account with existing provider (should update). @@ -504,15 +519,16 @@ class TestAccountService: AccountService.link_account_integrate("exists-google", "google_open_id_456", account) # Verify integration was updated - from extensions.ext_database import db from models import AccountIntegrate integration = ( - db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="exists-google").first() + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider="exists-google") + .first() ) assert integration.open_id == "google_open_id_456" - def test_close_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_close_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test closing an account. """ @@ -536,12 +552,11 @@ class TestAccountService: AccountService.close_account(account) # Verify account status changed - from extensions.ext_database import db - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.status == AccountStatus.CLOSED - def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating account fields. """ @@ -568,7 +583,9 @@ class TestAccountService: assert updated_account.name == updated_name assert updated_account.interface_theme == "dark" - def test_update_account_invalid_field(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_invalid_field( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating account with invalid field. """ @@ -591,7 +608,7 @@ class TestAccountService: with pytest.raises(AttributeError): AccountService.update_account(account, invalid_field="value") - def test_update_login_info(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_login_info(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating login information. """ @@ -616,13 +633,12 @@ class TestAccountService: AccountService.update_login_info(account, ip_address=ip_address) # Verify login info was updated - from extensions.ext_database import db - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.last_login_ip == ip_address assert account.last_login_at is not None - def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful login with token generation. """ @@ -659,7 +675,9 @@ class TestAccountService: assert call_args["iss"] is not None assert call_args["sub"] == "Console API Passport" - def test_login_pending_account_activation(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_pending_account_activation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test login activates pending account. """ @@ -680,17 +698,16 @@ class TestAccountService: password=password, ) account.status = AccountStatus.PENDING - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Login should activate the account token_pair = AccountService.login(account) - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.status == AccountStatus.ACTIVE - def test_logout(self, db_session_with_containers, mock_external_service_dependencies): + def test_logout(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test logout functionality. """ @@ -723,7 +740,7 @@ class TestAccountService: refresh_token_key = f"account_refresh_token:{account.id}" assert redis_client.get(refresh_token_key) is None - def test_refresh_token_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful token refresh. """ @@ -757,7 +774,7 @@ class TestAccountService: assert new_token_pair.access_token == "new_mock_access_token" assert new_token_pair.refresh_token != initial_token_pair.refresh_token - def test_refresh_token_invalid_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_invalid_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test refresh token with invalid token. """ @@ -766,7 +783,9 @@ class TestAccountService: with pytest.raises(ValueError, match="Invalid refresh token"): AccountService.refresh_token(invalid_token) - def test_refresh_token_invalid_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_invalid_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test refresh token with valid token but invalid account. """ @@ -791,16 +810,15 @@ class TestAccountService: token_pair = AccountService.login(account) # Delete account - from extensions.ext_database import db - db.session.delete(account) - db.session.commit() + db_session_with_containers.delete(account) + db_session_with_containers.commit() # Try to refresh token with deleted account with pytest.raises(ValueError, match="Invalid account"): AccountService.refresh_token(token_pair.refresh_token) - def test_load_user_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading user by ID successfully. """ @@ -830,7 +848,7 @@ class TestAccountService: assert loaded_user.id == account.id assert loaded_user.email == account.email - def test_load_user_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading non-existent user. """ @@ -839,7 +857,7 @@ class TestAccountService: loaded_user = AccountService.load_user(non_existent_user_id) assert loaded_user is None - def test_load_user_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading banned user raises Unauthorized. """ @@ -861,14 +879,13 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception AccountService.load_user(account.id) - def test_get_account_jwt_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_account_jwt_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test JWT token generation for account. """ @@ -902,7 +919,7 @@ class TestAccountService: assert call_args["iss"] is not None assert call_args["sub"] == "Console API Passport" - def test_load_logged_in_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_logged_in_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading logged in account by ID. """ @@ -931,7 +948,9 @@ class TestAccountService: assert loaded_account is not None assert loaded_account.id == account.id - def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through email successfully. """ @@ -957,7 +976,9 @@ class TestAccountService: assert found_user is not None assert found_user.id == account.id - def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through non-existent email. """ @@ -968,7 +989,7 @@ class TestAccountService: assert found_user is None def test_get_user_through_email_banned_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting banned user through email raises Unauthorized. @@ -991,14 +1012,15 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception AccountService.get_user_through_email(email) - def test_get_user_through_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_in_freeze( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through email that is in freeze period. """ @@ -1014,7 +1036,7 @@ class TestAccountService: # Reset config dify_config.BILLING_ENABLED = False - def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account deletion (should add task to queue and sync to enterprise). """ @@ -1050,7 +1072,7 @@ class TestAccountService: mock_delete_task.delay.assert_called_once_with(account.id) def test_generate_account_deletion_verification_code( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generating account deletion verification code. @@ -1079,7 +1101,9 @@ class TestAccountService: assert len(code) == 6 assert code.isdigit() - def test_verify_account_deletion_code_valid(self, db_session_with_containers, mock_external_service_dependencies): + def test_verify_account_deletion_code_valid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test verifying valid account deletion code. """ @@ -1106,7 +1130,9 @@ class TestAccountService: is_valid = AccountService.verify_account_deletion_code(token, code) assert is_valid is True - def test_verify_account_deletion_code_invalid(self, db_session_with_containers, mock_external_service_dependencies): + def test_verify_account_deletion_code_invalid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test verifying invalid account deletion code. """ @@ -1135,7 +1161,7 @@ class TestAccountService: assert is_valid is False def test_verify_account_deletion_code_invalid_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test verifying account deletion code with invalid token. @@ -1167,7 +1193,7 @@ class TestTenantService: "billing_service": mock_billing_service, } - def test_create_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tenant creation with default settings. """ @@ -1187,7 +1213,7 @@ class TestTenantService: assert tenant.encrypt_public_key is not None def test_create_tenant_workspace_creation_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant creation when workspace creation is disabled. @@ -1202,7 +1228,9 @@ class TestTenantService: with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception TenantService.create_tenant(name=tenant_name) - def test_create_tenant_with_custom_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_with_custom_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant creation with custom name and setup flag. """ @@ -1221,7 +1249,9 @@ class TestTenantService: assert tenant.status == "normal" assert tenant.encrypt_public_key is not None - def test_create_tenant_member_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful tenant member creation. """ @@ -1251,7 +1281,9 @@ class TestTenantService: assert tenant_member.account_id == account.id assert tenant_member.role == "admin" - def test_create_tenant_member_duplicate_owner(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_duplicate_owner( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test creating duplicate owner for a tenant (should fail). """ @@ -1290,7 +1322,9 @@ class TestTenantService: with pytest.raises(Exception, match="Tenant already has an owner"): TenantService.create_tenant_member(tenant, account2, role="owner") - def test_create_tenant_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_existing_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating role for existing tenant member. """ @@ -1323,7 +1357,7 @@ class TestTenantService: assert tenant_member2.account_id == tenant_member1.account_id assert tenant_member2.role == "editor" - def test_get_join_tenants_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_join_tenants_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting join tenants for an account. """ @@ -1361,7 +1395,7 @@ class TestTenantService: assert tenant2_name in tenant_names def test_get_current_tenant_by_account_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting current tenant by account successfully. @@ -1388,9 +1422,8 @@ class TestTenantService: # Add account to tenant and set as current TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Get current tenant current_tenant = TenantService.get_current_tenant_by_account(account) @@ -1400,7 +1433,7 @@ class TestTenantService: assert current_tenant.role == "owner" def test_get_current_tenant_by_account_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting current tenant when account has no current tenant. @@ -1426,7 +1459,7 @@ class TestTenantService: with pytest.raises((AttributeError, TenantNotFoundError)): TenantService.get_current_tenant_by_account(account) - def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tenant switching. """ @@ -1457,18 +1490,17 @@ class TestTenantService: # Set initial current tenant account.current_tenant = tenant1 - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Switch to second tenant TenantService.switch_tenant(account, tenant2.id) # Verify tenant was switched - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.current_tenant_id == tenant2.id - def test_switch_tenant_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tenant switching without providing tenant ID. """ @@ -1493,7 +1525,9 @@ class TestTenantService: with pytest.raises(ValueError, match="Tenant ID must be provided"): TenantService.switch_tenant(account, None) - def test_switch_tenant_account_not_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_account_not_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test switching to a tenant where account is not a member. """ @@ -1520,7 +1554,7 @@ class TestTenantService: with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): TenantService.switch_tenant(account, tenant.id) - def test_has_roles_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_has_roles_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test checking if tenant has specific roles. """ @@ -1570,7 +1604,7 @@ class TestTenantService: has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) assert has_normal is False - def test_has_roles_invalid_role_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_has_roles_invalid_role_type(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test checking roles with invalid role type. """ @@ -1589,7 +1623,7 @@ class TestTenantService: with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): TenantService.has_roles(tenant, [invalid_role]) - def test_get_user_role_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting user role in a tenant. """ @@ -1620,7 +1654,9 @@ class TestTenantService: assert user_role == "editor" - def test_check_member_permission_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_member_permission_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test checking member permission successfully. """ @@ -1660,7 +1696,7 @@ class TestTenantService: TenantService.check_member_permission(tenant, owner_account, member_account, "add") def test_check_member_permission_invalid_action( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test checking member permission with invalid action. @@ -1692,7 +1728,9 @@ class TestTenantService: with pytest.raises(Exception, match="Invalid action"): TenantService.check_member_permission(tenant, account, None, invalid_action) - def test_check_member_permission_operate_self(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_member_permission_operate_self( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test checking member permission when trying to operate self. """ @@ -1722,7 +1760,9 @@ class TestTenantService: with pytest.raises(Exception, match="Cannot operate self"): TenantService.check_member_permission(tenant, account, account, "remove") - def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_member_from_tenant_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful member removal from tenant (should sync to enterprise). """ @@ -1770,16 +1810,17 @@ class TestTenantService: ) # Verify member was removed - from extensions.ext_database import db from models.account import TenantAccountJoin member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert member_join is None def test_remove_member_from_tenant_operate_self( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test removing member when trying to operate self. @@ -1810,7 +1851,9 @@ class TestTenantService: with pytest.raises(Exception, match="Cannot operate self"): TenantService.remove_member_from_tenant(tenant, account, account) - def test_remove_member_from_tenant_not_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_member_from_tenant_not_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test removing member who is not in the tenant. """ @@ -1849,7 +1892,7 @@ class TestTenantService: with pytest.raises(Exception, match="Member not in tenant"): TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) - def test_update_member_role_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful member role update. """ @@ -1889,15 +1932,16 @@ class TestTenantService: TenantService.update_member_role(tenant, member_account, "admin", owner_account) # Verify role was updated - from extensions.ext_database import db from models.account import TenantAccountJoin member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert member_join.role == "admin" - def test_update_member_role_to_owner(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_to_owner(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating member role to owner (should change current owner to admin). """ @@ -1937,19 +1981,24 @@ class TestTenantService: TenantService.update_member_role(tenant, member_account, "owner", owner_account) # Verify roles were updated correctly - from extensions.ext_database import db from models.account import TenantAccountJoin owner_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=owner_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=owner_account.id) + .first() ) member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert owner_join.role == "admin" assert member_join.role == "owner" - def test_update_member_role_already_assigned(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_already_assigned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating member role to already assigned role. """ @@ -1989,7 +2038,7 @@ class TestTenantService: with pytest.raises(Exception, match="The provided role is already assigned to the member"): TenantService.update_member_role(tenant, member_account, "admin", owner_account) - def test_get_tenant_count_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_count_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting tenant count successfully. """ @@ -2014,7 +2063,7 @@ class TestTenantService: assert tenant_count >= 3 def test_create_owner_tenant_if_not_exist_new_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant for new user without existing tenants. @@ -2044,17 +2093,16 @@ class TestTenantService: TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" assert account.current_tenant is not None assert account.current_tenant.name == workspace_name def test_create_owner_tenant_if_not_exist_existing_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant when user already has a tenant. @@ -2083,20 +2131,19 @@ class TestTenantService: existing_tenant = TenantService.create_tenant(name=existing_tenant_name) TenantService.create_tenant_member(existing_tenant, account, role="owner") account.current_tenant = existing_tenant - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Try to create owner tenant again (should not create new one) TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) # Verify no new tenant was created - tenant_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).all() + tenant_joins = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).all() assert len(tenant_joins) == 1 assert account.current_tenant.id == existing_tenant.id def test_create_owner_tenant_if_not_exist_workspace_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant when workspace creation is disabled. @@ -2123,7 +2170,7 @@ class TestTenantService: with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) - def test_get_tenant_members_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_members_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting tenant members successfully. """ @@ -2187,7 +2234,9 @@ class TestTenantService: elif member.email == normal_email: assert member.role == "normal" - def test_get_dataset_operator_members_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_operator_members_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting dataset operator members successfully. """ @@ -2240,7 +2289,7 @@ class TestTenantService: assert dataset_operators[0].email == operator_email assert dataset_operators[0].role == "dataset_operator" - def test_get_custom_config_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_custom_config_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting custom config successfully. """ @@ -2259,9 +2308,8 @@ class TestTenantService: # Set custom config custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} tenant.custom_config_dict = custom_config - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Get custom config retrieved_config = TenantService.get_custom_config(tenant.id) @@ -2296,7 +2344,7 @@ class TestRegisterService: "passport_service": mock_passport_service, } - def test_setup_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_setup_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful system setup with account creation and tenant setup. """ @@ -2309,11 +2357,10 @@ class TestRegisterService: mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False - from extensions.ext_database import db from models.model import DifySetup - db.session.query(DifySetup).delete() - db.session.commit() + db_session_with_containers.query(DifySetup).delete() + db_session_with_containers.commit() # Execute setup RegisterService.setup( @@ -2327,7 +2374,7 @@ class TestRegisterService: # Verify account was created from models import Account - account = db.session.query(Account).filter_by(email=admin_email).first() + account = db_session_with_containers.query(Account).filter_by(email=admin_email).first() assert account is not None assert account.name == admin_name assert account.last_login_ip == ip_address @@ -2335,17 +2382,17 @@ class TestRegisterService: assert account.status == "active" # Verify DifySetup was created - dify_setup = db.session.query(DifySetup).first() + dify_setup = db_session_with_containers.query(DifySetup).first() assert dify_setup is not None # Verify tenant was created and linked from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" - def test_setup_failure_rollback(self, db_session_with_containers, mock_external_service_dependencies): + def test_setup_failure_rollback(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test setup failure with proper rollback of all created entities. """ @@ -2373,21 +2420,20 @@ class TestRegisterService: ) # Verify no entities were created (rollback worked) - from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin from models.model import DifySetup - account = db.session.query(Account).filter_by(email=admin_email).first() - tenant_count = db.session.query(Tenant).count() - tenant_join_count = db.session.query(TenantAccountJoin).count() - dify_setup_count = db.session.query(DifySetup).count() + account = db_session_with_containers.query(Account).filter_by(email=admin_email).first() + tenant_count = db_session_with_containers.query(Tenant).count() + tenant_join_count = db_session_with_containers.query(TenantAccountJoin).count() + dify_setup_count = db_session_with_containers.query(DifySetup).count() assert account is None assert tenant_count == 0 assert tenant_join_count == 0 assert dify_setup_count == 0 - def test_register_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful account registration with workspace creation. """ @@ -2421,16 +2467,15 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" assert account.current_tenant is not None assert account.current_tenant.name == f"{name}'s Workspace" - def test_register_with_oauth(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_with_oauth(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account registration with OAuth integration. """ @@ -2467,14 +2512,19 @@ class TestRegisterService: assert account.initialized_at is not None # Verify OAuth integration was created - from extensions.ext_database import db from models import AccountIntegrate - integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + integration = ( + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider=provider) + .first() + ) assert integration is not None assert integration.open_id == open_id - def test_register_with_pending_status(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_with_pending_status( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration with pending status. """ @@ -2511,14 +2561,15 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" - def test_register_workspace_creation_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_workspace_creation_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration when workspace creation is disabled. """ @@ -2549,13 +2600,14 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_register_workspace_limit_exceeded(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_workspace_limit_exceeded( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration when workspace limit is exceeded. """ @@ -2589,13 +2641,12 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_register_without_workspace(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_without_workspace(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account registration without workspace creation. """ @@ -2624,13 +2675,14 @@ class TestRegisterService: assert account.initialized_at is not None # Verify no tenant was created - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_invite_new_member_new_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_new_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a new member who doesn't have an account yet. """ @@ -2682,22 +2734,25 @@ class TestRegisterService: mock_send_mail.delay.assert_called_once() # Verify new account was created with pending status - from extensions.ext_database import db from models import Account, TenantAccountJoin - new_account = db.session.query(Account).filter_by(email=new_member_email).first() + new_account = db_session_with_containers.query(Account).filter_by(email=new_member_email).first() assert new_account is not None assert new_account.name == new_member_email.split("@")[0] # Default name from email assert new_account.status == "pending" # Verify tenant member was created tenant_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=new_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=new_account.id) + .first() ) assert tenant_join is not None assert tenant_join.role == "normal" - def test_invite_new_member_existing_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_existing_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting an existing member who is not in the tenant yet. """ @@ -2749,16 +2804,19 @@ class TestRegisterService: mock_send_mail.delay.assert_not_called() # Verify tenant member was created for existing account - from extensions.ext_database import db from models.account import TenantAccountJoin tenant_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=existing_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=existing_account.id) + .first() ) assert tenant_join is not None assert tenant_join.role == "admin" - def test_invite_new_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_existing_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a member who is already in the tenant with pending status. """ @@ -2793,9 +2851,8 @@ class TestRegisterService: password=existing_pending_member_password, ) existing_account.status = "pending" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Add existing account to tenant TenantService.create_tenant_member(tenant, existing_account, role="normal") @@ -2820,7 +2877,9 @@ class TestRegisterService: # Verify email task was called mock_send_mail.delay.assert_called_once() - def test_invite_new_member_no_inviter(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_no_inviter( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a member without providing an inviter. """ @@ -2846,7 +2905,7 @@ class TestRegisterService: ) def test_invite_new_member_account_already_in_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test inviting a member who is already in the tenant with active status. @@ -2882,9 +2941,8 @@ class TestRegisterService: password=already_in_tenant_password, ) existing_account.status = "active" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Add existing account to tenant TenantService.create_tenant_member(tenant, existing_account, role="normal") @@ -2899,7 +2957,9 @@ class TestRegisterService: inviter=inviter, ) - def test_generate_invite_token_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_invite_token_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation of invite token. """ @@ -2943,7 +3003,7 @@ class TestRegisterService: assert invitation_data["email"] == account.email assert invitation_data["workspace_id"] == tenant.id - def test_is_valid_invite_token_valid(self, db_session_with_containers, mock_external_service_dependencies): + def test_is_valid_invite_token_valid(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test validation of valid invite token. """ @@ -2974,7 +3034,9 @@ class TestRegisterService: # Verify token is valid assert is_valid is True - def test_is_valid_invite_token_invalid(self, db_session_with_containers, mock_external_service_dependencies): + def test_is_valid_invite_token_invalid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation of invalid invite token. """ @@ -2987,7 +3049,7 @@ class TestRegisterService: assert is_valid is False def test_revoke_token_with_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test revoking token with workspace ID and email. @@ -3030,7 +3092,7 @@ class TestRegisterService: assert redis_client.get(token_key) is not None def test_revoke_token_without_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test revoking token without workspace ID and email. @@ -3073,7 +3135,7 @@ class TestRegisterService: assert redis_client.get(token_key) is None def test_get_invitation_if_token_valid_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with valid token. @@ -3122,7 +3184,7 @@ class TestRegisterService: assert result["data"]["workspace_id"] == tenant.id def test_get_invitation_if_token_valid_invalid_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with invalid token. @@ -3142,7 +3204,7 @@ class TestRegisterService: assert result is None def test_get_invitation_if_token_valid_invalid_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with invalid tenant. @@ -3192,7 +3254,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_if_token_valid_account_mismatch( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with account ID mismatch. @@ -3242,7 +3304,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_if_token_valid_tenant_not_normal( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with tenant not in normal status. @@ -3269,9 +3331,8 @@ class TestRegisterService: # Change tenant status to non-normal tenant.status = "suspended" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Create a real token from extensions.ext_redis import redis_client @@ -3300,7 +3361,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_by_token_with_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation by token with workspace ID and email. @@ -3339,7 +3400,7 @@ class TestRegisterService: redis_client.delete(cache_key) def test_get_invitation_by_token_without_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation by token without workspace ID and email. @@ -3372,7 +3433,7 @@ class TestRegisterService: # Clean up redis_client.delete(token_key) - def test_get_invitation_token_key(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_invitation_token_key(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting invitation token key. """ diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index e7cc140582..45839fd463 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.plugin.impl.exc import PluginDaemonClientSideError from models import Account @@ -87,7 +88,7 @@ class TestAgentService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -133,13 +134,12 @@ class TestAgentService: # Update the app model config to set agent_mode for agent-chat mode if app.mode == "agent-chat" and app.app_model_config: app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []}) - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() return app, account - def _create_test_conversation_and_message(self, db_session_with_containers, app, account): + def _create_test_conversation_and_message(self, db_session_with_containers: Session, app, account): """ Helper method to create a test conversation and message with agent thoughts. @@ -153,8 +153,6 @@ class TestAgentService: """ fake = Faker() - from extensions.ext_database import db - # Create conversation conversation = Conversation( id=fake.uuid4(), @@ -167,8 +165,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -180,12 +178,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -206,12 +204,12 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return conversation, message - def _create_test_agent_thoughts(self, db_session_with_containers, message): + def _create_test_agent_thoughts(self, db_session_with_containers: Session, message): """ Helper method to create test agent thoughts for a message. @@ -224,8 +222,6 @@ class TestAgentService: """ fake = Faker() - from extensions.ext_database import db - agent_thoughts = [] # Create first agent thought @@ -251,7 +247,7 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought1) + db_session_with_containers.add(thought1) agent_thoughts.append(thought1) # Create second agent thought @@ -277,14 +273,14 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought2) + db_session_with_containers.add(thought2) agent_thoughts.append(thought2) - db.session.commit() + db_session_with_containers.commit() return agent_thoughts - def test_get_agent_logs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of agent logs with complete data. """ @@ -344,7 +340,7 @@ class TestAgentService: assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon def test_get_agent_logs_conversation_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when conversation is not found. @@ -358,7 +354,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Conversation not found"): AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4()) - def test_get_agent_logs_message_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_message_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when message is not found. """ @@ -372,7 +370,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Message not found"): AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4()) - def test_get_agent_logs_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_end_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval when conversation is from end user. """ @@ -381,8 +381,6 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create end user end_user = EndUser( id=fake.uuid4(), @@ -393,8 +391,8 @@ class TestAgentService: session_id=fake.uuid4(), name=fake.name(), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create conversation with end user conversation = Conversation( @@ -408,8 +406,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -421,12 +419,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -447,8 +445,8 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -457,7 +455,9 @@ class TestAgentService: assert result is not None assert result["meta"]["executor"] == end_user.name - def test_get_agent_logs_with_unknown_executor(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_unknown_executor( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval when executor is unknown. """ @@ -466,8 +466,6 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create conversation with non-existent account conversation = Conversation( id=fake.uuid4(), @@ -480,8 +478,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -493,12 +491,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -519,8 +517,8 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -529,7 +527,9 @@ class TestAgentService: assert result is not None assert result["meta"]["executor"] == "Unknown" - def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_tool_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with tool errors. """ @@ -539,8 +539,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with tool error thought_with_error = MessageAgentThought( message_id=message.id, @@ -564,8 +562,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought_with_error) - db.session.commit() + db_session_with_containers.add(thought_with_error) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -580,7 +578,7 @@ class TestAgentService: assert tool_call["error"] == "Tool execution failed" def test_get_agent_logs_without_agent_thoughts( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval when message has no agent thoughts. @@ -600,7 +598,7 @@ class TestAgentService: assert len(result["iterations"]) == 0 def test_get_agent_logs_app_model_config_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when app model config is not found. @@ -610,11 +608,9 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Remove app model config to test error handling app.app_model_config_id = None - db.session.commit() + db_session_with_containers.commit() # Create conversation without app model config conversation = Conversation( @@ -629,8 +625,8 @@ class TestAgentService: from_source="api", app_model_config_id=None, # Explicitly set to None ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message( @@ -651,15 +647,15 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test with pytest.raises(ValueError, match="App model config not found"): AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) def test_get_agent_logs_agent_config_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when agent config is not found. @@ -677,7 +673,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Agent config not found"): AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) - def test_list_agent_providers_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_agent_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful listing of agent providers. """ @@ -698,7 +696,7 @@ class TestAgentService: mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id)) - def test_get_agent_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of specific agent provider. """ @@ -720,7 +718,9 @@ class TestAgentService: mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name) - def test_get_agent_provider_plugin_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_provider_plugin_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when plugin daemon client raises an error. """ @@ -741,7 +741,7 @@ class TestAgentService: AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name) def test_get_agent_logs_with_complex_tool_data( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval with complex tool data and multiple tools. @@ -752,8 +752,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with multiple tools complex_thought = MessageAgentThought( message_id=message.id, @@ -799,8 +797,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(complex_thought) - db.session.commit() + db_session_with_containers.add(complex_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -831,7 +829,7 @@ class TestAgentService: assert tool_calls[2]["status"] == "success" assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon - def test_get_agent_logs_with_files(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_files(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test agent logs retrieval with message files and agent thought files. """ @@ -841,8 +839,7 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from core.workflow.file import FileTransferMethod, FileType - from extensions.ext_database import db + from dify_graph.file import FileTransferMethod, FileType from models.enums import CreatorUserRole # Add files to message @@ -867,9 +864,9 @@ class TestAgentService: created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) - db.session.add(message_file1) - db.session.add(message_file2) - db.session.commit() + db_session_with_containers.add(message_file1) + db_session_with_containers.add(message_file2) + db_session_with_containers.commit() # Create agent thought with files thought_with_files = MessageAgentThought( @@ -895,8 +892,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought_with_files) - db.session.commit() + db_session_with_containers.add(thought_with_files) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -912,7 +909,7 @@ class TestAgentService: assert "file2" in iterations[0]["files"] def test_get_agent_logs_with_different_timezone( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval with different timezone settings. @@ -938,7 +935,9 @@ class TestAgentService: assert "T" in start_time # ISO format assert "+08:00" in start_time or "Z" in start_time # Timezone offset - def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_empty_tool_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with empty tool data. """ @@ -948,8 +947,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with empty tool data empty_thought = MessageAgentThought( message_id=message.id, @@ -964,8 +961,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(empty_thought) - db.session.commit() + db_session_with_containers.add(empty_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -979,7 +976,9 @@ class TestAgentService: tool_calls = iterations[0]["tool_calls"] assert len(tool_calls) == 0 # No tools to process - def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_malformed_json( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with malformed JSON data in tool fields. """ @@ -989,8 +988,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with malformed JSON malformed_thought = MessageAgentThought( message_id=message.id, @@ -1005,8 +1002,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(malformed_thought) - db.session.commit() + db_session_with_containers.add(malformed_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 4f5190e533..004d643955 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account @@ -52,7 +53,7 @@ class TestAnnotationService: "current_user": mock_user, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -115,11 +116,10 @@ class TestAnnotationService: tenant_id, ) - def _create_test_conversation(self, app, account, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake): """ Helper method to create a test conversation with all required fields. """ - from extensions.ext_database import db from models.model import Conversation conversation = Conversation( @@ -141,17 +141,16 @@ class TestAnnotationService: from_account_id=account.id, ) - db.session.add(conversation) - db.session.flush() + db_session_with_containers.add(conversation) + db_session_with_containers.flush() return conversation - def _create_test_message(self, app, conversation, account, fake): + def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake): """ Helper method to create a test message with all required fields. """ import json - from extensions.ext_database import db from models.model import Message message = Message( @@ -180,12 +179,12 @@ class TestAnnotationService: from_account_id=account.id, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message def test_insert_app_annotation_directly_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct insertion of app annotation. @@ -211,9 +210,8 @@ class TestAnnotationService: assert annotation.id is not None # Verify annotation was saved to database - from extensions.ext_database import db - db.session.refresh(annotation) + db_session_with_containers.refresh(annotation) assert annotation.id is not None # Verify add_annotation_to_index_task was called (when annotation setting exists) @@ -221,7 +219,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_insert_app_annotation_directly_requires_question( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Question must be provided when inserting annotations directly. @@ -238,7 +236,7 @@ class TestAnnotationService: AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) def test_insert_app_annotation_directly_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test direct insertion of app annotation when app is not found. @@ -260,7 +258,7 @@ class TestAnnotationService: AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) def test_update_app_annotation_directly_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct update of app annotation. @@ -298,7 +296,7 @@ class TestAnnotationService: mock_external_service_dependencies["update_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_new( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating new annotation from message. @@ -307,8 +305,8 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Setup annotation data with message_id annotation_args = { @@ -333,7 +331,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_update( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test updating existing annotation from message. @@ -342,8 +340,8 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial annotation initial_args = { @@ -373,7 +371,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating annotation from message when app is not found. @@ -395,7 +393,7 @@ class TestAnnotationService: AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) def test_get_annotation_list_by_app_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of annotation list by app ID. @@ -428,7 +426,7 @@ class TestAnnotationService: assert annotation.account_id == account.id def test_get_annotation_list_by_app_id_with_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of annotation list with keyword search. @@ -462,7 +460,7 @@ class TestAnnotationService: assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. @@ -534,7 +532,7 @@ class TestAnnotationService: assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) def test_get_annotation_list_by_app_id_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of annotation list when app is not found. @@ -549,7 +547,9 @@ class TestAnnotationService: with pytest.raises(NotFound, match="App not found"): AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") - def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of app annotation. """ @@ -568,16 +568,19 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - from extensions.ext_database import db - deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + deleted_annotation = ( + db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + ) assert deleted_annotation is None # Verify delete_annotation_index_task was called (when annotation setting exists) # Note: In this test, no annotation setting exists, so task should not be called mock_external_service_dependencies["delete_task"].delay.assert_not_called() - def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_annotation_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deletion of app annotation when app is not found. """ @@ -593,7 +596,7 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) def test_delete_app_annotation_annotation_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test deletion of app annotation when annotation is not found. @@ -606,7 +609,9 @@ class TestAnnotationService: with pytest.raises(NotFound, match="Annotation not found"): AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) - def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful enabling of app annotation. """ @@ -632,7 +637,9 @@ class TestAnnotationService: # Verify task was called mock_external_service_dependencies["enable_task"].delay.assert_called_once() - def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful disabling of app annotation. """ @@ -651,7 +658,9 @@ class TestAnnotationService: # Verify task was called mock_external_service_dependencies["disable_task"].delay.assert_called_once() - def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_app_annotation_cached_job( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test enabling app annotation when job is already cached. """ @@ -685,7 +694,9 @@ class TestAnnotationService: # Clean up redis_client.delete(enable_app_annotation_key) - def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_annotation_hit_histories_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of annotation hit histories. """ @@ -728,7 +739,9 @@ class TestAnnotationService: assert history.app_id == app.id assert history.account_id == account.id - def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_add_annotation_history_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful addition of annotation history. """ @@ -763,16 +776,15 @@ class TestAnnotationService: ) # Verify hit count was incremented - from extensions.ext_database import db - db.session.refresh(annotation) + db_session_with_containers.refresh(annotation) assert annotation.hit_count == initial_hit_count + 1 # Verify history was created from models.model import AppAnnotationHitHistory history = ( - db.session.query(AppAnnotationHitHistory) + db_session_with_containers.query(AppAnnotationHitHistory) .where( AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id ) @@ -786,7 +798,9 @@ class TestAnnotationService: assert history.score == score assert history.source == "console" - def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_annotation_by_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of annotation by ID. """ @@ -811,7 +825,9 @@ class TestAnnotationService: assert retrieved_annotation.content == annotation_args["answer"] assert retrieved_annotation.account_id == account.id - def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_batch_import_app_annotations_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful batch import of app annotations. """ @@ -854,7 +870,7 @@ class TestAnnotationService: mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() def test_batch_import_app_annotations_empty_file( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test batch import with empty CSV file. @@ -889,7 +905,7 @@ class TestAnnotationService: assert "empty" in result["error_msg"].lower() def test_batch_import_app_annotations_quota_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test batch import when quota is exceeded. @@ -935,7 +951,7 @@ class TestAnnotationService: assert "limit" in result["error_msg"].lower() def test_get_app_annotation_setting_by_app_id_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting enabled app annotation setting by app ID. @@ -944,7 +960,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -956,8 +971,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -967,8 +982,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Get annotation setting result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) @@ -981,7 +996,7 @@ class TestAnnotationService: assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" def test_get_app_annotation_setting_by_app_id_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting disabled app annotation setting by app ID. @@ -996,7 +1011,7 @@ class TestAnnotationService: assert result["enabled"] is False def test_update_app_annotation_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of app annotation setting. @@ -1005,7 +1020,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1017,8 +1031,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1028,8 +1042,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Update annotation setting update_args = { @@ -1046,11 +1060,11 @@ class TestAnnotationService: assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" # Verify database was updated - db.session.refresh(annotation_setting) + db_session_with_containers.refresh(annotation_setting) assert annotation_setting.score_threshold == 0.9 def test_export_annotation_list_by_app_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful export of annotation list by app ID. @@ -1083,7 +1097,7 @@ class TestAnnotationService: assert annotation.created_at <= exported_annotations[i - 1].created_at def test_export_annotation_list_by_app_id_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test export of annotation list when app is not found. @@ -1099,7 +1113,7 @@ class TestAnnotationService: AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) def test_insert_app_annotation_directly_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct insertion of app annotation with annotation setting enabled. @@ -1108,7 +1122,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1120,8 +1133,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1131,8 +1144,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Setup annotation data annotation_args = { @@ -1161,7 +1174,7 @@ class TestAnnotationService: assert call_args[4] == collection_binding.id # collection_binding_id def test_update_app_annotation_directly_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct update of app annotation with annotation setting enabled. @@ -1170,7 +1183,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1182,8 +1194,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1193,8 +1205,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # First, create an annotation original_args = { @@ -1234,7 +1246,7 @@ class TestAnnotationService: assert call_args[4] == collection_binding.id # collection_binding_id def test_delete_app_annotation_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful deletion of app annotation with annotation setting enabled. @@ -1243,7 +1255,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1255,8 +1266,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1267,8 +1278,8 @@ class TestAnnotationService: updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Create an annotation first annotation_args = { @@ -1285,7 +1296,9 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + deleted_annotation = ( + db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + ) assert deleted_annotation is None # Verify delete_annotation_index_task was called @@ -1297,7 +1310,7 @@ class TestAnnotationService: assert call_args[3] == collection_binding.id # collection_binding_id def test_up_insert_app_annotation_from_message_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating annotation from message with annotation setting enabled. @@ -1306,7 +1319,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1318,8 +1330,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1329,12 +1341,12 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Setup annotation data with message_id annotation_args = { diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 8c8be2e670..b8bf8543bc 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.api_based_extension import APIBasedExtension from services.account_service import AccountService, TenantService @@ -31,7 +32,7 @@ class TestAPIBasedExtensionService: "requestor_instance": mock_requestor_instance, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -61,7 +62,7 @@ class TestAPIBasedExtensionService: return account, tenant - def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful saving of API-based extension. """ @@ -90,15 +91,16 @@ class TestAPIBasedExtensionService: assert saved_extension.created_at is not None # Verify extension was saved to database - from extensions.ext_database import db - db.session.refresh(saved_extension) + db_session_with_containers.refresh(saved_extension) assert saved_extension.id is not None # Verify ping connection was called mock_external_service_dependencies["requestor_instance"].request.assert_called_once() - def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_validation_errors( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation errors when saving extension with invalid data. """ @@ -132,7 +134,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must not be empty"): APIBasedExtensionService.save(extension_data) - def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_by_tenant_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of all extensions by tenant ID. """ @@ -169,7 +173,7 @@ class TestAPIBasedExtensionService: # Verify descending order (newer first) assert extension.created_at <= extension_list[i - 1].created_at - def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of extension by tenant ID and extension ID. """ @@ -200,7 +204,9 @@ class TestAPIBasedExtensionService: assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted assert retrieved_extension.created_at is not None - def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extension when extension is not found. """ @@ -214,7 +220,7 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) - def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful deletion of extension. """ @@ -238,12 +244,15 @@ class TestAPIBasedExtensionService: APIBasedExtensionService.delete(created_extension) # Verify extension was deleted - from extensions.ext_database import db - deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + deleted_extension = ( + db_session_with_containers.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + ) assert deleted_extension is None - def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when saving extension with duplicate name. """ @@ -272,7 +281,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="name must be unique, it is already existed"): APIBasedExtensionService.save(extension_data2) - def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_update_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful update of existing extension. """ @@ -329,7 +340,9 @@ class TestAPIBasedExtensionService: assert retrieved_extension.api_endpoint == new_endpoint assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved - def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_connection_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test connection error when saving extension with invalid endpoint. """ @@ -356,7 +369,7 @@ class TestAPIBasedExtensionService: APIBasedExtensionService.save(extension_data) def test_save_extension_invalid_api_key_length( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test validation error when saving extension with API key that is too short. @@ -378,7 +391,7 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must be at least 5 characters"): APIBasedExtensionService.save(extension_data) - def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test validation errors when saving extension with empty required fields. """ @@ -412,7 +425,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must not be empty"): APIBasedExtensionService.save(extension_data) - def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_by_tenant_id_empty_list( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extensions when no extensions exist for tenant. """ @@ -428,7 +443,9 @@ class TestAPIBasedExtensionService: assert len(extension_list) == 0 assert extension_list == [] - def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_invalid_ping_response( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when ping response is invalid. """ @@ -452,7 +469,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="{'result': 'invalid'}"): APIBasedExtensionService.save(extension_data) - def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_missing_ping_result( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when ping response is missing result field. """ @@ -476,7 +495,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="{'status': 'ok'}"): APIBasedExtensionService.save(extension_data) - def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_wrong_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extension when tenant ID doesn't match. """ diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 8544d23cdf..787a99f3e8 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -3,6 +3,7 @@ from unittest.mock import ANY, MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.model import EndUser @@ -118,7 +119,9 @@ class TestAppGenerateService: "global_dify_config": mock_global_dify_config, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"): + def _create_test_app_and_account( + self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat" + ): """ Helper method to create a test app and account for testing. @@ -169,7 +172,7 @@ class TestAppGenerateService: return app, account - def _create_test_workflow(self, db_session_with_containers, app): + def _create_test_workflow(self, db_session_with_containers: Session, app): """ Helper method to create a test workflow for testing. @@ -191,14 +194,14 @@ class TestAppGenerateService: status="published", ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow - def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_completion_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for completion mode app. """ @@ -226,7 +229,7 @@ class TestAppGenerateService: mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once() - def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_chat_mode_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful generation for chat mode app. """ @@ -250,7 +253,9 @@ class TestAppGenerateService: mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once() - def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_agent_chat_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for agent chat mode app. """ @@ -274,7 +279,9 @@ class TestAppGenerateService: mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() - def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_advanced_chat_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for advanced chat mode app. """ @@ -300,7 +307,9 @@ class TestAppGenerateService: "advanced_chat_generator" ].return_value.convert_to_event_stream.assert_called_once() - def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_workflow_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for workflow mode app. """ @@ -324,7 +333,9 @@ class TestAppGenerateService: mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() - def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_specific_workflow_id( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with a specific workflow ID. """ @@ -355,7 +366,9 @@ class TestAppGenerateService: "workflow_service" ].return_value.get_published_workflow_by_id.assert_called_once() - def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_debugger_invoke_from( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with debugger invoke from. """ @@ -378,7 +391,9 @@ class TestAppGenerateService: # Verify draft workflow was fetched for debugger mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once() - def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_non_streaming_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with non-streaming mode. """ @@ -401,7 +416,7 @@ class TestAppGenerateService: # Verify rate limit exit was called for non-streaming mode mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() - def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test generation with EndUser instead of Account. """ @@ -421,10 +436,8 @@ class TestAppGenerateService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Setup test arguments args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} @@ -438,7 +451,7 @@ class TestAppGenerateService: assert result == ["test_response"] def test_generate_with_billing_enabled_sandbox_plan( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with billing enabled and sandbox plan. @@ -466,7 +479,9 @@ class TestAppGenerateService: # Verify billing service was called to consume quota mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() - def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_invalid_app_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with invalid app mode. """ @@ -491,7 +506,7 @@ class TestAppGenerateService: assert "Invalid app mode" in str(exc_info.value) def test_generate_with_workflow_id_format_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with invalid workflow ID format. @@ -518,7 +533,7 @@ class TestAppGenerateService: assert "Invalid workflow_id format" in str(exc_info.value) def test_generate_with_workflow_not_found_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not found. @@ -552,7 +567,7 @@ class TestAppGenerateService: assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value) def test_generate_with_workflow_not_initialized_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not initialized for debugger. @@ -578,7 +593,7 @@ class TestAppGenerateService: assert "Workflow not initialized" in str(exc_info.value) def test_generate_with_workflow_not_published_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not published for non-debugger. @@ -604,7 +619,7 @@ class TestAppGenerateService: assert "Workflow not published" in str(exc_info.value) def test_generate_single_iteration_advanced_chat_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single iteration generation for advanced chat mode. @@ -631,7 +646,7 @@ class TestAppGenerateService: ].return_value.single_iteration_generate.assert_called_once() def test_generate_single_iteration_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single iteration generation for workflow mode. @@ -658,7 +673,7 @@ class TestAppGenerateService: ].return_value.single_iteration_generate.assert_called_once() def test_generate_single_iteration_invalid_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test single iteration generation with invalid app mode. @@ -681,7 +696,7 @@ class TestAppGenerateService: assert "Invalid app mode" in str(exc_info.value) def test_generate_single_loop_advanced_chat_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single loop generation for advanced chat mode. @@ -708,7 +723,7 @@ class TestAppGenerateService: ].return_value.single_loop_generate.assert_called_once() def test_generate_single_loop_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single loop generation for workflow mode. @@ -732,7 +747,9 @@ class TestAppGenerateService: # Verify workflow generator was called mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once() - def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_single_loop_invalid_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test single loop generation with invalid app mode. """ @@ -753,7 +770,9 @@ class TestAppGenerateService: # Verify error message assert "Invalid app mode" in str(exc_info.value) - def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_more_like_this_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful more like this generation. """ @@ -778,7 +797,7 @@ class TestAppGenerateService: ].return_value.generate_more_like_this.assert_called_once() def test_generate_more_like_this_with_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test more like this generation with EndUser. @@ -799,10 +818,8 @@ class TestAppGenerateService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() message_id = fake.uuid4() @@ -815,7 +832,7 @@ class TestAppGenerateService: assert result == ["more_like_this_response"] def test_get_max_active_requests_with_app_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with app-specific limit. @@ -835,7 +852,7 @@ class TestAppGenerateService: assert result == 10 def test_get_max_active_requests_with_config_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with config limit being smaller. @@ -856,7 +873,7 @@ class TestAppGenerateService: assert result <= 100 def test_get_max_active_requests_with_zero_limits( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with zero limits (infinite). @@ -875,7 +892,9 @@ class TestAppGenerateService: # Verify the result (should return config limit when app limit is 0) assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS - def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_exception_cleanup( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that rate limit exit is called when an exception occurs. """ @@ -904,7 +923,9 @@ class TestAppGenerateService: # Verify rate limit exit was called for cleanup mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() - def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_agent_mode_detection( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with agent mode detection based on app configuration. """ @@ -932,7 +953,7 @@ class TestAppGenerateService: mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() def test_generate_with_different_invoke_from_values( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with different invoke from values. @@ -962,7 +983,7 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_complex_args(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test generation with complex arguments including files and external trace ID. """ diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 745d6c97b0..fc3b20aaae 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account @@ -44,7 +45,7 @@ class TestAppService: "account_feature_service": mock_account_feature_service, } - def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app creation with basic parameters. """ @@ -98,7 +99,9 @@ class TestAppService: assert app.is_public is False assert app.is_universal is False - def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_with_different_modes( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app creation with different app modes. """ @@ -141,7 +144,7 @@ class TestAppService: assert app.tenant_id == tenant.id assert app.created_by == account.id - def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app retrieval. """ @@ -189,7 +192,7 @@ class TestAppService: assert retrieved_app.tenant_id == created_app.tenant_id assert retrieved_app.created_by == created_app.created_by - def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful paginated app list retrieval. """ @@ -243,7 +246,9 @@ class TestAppService: assert app.tenant_id == tenant.id assert app.mode == "chat" - def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_with_filters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test paginated app list with various filters. """ @@ -316,7 +321,9 @@ class TestAppService: my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) assert len(my_apps.items) == 1 - def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_with_tag_filters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test paginated app list with tag filters. """ @@ -386,7 +393,7 @@ class TestAppService: # Should return None when no apps match tag filter assert paginated_apps is None - def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app update with all fields. """ @@ -455,7 +462,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. """ @@ -508,7 +515,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_icon_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app icon update. """ @@ -565,7 +572,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_site_status_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful app site status update. """ @@ -623,7 +632,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_api_status_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful app API status update. """ @@ -681,7 +692,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_site_status_no_change( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app site status update when status doesn't change. """ @@ -732,7 +745,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app deletion. """ @@ -778,12 +791,13 @@ class TestAppService: mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) # Verify app was deleted from database - from extensions.ext_database import db - deleted_app = db.session.query(App).filter_by(id=app_id).first() + deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first() assert deleted_app is None - def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_with_related_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app deletion with related data cleanup. """ @@ -839,12 +853,11 @@ class TestAppService: mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) # Verify app was deleted from database - from extensions.ext_database import db - deleted_app = db.session.query(App).filter_by(id=app_id).first() + deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first() assert deleted_app is None - def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_meta_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app metadata retrieval. """ @@ -883,7 +896,7 @@ class TestAppService: assert "tool_icons" in app_meta # Note: get_app_meta currently only returns tool_icons - def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_code_by_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app code retrieval by app ID. """ @@ -923,7 +936,7 @@ class TestAppService: assert app_code is not None assert len(app_code) > 0 - def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_id_by_code_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app ID retrieval by app code. """ @@ -963,10 +976,9 @@ class TestAppService: site.status = "normal" site.default_language = "en-US" site.customize_token_strategy = "uuid" - from extensions.ext_database import db - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() # Get app ID by code app_id = AppService.get_app_id_by_code(site.code) @@ -974,7 +986,7 @@ class TestAppService: # Verify app ID was retrieved correctly assert app_id == app.id - def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_invalid_mode(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test app creation with invalid mode. """ @@ -1010,7 +1022,7 @@ class TestAppService: app_service.create_app(tenant.id, app_args, account) def test_get_apps_with_special_characters_in_name( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test app retrieval with special characters in name search to verify SQL injection prevention. diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index f05c47913e..102c1a1eb5 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,14 +9,15 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from core.model_runtime.entities.model_entities import ModelType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from extensions.ext_database import db +from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings +from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity from services.errors.dataset import DatasetNameDuplicateError @@ -24,7 +25,9 @@ class DatasetServiceIntegrationDataFactory: """Factory for creating real database entities used by integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> tuple[Account, Tenant]: """Create an account and tenant, then bind the account as current tenant member.""" account = Account( email=f"{uuid4()}@example.com", @@ -33,8 +36,8 @@ class DatasetServiceIntegrationDataFactory: status="active", ) tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add_all([account, tenant]) - db.session.flush() + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -42,8 +45,8 @@ class DatasetServiceIntegrationDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.flush() + db_session_with_containers.add(join) + db_session_with_containers.flush() # Keep tenant context on the in-memory user without opening a separate session. account.role = role @@ -52,6 +55,7 @@ class DatasetServiceIntegrationDataFactory: @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -81,12 +85,14 @@ class DatasetServiceIntegrationDataFactory: collection_binding_id=collection_binding_id, chunk_structure=chunk_structure, ) - db.session.add(dataset) - db.session.flush() + db_session_with_containers.add(dataset) + db_session_with_containers.flush() return dataset @staticmethod - def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document: + def create_document( + db_session_with_containers: Session, dataset: Dataset, created_by: str, name: str = "doc.txt" + ) -> Document: """Create a document row belonging to the given dataset.""" document = Document( tenant_id=dataset.tenant_id, @@ -101,8 +107,8 @@ class DatasetServiceIntegrationDataFactory: indexing_status="completed", doc_form="text_model", ) - db.session.add(document) - db.session.flush() + db_session_with_containers.add(document) + db_session_with_containers.flush() return document @staticmethod @@ -117,10 +123,10 @@ class DatasetServiceIntegrationDataFactory: class TestDatasetServiceCreateDataset: """Integration coverage for DatasetService.create_empty_dataset.""" - def test_create_internal_dataset_basic_success(self, db_session_with_containers): + def test_create_internal_dataset_basic_success(self, db_session_with_containers: Session): """Create a basic internal dataset with minimal configuration.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -132,17 +138,17 @@ class TestDatasetServiceCreateDataset: ) # Assert - created_dataset = db.session.get(Dataset, result.id) + created_dataset = db_session_with_containers.get(Dataset, result.id) assert created_dataset is not None assert created_dataset.provider == "vendor" assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME assert created_dataset.embedding_model_provider is None assert created_dataset.embedding_model is None - def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers): + def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers: Session): """Create an internal dataset with economy indexing and no embedding model.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -154,15 +160,15 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "economy" assert result.embedding_model_provider is None assert result.embedding_model is None - def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers): + def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers: Session): """Create a high-quality dataset and persist embedding model settings.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act @@ -178,7 +184,7 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "high_quality" assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name @@ -187,11 +193,12 @@ class TestDatasetServiceCreateDataset: model_type=ModelType.TEXT_EMBEDDING, ) - def test_create_dataset_duplicate_name_error(self, db_session_with_containers): + def test_create_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Raise duplicate-name error when the same tenant already has the name.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Duplicate Dataset", @@ -208,10 +215,10 @@ class TestDatasetServiceCreateDataset: account=account, ) - def test_create_external_dataset_success(self, db_session_with_containers): + def test_create_external_dataset_success(self, db_session_with_containers: Session): """Create an external dataset and persist external knowledge binding.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) external_knowledge_api_id = str(uuid4()) external_knowledge_id = "knowledge-123" @@ -230,16 +237,16 @@ class TestDatasetServiceCreateDataset: ) # Assert - binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() + binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() assert result.provider == "external" assert binding is not None assert binding.external_knowledge_id == external_knowledge_id assert binding.external_knowledge_api_id == external_knowledge_api_id - def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers): + def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers: Session): """Create a high-quality dataset with retrieval/reranking settings.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() retrieval_model = RetrievalModel( search_method=RetrievalMethod.SEMANTIC_SEARCH, @@ -270,24 +277,299 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.retrieval_model == retrieval_model.model_dump() mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0") + def test_create_internal_dataset_with_high_quality_indexing_custom_embedding( + self, db_session_with_containers: Session + ): + """Create high-quality dataset with explicitly configured embedding model.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + embedding_provider = "openai" + embedding_model_name = "text-embedding-3-small" + embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model( + provider=embedding_provider, model_name=embedding_model_name + ) + + # Act + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, + ): + mock_model_manager.return_value.get_model_instance.return_value = embedding_model + + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Custom Embedding Dataset", + description=None, + indexing_technique="high_quality", + account=account, + embedding_model_provider=embedding_provider, + embedding_model_name=embedding_model_name, + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.indexing_technique == "high_quality" + assert result.embedding_model_provider == embedding_provider + assert result.embedding_model == embedding_model_name + mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) + mock_model_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id=tenant.id, + provider=embedding_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=embedding_model_name, + ) + + def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers: Session): + """Persist retrieval model settings when creating an internal dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=False, + top_k=2, + score_threshold_enabled=True, + score_threshold=0.0, + ) + + # Act + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Retrieval Model Dataset", + description=None, + indexing_technique=None, + account=account, + retrieval_model=retrieval_model, + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.retrieval_model == retrieval_model.model_dump() + + def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers: Session): + """Persist canonical custom permission when creating an internal dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + + # Act + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Custom Permission Dataset", + description=None, + indexing_technique=None, + account=account, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.permission == DatasetPermissionEnum.ALL_TEAM + + def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers: Session): + """Raise error when external API template does not exist.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + external_knowledge_api_id = str(uuid4()) + + # Act / Assert + with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api: + mock_get_api.return_value = None + with pytest.raises(ValueError, match=r"External API template not found\.?"): + DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="External Missing API Dataset", + description=None, + indexing_technique=None, + account=account, + provider="external", + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id="knowledge-123", + ) + + def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): + """Raise error when external knowledge id is missing for external dataset creation.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + external_knowledge_api_id = str(uuid4()) + + # Act / Assert + with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api: + mock_get_api.return_value = Mock(id=external_knowledge_api_id) + with pytest.raises(ValueError, match="external_knowledge_id is required"): + DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="External Missing Knowledge Dataset", + description=None, + indexing_technique=None, + account=account, + provider="external", + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id=None, + ) + + +class TestDatasetServiceCreateRagPipelineDataset: + """Integration coverage for DatasetService.create_empty_rag_pipeline_dataset.""" + + def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers: Session): + """Create rag-pipeline dataset and pipeline rows when a name is provided.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="RAG Pipeline Dataset", + description="RAG Pipeline Description", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + created_dataset = db_session_with_containers.get(Dataset, result.id) + created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) + assert created_dataset is not None + assert created_dataset.name == entity.name + assert created_dataset.runtime_mode == "rag_pipeline" + assert created_dataset.created_by == account.id + assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME + assert created_pipeline is not None + assert created_pipeline.name == entity.name + assert created_pipeline.created_by == account.id + + def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers: Session): + """Create rag-pipeline dataset with generated incremental name when input name is empty.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + generated_name = "Untitled 1" + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with ( + patch("services.dataset_service.current_user", account), + patch("services.dataset_service.generate_incremental_name") as mock_generate_name, + ): + mock_generate_name.return_value = generated_name + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db_session_with_containers.refresh(result) + created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) + assert result.name == generated_name + assert created_pipeline is not None + assert created_pipeline.name == generated_name + mock_generate_name.assert_called_once() + + def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers: Session): + """Raise duplicate-name error when rag-pipeline dataset name already exists.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + duplicate_name = "Duplicate RAG Dataset" + DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=account.id, + name=duplicate_name, + indexing_technique=None, + ) + db_session_with_containers.commit() + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name=duplicate_name, + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act / Assert + with ( + patch("services.dataset_service.current_user", account), + pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {duplicate_name} already exists"), + ): + DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers: Session): + """Persist canonical custom permission for rag-pipeline dataset creation.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="Custom Permission RAG Dataset", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.permission == DatasetPermissionEnum.ALL_TEAM + + def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers: Session): + """Persist icon metadata when creating rag-pipeline dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + icon_info = IconInfo( + icon="📚", + icon_background="#E8F5E9", + icon_type="emoji", + icon_url="https://example.com/icon.png", + ) + entity = RagPipelineDatasetCreateEntity( + name="Icon Info RAG Dataset", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.icon_info == icon_info.model_dump() + class TestDatasetServiceUpdateAndDeleteDataset: """Integration coverage for SQL-backed update and delete behavior.""" - def test_update_dataset_duplicate_name_error(self, db_session_with_containers): + def test_update_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Reject update when target name already exists within the same tenant.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) source_dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Source Dataset", ) DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Existing Dataset", @@ -297,17 +579,20 @@ class TestDatasetServiceUpdateAndDeleteDataset: with pytest.raises(ValueError, match="Dataset name already exists"): DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account) - def test_delete_dataset_with_documents_success(self, db_session_with_containers): + def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset that already has documents.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique="high_quality", chunk_structure="text_model", ) - DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id) + DatasetServiceIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, created_by=account.id + ) # Act with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: @@ -315,14 +600,15 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) - def test_delete_empty_dataset_success(self, db_session_with_containers): + def test_delete_empty_dataset_success(self, db_session_with_containers: Session): """Delete a dataset that has no documents and no indexing technique.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique=None, @@ -335,14 +621,15 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) - def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session): """Delete dataset when indexing_technique is None but doc_form path still exists.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique=None, @@ -355,17 +642,17 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) class TestDatasetServiceRetrievalConfiguration: """Integration coverage for retrieval configuration persistence.""" - def test_get_dataset_retrieval_configuration(self, db_session_with_containers): + def test_get_dataset_retrieval_configuration(self, db_session_with_containers: Session): """Return retrieval configuration that is persisted in SQL.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) retrieval_model = { "search_method": "semantic_search", "top_k": 5, @@ -373,6 +660,7 @@ class TestDatasetServiceRetrievalConfiguration: "reranking_enable": True, } dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, retrieval_model=retrieval_model, @@ -387,11 +675,12 @@ class TestDatasetServiceRetrievalConfiguration: assert result.retrieval_model["search_method"] == "semantic_search" assert result.retrieval_model["top_k"] == 5 - def test_update_dataset_retrieval_configuration(self, db_session_with_containers): + def test_update_dataset_retrieval_configuration(self, db_session_with_containers: Session): """Persist retrieval configuration updates through DatasetService.update_dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique="high_quality", @@ -413,6 +702,6 @@ class TestDatasetServiceRetrievalConfiguration: result = DatasetService.update_dataset(dataset.id, update_data, account) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py new file mode 100644 index 0000000000..322b67d373 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -0,0 +1,693 @@ +"""Integration tests for DocumentService.batch_update_document_status. + +This suite validates SQL-backed batch status updates with testcontainers. +It keeps database access real and only patches non-DB side effects. +""" + +import datetime +import json +from dataclasses import dataclass +from unittest.mock import call, patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.dataset import Dataset, Document +from services.dataset_service import DocumentService +from services.errors.document import DocumentIndexingError + +FIXED_TIME = datetime.datetime(2023, 1, 1, 12, 0, 0) + + +@dataclass +class UserDouble: + """Minimal user object for batch update operations.""" + + id: str + + +class DocumentBatchUpdateIntegrationDataFactory: + """Factory for creating persisted entities used in integration tests.""" + + @staticmethod + def create_dataset( + db_session_with_containers: Session, + dataset_id: str | None = None, + tenant_id: str | None = None, + name: str = "Test Dataset", + created_by: str | None = None, + ) -> Dataset: + """Create and persist a dataset.""" + dataset = Dataset( + tenant_id=tenant_id or str(uuid4()), + name=name, + data_source_type="upload_file", + created_by=created_by or str(uuid4()), + ) + if dataset_id: + dataset.id = dataset_id + + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers: Session, + dataset: Dataset, + document_id: str | None = None, + name: str = "test_document.pdf", + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + completed_at: datetime.datetime | None = None, + position: int = 1, + created_by: str | None = None, + commit: bool = True, + **kwargs, + ) -> Document: + """Create a document bound to the given dataset and persist it.""" + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type="upload_file", + data_source_info=json.dumps({"upload_file_id": str(uuid4())}), + batch=f"batch-{uuid4()}", + name=name, + created_from="web", + created_by=created_by or str(uuid4()), + doc_form="text_model", + ) + document.id = document_id or str(uuid4()) + document.enabled = enabled + document.archived = archived + document.indexing_status = indexing_status + document.completed_at = ( + completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None) + ) + + for key, value in kwargs.items(): + setattr(document, key, value) + + db_session_with_containers.add(document) + if commit: + db_session_with_containers.commit() + return document + + @staticmethod + def create_multiple_documents( + db_session_with_containers: Session, + dataset: Dataset, + document_ids: list[str], + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + ) -> list[Document]: + """Create and persist multiple documents for one dataset in a single transaction.""" + documents: list[Document] = [] + for index, doc_id in enumerate(document_ids, start=1): + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + document_id=doc_id, + name=f"document_{doc_id}.pdf", + enabled=enabled, + archived=archived, + indexing_status=indexing_status, + position=index, + commit=False, + ) + documents.append(document) + db_session_with_containers.commit() + return documents + + @staticmethod + def create_user(user_id: str | None = None) -> UserDouble: + """Create a lightweight user for update metadata fields.""" + return UserDouble(id=user_id or str(uuid4())) + + +class TestDatasetServiceBatchUpdateDocumentStatus: + """Integration coverage for batch document status updates.""" + + @pytest.fixture + def patched_dependencies(self): + """Patch non-DB collaborators only.""" + with ( + patch("services.dataset_service.redis_client") as redis_client, + patch("services.dataset_service.add_document_to_index_task") as add_task, + patch("services.dataset_service.remove_document_from_index_task") as remove_task, + patch("services.dataset_service.naive_utc_now") as naive_utc_now, + ): + naive_utc_now.return_value = FIXED_TIME + redis_client.get.return_value = None + yield { + "redis_client": redis_client, + "add_task": add_task, + "remove_task": remove_task, + "naive_utc_now": naive_utc_now, + } + + def _assert_document_enabled(self, document: Document, current_time: datetime.datetime): + """Verify enabled-state fields after action=enable.""" + assert document.enabled is True + assert document.disabled_at is None + assert document.disabled_by is None + assert document.updated_at == current_time + + def _assert_document_disabled(self, document: Document, user_id: str, current_time: datetime.datetime): + """Verify disabled-state fields after action=disable.""" + assert document.enabled is False + assert document.disabled_at == current_time + assert document.disabled_by == user_id + assert document.updated_at == current_time + + def _assert_document_archived(self, document: Document, user_id: str, current_time: datetime.datetime): + """Verify archived-state fields after action=archive.""" + assert document.archived is True + assert document.archived_at == current_time + assert document.archived_by == user_id + assert document.updated_at == current_time + + def _assert_document_unarchived(self, document: Document): + """Verify unarchived-state fields after action=un_archive.""" + assert document.archived is False + assert document.archived_at is None + assert document.archived_by is None + + def test_batch_update_enable_documents_success(self, db_session_with_containers: Session, patched_dependencies): + """Enable disabled documents and trigger indexing side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()), str(uuid4())] + disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, + dataset=dataset, + document_ids=document_ids, + enabled=False, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=document_ids, action="enable", user=user + ) + + # Assert + for document in disabled_docs: + db_session_with_containers.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + + expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_add_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].get.assert_has_calls(expected_get_calls) + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls) + + def test_batch_update_enable_already_enabled_document_skipped( + self, db_session_with_containers: Session, patched_dependencies + ): + """Skip enable operation for already-enabled documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + assert document.enabled is True + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_disable_documents_success(self, db_session_with_containers: Session, patched_dependencies): + """Disable completed documents and trigger remove-index tasks.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()), str(uuid4())] + enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, + dataset=dataset, + document_ids=document_ids, + enabled=True, + indexing_status="completed", + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="disable", + user=user, + ) + + # Assert + for document in enabled_docs: + db_session_with_containers.refresh(document) + self._assert_document_disabled(document, user.id, FIXED_TIME) + + expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_remove_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].get.assert_has_calls(expected_get_calls) + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls) + + def test_batch_update_disable_already_disabled_document_skipped( + self, db_session_with_containers: Session, patched_dependencies + ): + """Skip disable operation for already-disabled documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=False, + indexing_status="completed", + completed_at=FIXED_TIME, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[disabled_doc.id], + action="disable", + user=user, + ) + + # Assert + db_session_with_containers.refresh(disabled_doc) + assert disabled_doc.enabled is False + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_disable_non_completed_document_error( + self, db_session_with_containers: Session, patched_dependencies + ): + """Raise error when disabling a non-completed document.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + indexing_status="indexing", + completed_at=None, + ) + + # Act / Assert + with pytest.raises(DocumentIndexingError, match="is not completed"): + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[non_completed_doc.id], + action="disable", + user=user, + ) + + def test_batch_update_archive_documents_success(self, db_session_with_containers: Session, patched_dependencies): + """Archive enabled documents and trigger remove-index task.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + self._assert_document_archived(document, user.id, FIXED_TIME) + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + patched_dependencies["remove_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_archive_already_archived_document_skipped( + self, db_session_with_containers: Session, patched_dependencies + ): + """Skip archive operation for already-archived documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + assert document.archived is True + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_archive_disabled_document_no_index_removal( + self, db_session_with_containers: Session, patched_dependencies + ): + """Archive disabled document without index-removal side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + self._assert_document_archived(document, user.id, FIXED_TIME) + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_unarchive_documents_success(self, db_session_with_containers: Session, patched_dependencies): + """Unarchive enabled documents and trigger add-index task.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + self._assert_document_unarchived(document) + assert document.updated_at == FIXED_TIME + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + patched_dependencies["add_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_unarchive_already_unarchived_document_skipped( + self, db_session_with_containers: Session, patched_dependencies + ): + """Skip unarchive operation for already-unarchived documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + assert document.archived is False + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_unarchive_disabled_document_no_index_addition( + self, db_session_with_containers: Session, patched_dependencies + ): + """Unarchive disabled document without index-add side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + self._assert_document_unarchived(document) + assert document.updated_at == FIXED_TIME + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_document_indexing_error_redis_cache_hit( + self, db_session_with_containers: Session, patched_dependencies + ): + """Raise DocumentIndexingError when redis indicates active indexing.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + name="test_document.pdf", + enabled=True, + ) + patched_dependencies["redis_client"].get.return_value = "indexing" + + # Act / Assert + with pytest.raises(DocumentIndexingError, match="is being indexed") as exc_info: + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + assert "test_document.pdf" in str(exc_info.value) + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + + def test_batch_update_async_task_error_handling(self, db_session_with_containers: Session, patched_dependencies): + """Persist DB update, then propagate async task error.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) + patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error") + + # Act / Assert + with pytest.raises(Exception, match="Celery task error"): + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + db_session_with_containers.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + + def test_batch_update_empty_document_list(self, db_session_with_containers: Session, patched_dependencies): + """Return early when document_ids is empty.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + + # Act + result = DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[], action="enable", user=user + ) + + # Assert + assert result is None + patched_dependencies["redis_client"].get.assert_not_called() + patched_dependencies["redis_client"].setex.assert_not_called() + + def test_batch_update_document_not_found_skipped(self, db_session_with_containers: Session, patched_dependencies): + """Skip IDs that do not map to existing dataset documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + missing_document_id = str(uuid4()) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[missing_document_id], + action="enable", + user=user, + ) + + # Assert + patched_dependencies["redis_client"].get.assert_not_called() + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_mixed_document_states_and_actions( + self, db_session_with_containers: Session, patched_dependencies + ): + """Process only the applicable document in a mixed-state enable batch.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) + enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + position=2, + ) + archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + archived=True, + position=3, + ) + document_ids = [disabled_doc.id, enabled_doc.id, archived_doc.id] + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + db_session_with_containers.refresh(disabled_doc) + db_session_with_containers.refresh(enabled_doc) + db_session_with_containers.refresh(archived_doc) + self._assert_document_enabled(disabled_doc, FIXED_TIME) + assert enabled_doc.enabled is True + assert archived_doc.enabled is True + + patched_dependencies["redis_client"].setex.assert_called_once_with( + f"document_{disabled_doc.id}_indexing", + 600, + 1, + ) + patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id) + + def test_batch_update_large_document_list_performance( + self, db_session_with_containers: Session, patched_dependencies + ): + """Handle large document lists with consistent updates and side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()) for _ in range(100)] + documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, + dataset=dataset, + document_ids=document_ids, + enabled=False, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + for document in documents: + db_session_with_containers.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + + assert patched_dependencies["redis_client"].setex.call_count == len(document_ids) + assert patched_dependencies["add_task"].delay.call_count == len(document_ids) + + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_task_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) + + def test_batch_update_mixed_document_states_complex_scenario( + self, db_session_with_containers: Session, patched_dependencies + ): + """Process a complex mixed-state batch and update only eligible records.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) + doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=2 + ) + doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=3 + ) + doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=4 + ) + doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + archived=True, + position=5, + ) + missing_id = str(uuid4()) + + document_ids = [doc1.id, doc2.id, doc3.id, doc4.id, doc5.id, missing_id] + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + db_session_with_containers.refresh(doc3) + db_session_with_containers.refresh(doc4) + db_session_with_containers.refresh(doc5) + self._assert_document_enabled(doc1, FIXED_TIME) + assert doc2.enabled is True + assert doc3.enabled is True + assert doc4.enabled is True + assert doc5.enabled is True + + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1) + patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index 6effe795e2..e78894fcae 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -10,7 +10,8 @@ Tests the retrieval of document segments with pagination and filtering: from uuid import uuid4 -from extensions.ext_database import db +from sqlalchemy.orm import Session + from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from services.dataset_service import SegmentService @@ -23,6 +24,7 @@ class SegmentServiceTestDataFactory: @staticmethod def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER, tenant: Tenant | None = None, ) -> tuple[Account, Tenant]: @@ -33,13 +35,13 @@ class SegmentServiceTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() if tenant is None: tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -47,14 +49,14 @@ class SegmentServiceTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod - def create_dataset(tenant_id: str, created_by: str) -> Dataset: + def create_dataset(db_session_with_containers: Session, tenant_id: str, created_by: str) -> Dataset: """Create a real dataset.""" dataset = Dataset( tenant_id=tenant_id, @@ -67,12 +69,14 @@ class SegmentServiceTestDataFactory: provider="vendor", retrieval_model={"top_k": 2}, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_document(tenant_id: str, dataset_id: str, created_by: str) -> Document: + def create_document( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str + ) -> Document: """Create a real document.""" document = Document( tenant_id=tenant_id, @@ -84,12 +88,13 @@ class SegmentServiceTestDataFactory: created_from="api", created_by=created_by, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document @staticmethod def create_segment( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, document_id: str, @@ -112,8 +117,8 @@ class SegmentServiceTestDataFactory: tokens=tokens, created_by=created_by, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment @@ -130,7 +135,7 @@ class TestSegmentServiceGetSegments: - Combined filters """ - def test_get_segments_basic_pagination(self, db_session_with_containers): + def test_get_segments_basic_pagination(self, db_session_with_containers: Session): """ Test basic pagination functionality. @@ -140,11 +145,14 @@ class TestSegmentServiceGetSegments: - Returns segments and total count """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) segment1 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -153,6 +161,7 @@ class TestSegmentServiceGetSegments: content="First segment", ) segment2 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -170,7 +179,7 @@ class TestSegmentServiceGetSegments: assert items[0].id == segment1.id assert items[1].id == segment2.id - def test_get_segments_with_status_filter(self, db_session_with_containers): + def test_get_segments_with_status_filter(self, db_session_with_containers: Session): """ Test filtering by status list. @@ -179,11 +188,14 @@ class TestSegmentServiceGetSegments: - Only segments with matching status are returned """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -192,6 +204,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -200,6 +213,7 @@ class TestSegmentServiceGetSegments: status="indexing", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -219,7 +233,7 @@ class TestSegmentServiceGetSegments: statuses = {item.status for item in items} assert statuses == {"completed", "indexing"} - def test_get_segments_with_empty_status_list(self, db_session_with_containers): + def test_get_segments_with_empty_status_list(self, db_session_with_containers: Session): """ Test with empty status list. @@ -228,11 +242,14 @@ class TestSegmentServiceGetSegments: - No status filter is applied to avoid WHERE false condition """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -241,6 +258,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -256,7 +274,7 @@ class TestSegmentServiceGetSegments: assert len(items) == 2 assert total == 2 - def test_get_segments_with_keyword_search(self, db_session_with_containers): + def test_get_segments_with_keyword_search(self, db_session_with_containers: Session): """ Test keyword search functionality. @@ -265,11 +283,14 @@ class TestSegmentServiceGetSegments: - Search pattern includes wildcards (%keyword%) """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -278,6 +299,7 @@ class TestSegmentServiceGetSegments: content="This contains search term in the middle", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -294,7 +316,7 @@ class TestSegmentServiceGetSegments: assert total == 1 assert "search term" in items[0].content - def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers): + def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers: Session): """ Test ordering by position and id. @@ -304,12 +326,15 @@ class TestSegmentServiceGetSegments: - This prevents duplicate data across pages when positions are not unique """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create segments with different positions seg_pos2 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -318,6 +343,7 @@ class TestSegmentServiceGetSegments: content="Position 2", ) seg_pos1 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -326,6 +352,7 @@ class TestSegmentServiceGetSegments: content="Position 1", ) seg_pos3 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -344,7 +371,7 @@ class TestSegmentServiceGetSegments: assert items[1].id == seg_pos2.id assert items[2].id == seg_pos3.id - def test_get_segments_empty_results(self, db_session_with_containers): + def test_get_segments_empty_results(self, db_session_with_containers: Session): """ Test when no segments match the criteria. @@ -353,7 +380,7 @@ class TestSegmentServiceGetSegments: - Total count is 0 """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) non_existent_doc_id = str(uuid4()) # Act @@ -363,7 +390,7 @@ class TestSegmentServiceGetSegments: assert items == [] assert total == 0 - def test_get_segments_combined_filters(self, db_session_with_containers): + def test_get_segments_combined_filters(self, db_session_with_containers: Session): """ Test with multiple filters combined. @@ -372,12 +399,15 @@ class TestSegmentServiceGetSegments: - Status list and keyword search both applied """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create segments with various statuses and content SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -387,6 +417,7 @@ class TestSegmentServiceGetSegments: content="This is important information", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -396,6 +427,7 @@ class TestSegmentServiceGetSegments: content="This is also important", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -421,7 +453,7 @@ class TestSegmentServiceGetSegments: assert items[0].status == "completed" assert "important" in items[0].content - def test_get_segments_with_none_status_list(self, db_session_with_containers): + def test_get_segments_with_none_status_list(self, db_session_with_containers: Session): """ Test with None status list. @@ -430,11 +462,14 @@ class TestSegmentServiceGetSegments: - No status filter is applied """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -443,6 +478,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -462,7 +498,7 @@ class TestSegmentServiceGetSegments: assert len(items) == 2 assert total == 2 - def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers): + def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers: Session): """ Test that max_per_page is correctly set to 100. @@ -471,13 +507,16 @@ class TestSegmentServiceGetSegments: - This prevents excessive page sizes """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create 105 segments to exceed max_per_page of 100 for i in range(105): SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index f605a286ed..8bd994937a 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -13,7 +13,8 @@ This test suite covers: import json from uuid import uuid4 -from extensions.ext_database import db +from sqlalchemy.orm import Session + from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -31,7 +32,9 @@ class DatasetRetrievalTestDataFactory: """Factory class for creating database-backed test data for dataset retrieval integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.NORMAL) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL + ) -> tuple[Account, Tenant]: """Create an account and tenant with the specified role.""" account = Account( email=f"{uuid4()}@example.com", @@ -43,8 +46,8 @@ class DatasetRetrievalTestDataFactory: name=f"tenant-{uuid4()}", status="normal", ) - db.session.add_all([account, tenant]) - db.session.flush() + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -52,14 +55,16 @@ class DatasetRetrievalTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod - def create_account_in_tenant(tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER) -> Account: + def create_account_in_tenant( + db_session_with_containers: Session, tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> Account: """Create an account and add it to an existing tenant.""" account = Account( email=f"{uuid4()}@example.com", @@ -67,8 +72,8 @@ class DatasetRetrievalTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.flush() + db_session_with_containers.add(account) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -76,14 +81,15 @@ class DatasetRetrievalTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -101,12 +107,14 @@ class DatasetRetrievalTestDataFactory: provider="vendor", retrieval_model={"top_k": 2}, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_dataset_permission(dataset_id: str, tenant_id: str, account_id: str) -> DatasetPermission: + def create_dataset_permission( + db_session_with_containers: Session, dataset_id: str, tenant_id: str, account_id: str + ) -> DatasetPermission: """Create a dataset permission.""" permission = DatasetPermission( dataset_id=dataset_id, @@ -114,12 +122,14 @@ class DatasetRetrievalTestDataFactory: account_id=account_id, has_permission=True, ) - db.session.add(permission) - db.session.commit() + db_session_with_containers.add(permission) + db_session_with_containers.commit() return permission @staticmethod - def create_process_rule(dataset_id: str, created_by: str, mode: str, rules: dict) -> DatasetProcessRule: + def create_process_rule( + db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict + ) -> DatasetProcessRule: """Create a dataset process rule.""" process_rule = DatasetProcessRule( dataset_id=dataset_id, @@ -127,12 +137,14 @@ class DatasetRetrievalTestDataFactory: mode=mode, rules=json.dumps(rules), ) - db.session.add(process_rule) - db.session.commit() + db_session_with_containers.add(process_rule) + db_session_with_containers.commit() return process_rule @staticmethod - def create_dataset_query(dataset_id: str, created_by: str, content: str) -> DatasetQuery: + def create_dataset_query( + db_session_with_containers: Session, dataset_id: str, created_by: str, content: str + ) -> DatasetQuery: """Create a dataset query.""" dataset_query = DatasetQuery( dataset_id=dataset_id, @@ -142,23 +154,23 @@ class DatasetRetrievalTestDataFactory: created_by_role="account", created_by=created_by, ) - db.session.add(dataset_query) - db.session.commit() + db_session_with_containers.add(dataset_query) + db_session_with_containers.commit() return dataset_query @staticmethod - def create_app_dataset_join(dataset_id: str) -> AppDatasetJoin: + def create_app_dataset_join(db_session_with_containers: Session, dataset_id: str) -> AppDatasetJoin: """Create an app-dataset join.""" join = AppDatasetJoin( app_id=str(uuid4()), dataset_id=dataset_id, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return join @staticmethod - def create_tag_binding(tenant_id: str, created_by: str, target_id: str) -> Tag: + def create_tag_binding(db_session_with_containers: Session, tenant_id: str, created_by: str, target_id: str) -> Tag: """Create a knowledge tag and bind it to the target dataset.""" tag = Tag( tenant_id=tenant_id, @@ -166,8 +178,8 @@ class DatasetRetrievalTestDataFactory: name=f"tag-{uuid4()}", created_by=created_by, ) - db.session.add(tag) - db.session.flush() + db_session_with_containers.add(tag) + db_session_with_containers.flush() binding = TagBinding( tenant_id=tenant_id, @@ -175,8 +187,8 @@ class DatasetRetrievalTestDataFactory: target_id=target_id, created_by=created_by, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return tag @@ -195,15 +207,16 @@ class TestDatasetServiceGetDatasets: # ==================== Basic Retrieval Tests ==================== - def test_get_datasets_basic_pagination(self, db_session_with_containers): + def test_get_datasets_basic_pagination(self, db_session_with_containers: Session): """Test basic pagination without user or filters.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 for i in range(5): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=f"Dataset {i}", @@ -217,21 +230,23 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 5 assert total == 5 - def test_get_datasets_with_search(self, db_session_with_containers): + def test_get_datasets_with_search(self, db_session_with_containers: Session): """Test get_datasets with search keyword.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 search = "test" DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Test Dataset", permission=DatasetPermissionEnum.ALL_TEAM, ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Another Dataset", @@ -245,26 +260,32 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_with_tag_filtering(self, db_session_with_containers): + def test_get_datasets_with_tag_filtering(self, db_session_with_containers: Session): """Test get_datasets with tag_ids filtering.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 dataset_1 = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) dataset_2 = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) - tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_1.id) - tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_2.id) + tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding( + db_session_with_containers, tenant.id, account.id, dataset_1.id + ) + tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding( + db_session_with_containers, tenant.id, account.id, dataset_2.id + ) tag_ids = [tag_1.id, tag_2.id] # Act @@ -274,16 +295,17 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 2 assert total == 2 - def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers): + def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session): """Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 tag_ids = [] for i in range(3): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=f"dataset-{i}", @@ -300,19 +322,21 @@ class TestDatasetServiceGetDatasets: # ==================== Permission-Based Filtering Tests ==================== - def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers): + def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers: Session): """Test that without user, only ALL_TEAM datasets are shown.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ONLY_ME, @@ -325,15 +349,18 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_owner_with_include_all(self, db_session_with_containers): + def test_get_datasets_owner_with_include_all(self, db_session_with_containers: Session): """Test that OWNER with include_all=True sees all datasets.""" # Arrange - owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) for i, permission in enumerate( [DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM] ): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, name=f"dataset-{i}", @@ -353,12 +380,15 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 3 assert total == 3 - def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers: Session): """Test that normal user sees ONLY_ME datasets they created.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, permission=DatasetPermissionEnum.ONLY_ME, @@ -371,13 +401,18 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers: Session): """Test that normal user sees ALL_TEAM datasets.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER + ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ALL_TEAM, @@ -390,18 +425,25 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers: Session): """Test that normal user sees PARTIAL_TEAM datasets they have permission for.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER + ) dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.PARTIAL_TEAM, ) - DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, user.id) + DatasetRetrievalTestDataFactory.create_dataset_permission( + db_session_with_containers, dataset.id, tenant.id, user.id + ) # Act datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) @@ -410,20 +452,25 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers): + def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers: Session): """Test that DATASET_OPERATOR only sees datasets they have explicit permission for.""" # Arrange operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( - role=TenantAccountRole.DATASET_OPERATOR + db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER ) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ONLY_ME, ) - DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, operator.id) + DatasetRetrievalTestDataFactory.create_dataset_permission( + db_session_with_containers, dataset.id, tenant.id, operator.id + ) # Act datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) @@ -432,14 +479,17 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers): + def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers: Session): """Test that DATASET_OPERATOR without permissions returns empty result.""" # Arrange operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( - role=TenantAccountRole.DATASET_OPERATOR + db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER ) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ALL_TEAM, @@ -456,11 +506,13 @@ class TestDatasetServiceGetDatasets: class TestDatasetServiceGetDataset: """Comprehensive integration tests for DatasetService.get_dataset method.""" - def test_get_dataset_success(self, db_session_with_containers): + def test_get_dataset_success(self, db_session_with_containers: Session): """Test successful retrieval of a single dataset.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_dataset(dataset.id) @@ -469,7 +521,7 @@ class TestDatasetServiceGetDataset: assert result is not None assert result.id == dataset.id - def test_get_dataset_not_found(self, db_session_with_containers): + def test_get_dataset_not_found(self, db_session_with_containers: Session): """Test retrieval when dataset doesn't exist.""" # Arrange dataset_id = str(uuid4()) @@ -484,12 +536,15 @@ class TestDatasetServiceGetDataset: class TestDatasetServiceGetDatasetsByIds: """Comprehensive integration tests for DatasetService.get_datasets_by_ids method.""" - def test_get_datasets_by_ids_success(self, db_session_with_containers): + def test_get_datasets_by_ids_success(self, db_session_with_containers: Session): """Test successful bulk retrieval of datasets by IDs.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) datasets = [ - DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) for _ in range(3) + DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) + for _ in range(3) ] dataset_ids = [dataset.id for dataset in datasets] @@ -501,7 +556,7 @@ class TestDatasetServiceGetDatasetsByIds: assert total == 3 assert all(dataset.id in dataset_ids for dataset in result_datasets) - def test_get_datasets_by_ids_empty_list(self, db_session_with_containers): + def test_get_datasets_by_ids_empty_list(self, db_session_with_containers: Session): """Test get_datasets_by_ids with empty list returns empty result.""" # Arrange tenant_id = str(uuid4()) @@ -514,7 +569,7 @@ class TestDatasetServiceGetDatasetsByIds: assert datasets == [] assert total == 0 - def test_get_datasets_by_ids_none_list(self, db_session_with_containers): + def test_get_datasets_by_ids_none_list(self, db_session_with_containers: Session): """Test get_datasets_by_ids with None returns empty result.""" # Arrange tenant_id = str(uuid4()) @@ -530,17 +585,20 @@ class TestDatasetServiceGetDatasetsByIds: class TestDatasetServiceGetProcessRules: """Comprehensive integration tests for DatasetService.get_process_rules method.""" - def test_get_process_rules_with_existing_rule(self, db_session_with_containers): + def test_get_process_rules_with_existing_rule(self, db_session_with_containers: Session): """Test retrieval of process rules when rule exists.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) rules_data = { "pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}], "segmentation": {"delimiter": "\n", "max_tokens": 500}, } DatasetRetrievalTestDataFactory.create_process_rule( + db_session_with_containers, dataset_id=dataset.id, created_by=account.id, mode="custom", @@ -554,11 +612,13 @@ class TestDatasetServiceGetProcessRules: assert result["mode"] == "custom" assert result["rules"] == rules_data - def test_get_process_rules_without_existing_rule(self, db_session_with_containers): + def test_get_process_rules_without_existing_rule(self, db_session_with_containers: Session): """Test retrieval of process rules when no rule exists (returns defaults).""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_process_rules(dataset.id) @@ -572,16 +632,19 @@ class TestDatasetServiceGetProcessRules: class TestDatasetServiceGetDatasetQueries: """Comprehensive integration tests for DatasetService.get_dataset_queries method.""" - def test_get_dataset_queries_success(self, db_session_with_containers): + def test_get_dataset_queries_success(self, db_session_with_containers: Session): """Test successful retrieval of dataset queries.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) page = 1 per_page = 20 for i in range(3): DatasetRetrievalTestDataFactory.create_dataset_query( + db_session_with_containers, dataset_id=dataset.id, created_by=account.id, content=f"query-{i}", @@ -595,11 +658,13 @@ class TestDatasetServiceGetDatasetQueries: assert total == 3 assert all(query.dataset_id == dataset.id for query in queries) - def test_get_dataset_queries_empty_result(self, db_session_with_containers): + def test_get_dataset_queries_empty_result(self, db_session_with_containers: Session): """Test retrieval when no queries exist.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) page = 1 per_page = 20 @@ -614,14 +679,16 @@ class TestDatasetServiceGetDatasetQueries: class TestDatasetServiceGetRelatedApps: """Comprehensive integration tests for DatasetService.get_related_apps method.""" - def test_get_related_apps_success(self, db_session_with_containers): + def test_get_related_apps_success(self, db_session_with_containers: Session): """Test successful retrieval of related apps.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) for _ in range(2): - DatasetRetrievalTestDataFactory.create_app_dataset_join(dataset.id) + DatasetRetrievalTestDataFactory.create_app_dataset_join(db_session_with_containers, dataset.id) # Act result = DatasetService.get_related_apps(dataset.id) @@ -630,11 +697,13 @@ class TestDatasetServiceGetRelatedApps: assert len(result) == 2 assert all(join.dataset_id == dataset.id for join in result) - def test_get_related_apps_empty_result(self, db_session_with_containers): + def test_get_related_apps_empty_result(self, db_session_with_containers: Session): """Test retrieval when no related apps exist.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_related_apps(dataset.id) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index f6d9dfddae..ebaa3b4637 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,9 +2,9 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from core.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db +from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from services.dataset_service import DatasetService @@ -15,7 +15,9 @@ class DatasetUpdateTestDataFactory: """Factory class for creating real test data for dataset update integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> tuple[Account, Tenant]: """Create a real account and tenant with the given role.""" account = Account( email=f"{uuid4()}@example.com", @@ -23,12 +25,12 @@ class DatasetUpdateTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant(name=f"tenant-{account.id}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -36,14 +38,15 @@ class DatasetUpdateTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, provider: str = "vendor", @@ -71,12 +74,13 @@ class DatasetUpdateTestDataFactory: embedding_model=embedding_model, collection_binding_id=collection_binding_id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod def create_external_binding( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str, @@ -93,8 +97,8 @@ class DatasetUpdateTestDataFactory: external_knowledge_id=external_knowledge_id, external_knowledge_api_id=external_knowledge_api_id, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return binding @@ -112,10 +116,11 @@ class TestDatasetServiceUpdateDataset: # ==================== External Dataset Tests ==================== - def test_update_external_dataset_success(self, db_session_with_containers): + def test_update_external_dataset_success(self, db_session_with_containers: Session): """Test successful update of external dataset.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", @@ -124,12 +129,13 @@ class TestDatasetServiceUpdateDataset: retrieval_model="old_model", ) binding = DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, ) binding_id = binding.id - db.session.expunge(binding) + db_session_with_containers.expunge(binding) update_data = { "name": "new_name", @@ -142,8 +148,8 @@ class TestDatasetServiceUpdateDataset: result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) - updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() + db_session_with_containers.refresh(dataset) + updated_binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() assert dataset.name == "new_name" assert dataset.description == "new_description" @@ -153,15 +159,17 @@ class TestDatasetServiceUpdateDataset: assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"] assert result.id == dataset.id - def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): + def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): """Test error when external knowledge id is missing.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", ) DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, @@ -173,17 +181,19 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge id is required" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() - def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers): + def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers: Session): """Test error when external knowledge api id is missing.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", ) DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, @@ -195,12 +205,13 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge api id is required" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() - def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers): + def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers: Session): """Test error when external knowledge binding is not found.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", @@ -216,15 +227,16 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge binding not found" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() # ==================== Internal Dataset Basic Tests ==================== - def test_update_internal_dataset_basic_success(self, db_session_with_containers): + def test_update_internal_dataset_basic_success(self, db_session_with_containers: Session): """Test successful update of internal dataset with basic fields.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -244,7 +256,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.description == "new_description" @@ -254,11 +266,12 @@ class TestDatasetServiceUpdateDataset: assert dataset.embedding_model == "text-embedding-ada-002" assert result.id == dataset.id - def test_update_internal_dataset_filter_none_values(self, db_session_with_containers): + def test_update_internal_dataset_filter_none_values(self, db_session_with_containers: Session): """Test that None values are filtered out except for description field.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -278,7 +291,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.description is None @@ -289,11 +302,12 @@ class TestDatasetServiceUpdateDataset: # ==================== Indexing Technique Switch Tests ==================== - def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers): + def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers: Session): """Test updating internal dataset indexing technique to economy.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -312,7 +326,7 @@ class TestDatasetServiceUpdateDataset: result = DatasetService.update_dataset(dataset.id, update_data, user) mock_task.delay.assert_called_once_with(dataset.id, "remove") - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.indexing_technique == "economy" assert dataset.embedding_model is None assert dataset.embedding_model_provider is None @@ -320,10 +334,11 @@ class TestDatasetServiceUpdateDataset: assert dataset.retrieval_model == "new_model" assert result.id == dataset.id - def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers): + def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers: Session): """Test updating internal dataset indexing technique to high_quality.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -366,7 +381,7 @@ class TestDatasetServiceUpdateDataset: mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002") mock_task.delay.assert_called_once_with(dataset.id, "add") - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.indexing_technique == "high_quality" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" @@ -380,9 +395,10 @@ class TestDatasetServiceUpdateDataset: self, db_session_with_containers ): """Test preserving embedding settings when indexing technique remains unchanged.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -399,7 +415,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.indexing_technique == "high_quality" @@ -409,11 +425,12 @@ class TestDatasetServiceUpdateDataset: assert dataset.retrieval_model == "new_model" assert result.id == dataset.id - def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers): + def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers: Session): """Test updating internal dataset with new embedding model.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -465,7 +482,7 @@ class TestDatasetServiceUpdateDataset: regenerate_vectors_only=True, ) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.embedding_model == "text-embedding-3-small" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -474,9 +491,9 @@ class TestDatasetServiceUpdateDataset: # ==================== Error Handling Tests ==================== - def test_update_dataset_not_found_error(self, db_session_with_containers): + def test_update_dataset_not_found_error(self, db_session_with_containers: Session): """Test error when dataset is not found.""" - user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) update_data = {"name": "new_name"} with pytest.raises(ValueError) as context: @@ -484,11 +501,16 @@ class TestDatasetServiceUpdateDataset: assert "Dataset not found" in str(context.value) - def test_update_dataset_permission_error(self, db_session_with_containers): + def test_update_dataset_permission_error(self, db_session_with_containers: Session): """Test error when user doesn't have permission.""" - owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, provider="vendor", @@ -500,10 +522,11 @@ class TestDatasetServiceUpdateDataset: with pytest.raises(NoPermissionError): DatasetService.update_dataset(dataset.id, update_data, outsider) - def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers): + def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers: Session): """Test error when embedding model is not available.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 546292109e..5f86cb2ae9 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -7,7 +7,7 @@ from uuid import uuid4 from sqlalchemy import select -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py new file mode 100644 index 0000000000..f641da6576 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -0,0 +1,252 @@ +"""Container-backed integration tests for DocumentService.rename_document real SQL paths.""" + +import datetime +import json +from unittest.mock import create_autospec, patch +from uuid import uuid4 + +import pytest + +from models import Account +from models.dataset import Dataset, Document +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.dataset_service import DocumentService + +FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0) + + +@pytest.fixture +def mock_env(): + """Patch only non-SQL dependency used by rename_document: current_user context.""" + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.current_tenant_id = str(uuid4()) + current_user.id = str(uuid4()) + yield {"current_user": current_user} + + +def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, built_in_field_enabled=False): + """Persist a dataset row for rename_document integration scenarios.""" + dataset_id = dataset_id or str(uuid4()) + tenant_id = tenant_id or str(uuid4()) + + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + dataset.id = dataset_id + dataset.built_in_field_enabled = built_in_field_enabled + + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + +def make_document( + db_session_with_containers, + document_id=None, + dataset_id=None, + tenant_id=None, + name="Old Name", + data_source_info=None, + doc_metadata=None, +): + """Persist a document row used by rename_document integration scenarios.""" + document_id = document_id or str(uuid4()) + dataset_id = dataset_id or str(uuid4()) + tenant_id = tenant_id or str(uuid4()) + + doc = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type="upload_file", + data_source_info=json.dumps(data_source_info or {}), + batch=f"batch-{uuid4()}", + name=name, + created_from="web", + created_by=str(uuid4()), + doc_form="text_model", + ) + doc.id = document_id + doc.indexing_status = "completed" + doc.doc_metadata = dict(doc_metadata or {}) + + db_session_with_containers.add(doc) + db_session_with_containers.commit() + return doc + + +def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, name: str): + """Persist an upload file row referenced by document.data_source_info.""" + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key=f"uploads/{uuid4()}", + name=name, + size=128, + extension="pdf", + mime_type="application/pdf", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=FIXED_UPLOAD_CREATED_AT, + used=False, + ) + upload_file.id = file_id + + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + +def test_rename_document_success(db_session_with_containers, mock_env): + """Rename succeeds and returns the renamed document identity by id.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "New Document Name" + dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id) + document = make_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset_id, + tenant_id=mock_env["current_user"].current_tenant_id, + ) + + # Act + result = DocumentService.rename_document(dataset.id, document_id, new_name) + + # Assert + db_session_with_containers.refresh(document) + assert result.id == document.id + assert document.name == new_name + + +def test_rename_document_with_built_in_fields(db_session_with_containers, mock_env): + """Built-in document_name metadata is updated while existing metadata keys are preserved.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "Renamed" + dataset = make_dataset( + db_session_with_containers, + dataset_id, + mock_env["current_user"].current_tenant_id, + built_in_field_enabled=True, + ) + document = make_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=mock_env["current_user"].current_tenant_id, + doc_metadata={"foo": "bar"}, + ) + + # Act + DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + assert document.name == new_name + assert document.doc_metadata["document_name"] == new_name + assert document.doc_metadata["foo"] == "bar" + + +def test_rename_document_updates_upload_file_when_present(db_session_with_containers, mock_env): + """Rename propagates to UploadFile.name when upload_file_id is present in data_source_info.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + file_id = str(uuid4()) + new_name = "Renamed" + dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id) + document = make_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=mock_env["current_user"].current_tenant_id, + data_source_info={"upload_file_id": file_id}, + ) + upload_file = make_upload_file( + db_session_with_containers, + tenant_id=mock_env["current_user"].current_tenant_id, + file_id=file_id, + name="old.pdf", + ) + + # Act + DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(upload_file) + assert document.name == new_name + assert upload_file.name == new_name + + +def test_rename_document_does_not_update_upload_file_when_missing_id(db_session_with_containers, mock_env): + """Rename does not update UploadFile when data_source_info lacks upload_file_id.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "Another Name" + dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id) + document = make_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=mock_env["current_user"].current_tenant_id, + data_source_info={"url": "https://example.com"}, + ) + untouched_file = make_upload_file( + db_session_with_containers, + tenant_id=mock_env["current_user"].current_tenant_id, + file_id=str(uuid4()), + name="untouched.pdf", + ) + + # Act + DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(untouched_file) + assert document.name == new_name + assert untouched_file.name == "untouched.pdf" + + +def test_rename_document_dataset_not_found(db_session_with_containers, mock_env): + """Rename raises Dataset not found when dataset id does not exist.""" + # Arrange + missing_dataset_id = str(uuid4()) + + # Act / Assert + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document(missing_dataset_id, str(uuid4()), "x") + + +def test_rename_document_not_found(db_session_with_containers, mock_env): + """Rename raises Document not found when document id is absent in the dataset.""" + # Arrange + dataset = make_dataset(db_session_with_containers, str(uuid4()), mock_env["current_user"].current_tenant_id) + + # Act / Assert + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, str(uuid4()), "x") + + +def test_rename_document_permission_denied_when_tenant_mismatch(db_session_with_containers, mock_env): + """Rename raises No permission when document tenant differs from current_user tenant.""" + # Arrange + dataset = make_dataset(db_session_with_containers, str(uuid4()), mock_env["current_user"].current_tenant_id) + document = make_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=str(uuid4()), + ) + + # Act / Assert + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "x") diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 93516a0030..6712fe8454 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -5,6 +5,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker from sqlalchemy import Engine +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config @@ -19,7 +20,7 @@ class TestFileService: """Integration tests for FileService using testcontainers.""" @pytest.fixture - def engine(self, db_session_with_containers): + def engine(self, db_session_with_containers: Session): bind = db_session_with_containers.get_bind() assert isinstance(bind, Engine) return bind @@ -46,7 +47,7 @@ class TestFileService: "extract_processor": mock_extract_processor, } - def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account for testing. @@ -67,18 +68,16 @@ class TestFileService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -89,15 +88,15 @@ class TestFileService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account - def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test end user for testing. @@ -118,14 +117,14 @@ class TestFileService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account): + def _create_test_upload_file( + self, db_session_with_containers: Session, mock_external_service_dependencies, account + ): """ Helper method to create a test upload file for testing. @@ -155,15 +154,13 @@ class TestFileService: source_url="", ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file # Test upload_file method - def test_upload_file_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies): """ Test successful file upload with valid parameters. """ @@ -196,7 +193,9 @@ class TestFileService: assert upload_file.id is not None - def test_upload_file_with_end_user(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_with_end_user( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with end user instead of account. """ @@ -219,7 +218,7 @@ class TestFileService: assert upload_file.created_by_role == CreatorUserRole.END_USER def test_upload_file_with_datasets_source( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with datasets source parameter. @@ -244,7 +243,7 @@ class TestFileService: assert upload_file.source_url == "https://example.com/source" def test_upload_file_invalid_filename_characters( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with invalid filename characters. @@ -265,7 +264,7 @@ class TestFileService: ) def test_upload_file_filename_too_long( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with filename that exceeds length limit. @@ -295,7 +294,7 @@ class TestFileService: assert len(base_name) <= 200 def test_upload_file_datasets_unsupported_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload for datasets with unsupported file type. @@ -316,7 +315,9 @@ class TestFileService: source="datasets", ) - def test_upload_file_too_large(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_too_large( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with file size exceeding limit. """ @@ -338,7 +339,7 @@ class TestFileService: # Test is_file_size_within_limit method def test_is_file_size_within_limit_image_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for image files within limit. @@ -351,7 +352,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_video_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for video files within limit. @@ -364,7 +365,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_audio_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for audio files within limit. @@ -377,7 +378,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_document_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for document files within limit. @@ -390,7 +391,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_image_exceeded( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for image files exceeding limit. @@ -403,7 +404,7 @@ class TestFileService: assert result is False def test_is_file_size_within_limit_unknown_extension( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for unknown file extension. @@ -416,7 +417,7 @@ class TestFileService: assert result is True # Test upload_text method - def test_upload_text_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies): """ Test successful text upload. """ @@ -447,7 +448,9 @@ class TestFileService: # Verify storage was called mock_external_service_dependencies["storage"].save.assert_called_once() - def test_upload_text_name_too_long(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_name_too_long( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test text upload with name that exceeds length limit. """ @@ -472,7 +475,9 @@ class TestFileService: assert upload_file.name == "a" * 200 # Test get_file_preview method - def test_get_file_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_get_file_preview_success( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test successful file preview generation. """ @@ -484,9 +489,8 @@ class TestFileService: # Update file to have document extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() result = FileService(engine).get_file_preview(file_id=upload_file.id) @@ -494,7 +498,7 @@ class TestFileService: mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once() def test_get_file_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with non-existent file. @@ -506,7 +510,7 @@ class TestFileService: FileService(engine).get_file_preview(file_id=non_existent_id) def test_get_file_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with unsupported file type. @@ -519,15 +523,14 @@ class TestFileService: # Update file to have non-document extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): FileService(engine).get_file_preview(file_id=upload_file.id) def test_get_file_preview_text_truncation( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with text that exceeds preview limit. @@ -540,9 +543,8 @@ class TestFileService: # Update file to have document extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Mock long text content long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT @@ -554,7 +556,9 @@ class TestFileService: assert result == "x" * 3000 # Test get_image_preview method - def test_get_image_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_get_image_preview_success( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test successful image preview generation. """ @@ -566,9 +570,8 @@ class TestFileService: # Update file to have image extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() timestamp = "1234567890" nonce = "test_nonce" @@ -586,7 +589,7 @@ class TestFileService: mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once() def test_get_image_preview_invalid_signature( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with invalid signature. @@ -613,7 +616,7 @@ class TestFileService: ) def test_get_image_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with non-existent file. @@ -634,7 +637,7 @@ class TestFileService: ) def test_get_image_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with non-image file type. @@ -647,9 +650,8 @@ class TestFileService: # Update file to have non-image extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() timestamp = "1234567890" nonce = "test_nonce" @@ -665,7 +667,7 @@ class TestFileService: # Test get_file_generator_by_file_id method def test_get_file_generator_by_file_id_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test successful file generator retrieval. @@ -692,7 +694,7 @@ class TestFileService: mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once() def test_get_file_generator_by_file_id_invalid_signature( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file generator retrieval with invalid signature. @@ -719,7 +721,7 @@ class TestFileService: ) def test_get_file_generator_by_file_id_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file generator retrieval with non-existent file. @@ -741,7 +743,7 @@ class TestFileService: # Test get_public_image_preview method def test_get_public_image_preview_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test successful public image preview generation. @@ -754,9 +756,8 @@ class TestFileService: # Update file to have image extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() generator, mime_type = FileService(engine).get_public_image_preview(file_id=upload_file.id) @@ -765,7 +766,7 @@ class TestFileService: mock_external_service_dependencies["storage"].load.assert_called_once() def test_get_public_image_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test public image preview with non-existent file. @@ -777,7 +778,7 @@ class TestFileService: FileService(engine).get_public_image_preview(file_id=non_existent_id) def test_get_public_image_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test public image preview with non-image file type. @@ -790,15 +791,16 @@ class TestFileService: # Update file to have non-image extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): FileService(engine).get_public_image_preview(file_id=upload_file.id) # Test edge cases and boundary conditions - def test_upload_file_empty_content(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_empty_content( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with empty content. """ @@ -820,7 +822,7 @@ class TestFileService: assert upload_file.size == 0 def test_upload_file_special_characters_in_name( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with special characters in filename (but valid ones). @@ -843,7 +845,7 @@ class TestFileService: assert upload_file.name == filename def test_upload_file_different_case_extensions( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with different case extensions. @@ -865,7 +867,9 @@ class TestFileService: assert upload_file is not None assert upload_file.extension == "pdf" # Should be converted to lowercase - def test_upload_text_empty_text(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_empty_text( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test text upload with empty text. """ @@ -888,7 +892,9 @@ class TestFileService: assert upload_file is not None assert upload_file.size == 0 - def test_file_size_limits_edge_cases(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_file_size_limits_edge_cases( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file size limits with edge case values. """ @@ -908,7 +914,9 @@ class TestFileService: result = FileService(engine).is_file_size_within_limit(extension=extension, file_size=file_size) assert result is False - def test_upload_file_with_source_url(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_with_source_url( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with source URL that gets overridden by signed URL. """ @@ -946,7 +954,7 @@ class TestFileService: # Test file extension blacklist def test_upload_file_blocked_extension( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with blocked extension. @@ -969,7 +977,7 @@ class TestFileService: ) def test_upload_file_blocked_extension_case_insensitive( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with blocked extension (case insensitive). @@ -992,7 +1000,9 @@ class TestFileService: user=account, ) - def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_not_in_blacklist( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with extension not in blacklist. """ @@ -1016,7 +1026,9 @@ class TestFileService: assert upload_file.name == filename assert upload_file.extension == "pdf" - def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_empty_blacklist( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with empty blacklist (default behavior). """ @@ -1041,7 +1053,7 @@ class TestFileService: assert upload_file.extension == "sh" def test_upload_file_multiple_blocked_extensions( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with multiple blocked extensions. @@ -1066,7 +1078,7 @@ class TestFileService: ) def test_upload_file_no_extension_with_blacklist( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with no extension when blacklist is configured. diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 9c978f830f..08f99cf55a 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import ( +from dify_graph.enums import NodeType +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index ece6de6cdf..19a684a58a 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.model import MessageFeedback from services.app_service import AppService @@ -69,7 +70,7 @@ class TestMessageService: # "current_user": mock_current_user, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -127,11 +128,10 @@ class TestMessageService: # mock_external_service_dependencies["current_user"].id = account_id # mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id - def _create_test_conversation(self, app, account, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake): """ Helper method to create a test conversation with all required fields. """ - from extensions.ext_database import db from models.model import Conversation conversation = Conversation( @@ -153,17 +153,16 @@ class TestMessageService: from_account_id=account.id, ) - db.session.add(conversation) - db.session.flush() + db_session_with_containers.add(conversation) + db_session_with_containers.flush() return conversation - def _create_test_message(self, app, conversation, account, fake): + def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake): """ Helper method to create a test message with all required fields. """ import json - from extensions.ext_database import db from models.model import Message message = Message( @@ -192,11 +191,13 @@ class TestMessageService: from_account_id=account.id, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message - def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_first_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by first ID. """ @@ -204,10 +205,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination by first ID @@ -228,7 +229,9 @@ class TestMessageService: # Verify messages are in ascending order assert result.data[0].created_at <= result.data[1].created_at - def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_first_id_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pagination by first ID when no user is provided. """ @@ -246,7 +249,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_first_id_no_conversation_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by first ID when no conversation ID is provided. @@ -265,7 +268,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_first_id_invalid_first_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by first ID with invalid first_id. @@ -274,8 +277,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test pagination with invalid first_id with pytest.raises(FirstMessageNotExistsError): @@ -287,7 +290,9 @@ class TestMessageService: limit=10, ) - def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by last ID. """ @@ -295,10 +300,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination by last ID @@ -319,7 +324,7 @@ class TestMessageService: assert result.data[0].created_at >= result.data[1].created_at def test_pagination_by_last_id_with_include_ids( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with include_ids filter. @@ -328,10 +333,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination with include_ids @@ -347,7 +352,9 @@ class TestMessageService: for message in result.data: assert message.id in include_ids - def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pagination by last ID when no user is provided. """ @@ -363,7 +370,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_last_id_invalid_last_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with invalid last_id. @@ -372,8 +379,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test pagination with invalid last_id with pytest.raises(LastMessageNotExistsError): @@ -385,7 +392,7 @@ class TestMessageService: conversation_id=conversation.id, ) - def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful creation of feedback. """ @@ -393,8 +400,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create feedback rating = "like" @@ -413,7 +420,7 @@ class TestMessageService: assert feedback.from_account_id == account.id assert feedback.from_end_user_id is None - def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test creating feedback when no user is provided. """ @@ -421,8 +428,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test creating feedback with no user with pytest.raises(ValueError, match="user cannot be None"): @@ -430,7 +437,9 @@ class TestMessageService: app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100) ) - def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_update_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating existing feedback. """ @@ -438,8 +447,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback initial_rating = "like" @@ -462,7 +471,9 @@ class TestMessageService: assert updated_feedback.rating != initial_rating assert updated_feedback.content != initial_content - def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_delete_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting existing feedback by setting rating to None. """ @@ -470,8 +481,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback feedback = MessageService.create_feedback( @@ -482,13 +493,14 @@ class TestMessageService: MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None) # Verify feedback was deleted - from extensions.ext_database import db - deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + deleted_feedback = ( + db_session_with_containers.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + ) assert deleted_feedback is None def test_create_feedback_no_rating_when_not_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating feedback with no rating when feedback doesn't exist. @@ -497,8 +509,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test creating feedback with no rating when no feedback exists with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"): @@ -506,7 +518,9 @@ class TestMessageService: app_model=app, message_id=message.id, user=account, rating=None, content=None ) - def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_messages_feedbacks_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of all message feedbacks. """ @@ -516,8 +530,8 @@ class TestMessageService: # Create multiple conversations and messages with feedbacks feedbacks = [] for i in range(3): - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) feedback = MessageService.create_feedback( app_model=app, @@ -539,7 +553,7 @@ class TestMessageService: assert result[i]["created_at"] >= result[i + 1]["created_at"] def test_get_all_messages_feedbacks_pagination( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of message feedbacks. @@ -549,8 +563,8 @@ class TestMessageService: # Create multiple conversations and messages with feedbacks for i in range(5): - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}" @@ -569,7 +583,7 @@ class TestMessageService: page_2_ids = {feedback["id"] for feedback in result_page_2} assert len(page_1_ids.intersection(page_2_ids)) == 0 - def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of message. """ @@ -577,8 +591,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Get message retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id) @@ -590,7 +604,7 @@ class TestMessageService: assert retrieved_message.from_source == "console" assert retrieved_message.from_account_id == account.id - def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_not_exists(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting message that doesn't exist. """ @@ -601,7 +615,7 @@ class TestMessageService: with pytest.raises(MessageNotExistsError): MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4()) - def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_wrong_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting message with wrong user (different account). """ @@ -609,8 +623,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create another account from services.account_service import AccountService, TenantService @@ -628,7 +642,7 @@ class TestMessageService: MessageService.get_message(app_model=app, user=other_account, message_id=message.id) def test_get_suggested_questions_after_answer_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful generation of suggested questions after answer. @@ -637,8 +651,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock the LLMGenerator to return specific questions mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"] @@ -665,7 +679,7 @@ class TestMessageService: mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once() def test_get_suggested_questions_after_answer_no_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when no user is provided. @@ -674,8 +688,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test getting suggested questions with no user from core.app.entities.app_invoke_entities import InvokeFrom @@ -686,7 +700,7 @@ class TestMessageService: ) def test_get_suggested_questions_after_answer_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when feature is disabled. @@ -695,8 +709,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock the feature to be disabled mock_external_service_dependencies[ @@ -712,7 +726,7 @@ class TestMessageService: ) def test_get_suggested_questions_after_answer_no_workflow( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when no workflow exists. @@ -721,8 +735,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock no workflow mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None @@ -738,7 +752,7 @@ class TestMessageService: assert result == [] def test_get_suggested_questions_after_answer_debugger_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions in debugger mode. @@ -747,8 +761,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock questions mock_questions = ["Debug question 1", "Debug question 2"] diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 5b6db64c09..6fe40c0744 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -6,9 +6,9 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import ( @@ -40,25 +40,25 @@ class TestMessagesCleanServiceIntegration: PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before and after each test to ensure isolation.""" yield # Clear all test data in correct order (respecting foreign key constraints) - db.session.query(DatasetRetrieverResource).delete() - db.session.query(AppAnnotationHitHistory).delete() - db.session.query(SavedMessage).delete() - db.session.query(MessageFile).delete() - db.session.query(MessageAgentThought).delete() - db.session.query(MessageChain).delete() - db.session.query(MessageAnnotation).delete() - db.session.query(MessageFeedback).delete() - db.session.query(Message).delete() - db.session.query(Conversation).delete() - db.session.query(App).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + db_session_with_containers.query(DatasetRetrieverResource).delete() + db_session_with_containers.query(AppAnnotationHitHistory).delete() + db_session_with_containers.query(SavedMessage).delete() + db_session_with_containers.query(MessageFile).delete() + db_session_with_containers.query(MessageAgentThought).delete() + db_session_with_containers.query(MessageChain).delete() + db_session_with_containers.query(MessageAnnotation).delete() + db_session_with_containers.query(MessageFeedback).delete() + db_session_with_containers.query(Message).delete() + db_session_with_containers.query(Conversation).delete() + db_session_with_containers.query(App).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() @pytest.fixture(autouse=True) def cleanup_redis(self): @@ -100,7 +100,7 @@ class TestMessagesCleanServiceIntegration: with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): yield - def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): + def _create_account_and_tenant(self, db_session_with_containers: Session, plan: str = CloudPlan.SANDBOX): """Helper to create account and tenant.""" fake = Faker() @@ -110,28 +110,28 @@ class TestMessagesCleanServiceIntegration: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.flush() + db_session_with_containers.add(account) + db_session_with_containers.flush() tenant = Tenant( name=fake.company(), plan=str(plan), status="normal", ) - db.session.add(tenant) - db.session.flush() + db_session_with_containers.add(tenant) + db_session_with_containers.flush() tenant_account_join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER, ) - db.session.add(tenant_account_join) - db.session.commit() + db_session_with_containers.add(tenant_account_join) + db_session_with_containers.commit() return account, tenant - def _create_app(self, tenant, account): + def _create_app(self, db_session_with_containers: Session, tenant, account): """Helper to create an app.""" fake = Faker() @@ -149,12 +149,12 @@ class TestMessagesCleanServiceIntegration: created_by=account.id, updated_by=account.id, ) - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_conversation(self, app): + def _create_conversation(self, db_session_with_containers: Session, app): """Helper to create a conversation.""" conversation = Conversation( app_id=app.id, @@ -168,12 +168,14 @@ class TestMessagesCleanServiceIntegration: from_source="api", from_end_user_id=str(uuid.uuid4()), ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() return conversation - def _create_message(self, app, conversation, created_at=None, with_relations=True): + def _create_message( + self, db_session_with_containers: Session, app, conversation, created_at=None, with_relations=True + ): """Helper to create a message with optional related records.""" if created_at is None: created_at = datetime.datetime.now() @@ -197,16 +199,16 @@ class TestMessagesCleanServiceIntegration: from_account_id=conversation.from_end_user_id, created_at=created_at, ) - db.session.add(message) - db.session.flush() + db_session_with_containers.add(message) + db_session_with_containers.flush() if with_relations: - self._create_message_relations(message) + self._create_message_relations(db_session_with_containers, message) - db.session.commit() + db_session_with_containers.commit() return message - def _create_message_relations(self, message): + def _create_message_relations(self, db_session_with_containers: Session, message): """Helper to create all message-related records.""" # MessageFeedback feedback = MessageFeedback( @@ -217,7 +219,7 @@ class TestMessagesCleanServiceIntegration: from_source="api", from_end_user_id=str(uuid.uuid4()), ) - db.session.add(feedback) + db_session_with_containers.add(feedback) # MessageAnnotation annotation = MessageAnnotation( @@ -228,7 +230,7 @@ class TestMessagesCleanServiceIntegration: content="Test annotation", account_id=message.from_account_id, ) - db.session.add(annotation) + db_session_with_containers.add(annotation) # MessageChain chain = MessageChain( @@ -237,8 +239,8 @@ class TestMessagesCleanServiceIntegration: input=json.dumps({"test": "input"}), output=json.dumps({"test": "output"}), ) - db.session.add(chain) - db.session.flush() + db_session_with_containers.add(chain) + db_session_with_containers.flush() # MessageFile file = MessageFile( @@ -250,7 +252,7 @@ class TestMessagesCleanServiceIntegration: created_by_role="end_user", created_by=str(uuid.uuid4()), ) - db.session.add(file) + db_session_with_containers.add(file) # SavedMessage saved = SavedMessage( @@ -259,9 +261,9 @@ class TestMessagesCleanServiceIntegration: created_by_role="end_user", created_by=str(uuid.uuid4()), ) - db.session.add(saved) + db_session_with_containers.add(saved) - db.session.flush() + db_session_with_containers.flush() # AppAnnotationHitHistory hit = AppAnnotationHitHistory( @@ -275,7 +277,7 @@ class TestMessagesCleanServiceIntegration: annotation_question="Test annotation question", annotation_content="Test annotation content", ) - db.session.add(hit) + db_session_with_containers.add(hit) # DatasetRetrieverResource resource = DatasetRetrieverResource( @@ -296,25 +298,29 @@ class TestMessagesCleanServiceIntegration: retriever_from="dataset", created_by=message.from_account_id, ) - db.session.add(resource) + db_session_with_containers.add(resource) def test_billing_disabled_deletes_all_messages_in_time_range( - self, db_session_with_containers, mock_billing_disabled + self, db_session_with_containers: Session, mock_billing_disabled ): """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" # Arrange - Create tenant with messages (plan doesn't matter for billing disabled) - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create messages: in-range (should be deleted) and out-of-range (should be kept) in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) - in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) + in_range_msg = self._create_message( + db_session_with_containers, app, conv, created_at=in_range_date, with_relations=True + ) in_range_msg_id = in_range_msg.id - out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) + out_of_range_msg = self._create_message( + db_session_with_containers, app, conv, created_at=out_of_range_date, with_relations=True + ) out_of_range_msg_id = out_of_range_msg.id # Act - create_message_clean_policy should return BillingDisabledPolicy @@ -336,17 +342,34 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 # In-range message deleted - assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == in_range_msg_id).count() == 0 # Out-of-range message kept - assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 # Related records of in-range message deleted - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 - assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 + assert ( + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id == in_range_msg_id) + .count() + == 0 + ) + assert ( + db_session_with_containers.query(MessageAnnotation) + .where(MessageAnnotation.message_id == in_range_msg_id) + .count() + == 0 + ) # Related records of out-of-range message kept - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 + assert ( + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id == out_of_range_msg_id) + .count() + == 1 + ) - def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_no_messages_returns_empty_stats( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cleaning when there are no messages to delete (B1).""" # Arrange end_before = datetime.datetime.now() - datetime.timedelta(days=30) @@ -371,36 +394,42 @@ class TestMessagesCleanServiceIntegration: assert stats["filtered_messages"] == 0 assert stats["total_deleted"] == 0 - def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_mixed_sandbox_and_paid_tenants( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cleaning with mixed sandbox and paid tenants (B2).""" # Arrange - Create sandbox tenants with expired messages sandbox_tenants = [] sandbox_message_ids = [] for i in range(2): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) sandbox_tenants.append(tenant) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 3 expired messages per sandbox tenant expired_date = datetime.datetime.now() - datetime.timedelta(days=35) for j in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j) + ) sandbox_message_ids.append(msg.id) # Create paid tenants with expired messages (should NOT be deleted) paid_tenants = [] paid_message_ids = [] for i in range(2): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL) paid_tenants.append(tenant) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 2 expired messages per paid tenant expired_date = datetime.datetime.now() - datetime.timedelta(days=35) for j in range(2): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j) + ) paid_message_ids.append(msg.id) # Mock billing service - return plan and expiration_date @@ -442,29 +471,39 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 6 # Only sandbox messages should be deleted - assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 # Paid messages should remain - assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + assert db_session_with_containers.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 # Related records of sandbox messages should be deleted - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 assert ( - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id.in_(sandbox_message_ids)) + .count() + == 0 + ) + assert ( + db_session_with_containers.query(MessageAnnotation) + .where(MessageAnnotation.message_id.in_(sandbox_message_ids)) + .count() == 0 ) - def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_cursor_pagination_multiple_batches( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cursor pagination works correctly across multiple batches (B3).""" # Arrange - Create sandbox tenant with messages that will span multiple batches - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 10 expired messages with different timestamps base_date = datetime.datetime.now() - datetime.timedelta(days=35) message_ids = [] for i in range(10): msg = self._create_message( + db_session_with_containers, app, conv, created_at=base_date + datetime.timedelta(hours=i), @@ -498,20 +537,22 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 10 # All messages should be deleted - assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 0 - def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_dry_run_does_not_delete(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test dry_run mode does not delete messages (B4).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create expired messages expired_date = datetime.datetime.now() - datetime.timedelta(days=35) message_ids = [] for i in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i) + ) message_ids.append(msg.id) with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -540,21 +581,26 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # But NOT deleted # All messages should still exist - assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 3 # Related records should also still exist - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + assert ( + db_session_with_containers.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() + == 3 + ) - def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_partial_plan_data_safe_default( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test when billing returns partial data, unknown tenants are preserved (B5).""" # Arrange - Create 3 tenants tenants_data = [] for i in range(3): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date) + msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date) tenants_data.append( { @@ -600,28 +646,30 @@ class TestMessagesCleanServiceIntegration: # Check which messages were deleted assert ( - db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 ) # Sandbox tenant's message deleted assert ( - db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 ) # Professional tenant's message preserved assert ( - db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 ) # Unknown tenant's message preserved (safe default) - def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_empty_plan_data_skips_deletion( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test when billing returns empty data, skip deletion entirely (B6).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date) + msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date) msg_id = msg.id - db.session.commit() + db_session_with_containers.commit() # Mock billing service to return empty data (simulating failure/no data scenario) with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -644,17 +692,20 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # Message should still exist (safe default - don't delete if plan is unknown) - assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_id).count() == 1 - def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_time_range_boundary_behavior( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create messages: before range, in range, after range msg_before = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from @@ -663,6 +714,7 @@ class TestMessagesCleanServiceIntegration: msg_before_id = msg_before.id msg_at_start = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) @@ -671,6 +723,7 @@ class TestMessagesCleanServiceIntegration: msg_at_start_id = msg_at_start.id msg_in_range = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range @@ -679,6 +732,7 @@ class TestMessagesCleanServiceIntegration: msg_in_range_id = msg_in_range.id msg_at_end = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) @@ -687,6 +741,7 @@ class TestMessagesCleanServiceIntegration: msg_at_end_id = msg_at_end.id msg_after = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before @@ -694,7 +749,7 @@ class TestMessagesCleanServiceIntegration: ) msg_after_id = msg_after.id - db.session.commit() + db_session_with_containers.commit() # Mock billing service with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -722,17 +777,17 @@ class TestMessagesCleanServiceIntegration: # Verify specific messages using stored IDs # Before range, kept - assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_before_id).count() == 1 # At start (inclusive), deleted - assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == msg_at_start_id).count() == 0 # In range, deleted - assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == msg_in_range_id).count() == 0 # At end (exclusive), kept - assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_at_end_id).count() == 1 # After range, kept - assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_after_id).count() == 1 - def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_grace_period_scenarios(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test cleaning with different graceful period scenarios (B8).""" # Arrange - Create 5 different tenants with different plan and expiration scenarios now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) @@ -740,50 +795,60 @@ class TestMessagesCleanServiceIntegration: # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) # Should NOT be deleted - account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app1 = self._create_app(tenant1, account1) - conv1 = self._create_conversation(app1) + account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app1 = self._create_app(db_session_with_containers, tenant1, account1) + conv1 = self._create_conversation(db_session_with_containers, app1) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1 = self._create_message( + db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False + ) msg1_id = msg1.id expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) # Should be deleted - account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app2 = self._create_app(tenant2, account2) - conv2 = self._create_conversation(app2) - msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app2 = self._create_app(db_session_with_containers, tenant2, account2) + conv2 = self._create_conversation(db_session_with_containers, app2) + msg2 = self._create_message( + db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False + ) msg2_id = msg2.id expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) # Should be deleted - account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app3 = self._create_app(tenant3, account3) - conv3 = self._create_conversation(app3) - msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + account3, tenant3 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app3 = self._create_app(db_session_with_containers, tenant3, account3) + conv3 = self._create_conversation(db_session_with_containers, app3) + msg3 = self._create_message( + db_session_with_containers, app3, conv3, created_at=expired_date, with_relations=False + ) msg3_id = msg3.id # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) # Should NOT be deleted - account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) - app4 = self._create_app(tenant4, account4) - conv4 = self._create_conversation(app4) - msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + account4, tenant4 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL) + app4 = self._create_app(db_session_with_containers, tenant4, account4) + conv4 = self._create_conversation(db_session_with_containers, app4) + msg4 = self._create_message( + db_session_with_containers, app4, conv4, created_at=expired_date, with_relations=False + ) msg4_id = msg4.id future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) # Should NOT be deleted (boundary is exclusive: > graceful_period) - account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app5 = self._create_app(tenant5, account5) - conv5 = self._create_conversation(app5) - msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + account5, tenant5 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app5 = self._create_app(db_session_with_containers, tenant5, account5) + conv5 = self._create_conversation(db_session_with_containers, app5) + msg5 = self._create_message( + db_session_with_containers, app5, conv5, created_at=expired_date, with_relations=False + ) msg5_id = msg5.id expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary - db.session.commit() + db_session_with_containers.commit() # Mock billing service with all scenarios plan_map = { @@ -832,23 +897,31 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 2 # Verify each scenario using saved IDs - assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept - assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted - assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted - assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept - assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + assert db_session_with_containers.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert ( + db_session_with_containers.query(Message).where(Message.id == msg2_id).count() == 0 + ) # Beyond grace, deleted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg3_id).count() == 0 + ) # No subscription, deleted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg4_id).count() == 1 + ) # Professional plan, kept + assert db_session_with_containers.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept - def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_tenant_whitelist(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test that whitelisted tenants' messages are not deleted (B9).""" # Arrange - Create 3 sandbox tenants with expired messages tenants_data = [] for i in range(3): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date, with_relations=False + ) tenants_data.append( { @@ -897,27 +970,33 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 # Verify tenant0's message still exists (whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 # Verify tenant1's message still exists (whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 # Verify tenant2's message was deleted (not whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 - def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_from_days_cleans_old_messages( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test from_days correctly cleans messages older than N days (B11).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create old messages (should be deleted - older than 30 days) old_date = datetime.datetime.now() - datetime.timedelta(days=45) old_msg_ids = [] for i in range(3): msg = self._create_message( - app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False + db_session_with_containers, + app, + conv, + created_at=old_date - datetime.timedelta(hours=i), + with_relations=False, ) old_msg_ids.append(msg.id) @@ -926,11 +1005,15 @@ class TestMessagesCleanServiceIntegration: recent_msg_ids = [] for i in range(2): msg = self._create_message( - app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False + db_session_with_containers, + app, + conv, + created_at=recent_date - datetime.timedelta(hours=i), + with_relations=False, ) recent_msg_ids.append(msg.id) - db.session.commit() + db_session_with_containers.commit() with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: mock_billing.return_value = { @@ -955,30 +1038,34 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 3 # Old messages deleted - assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 # Recent messages kept - assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 + assert db_session_with_containers.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 def test_whitelist_precedence_over_grace_period( - self, db_session_with_containers, mock_billing_enabled, mock_whitelist + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist ): """Test that whitelist takes precedence over grace period logic.""" # Arrange - Create 2 sandbox tenants now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) # Tenant1: whitelisted, expired beyond grace period - account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app1 = self._create_app(tenant1, account1) - conv1 = self._create_conversation(app1) + account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app1 = self._create_app(db_session_with_containers, tenant1, account1) + conv1 = self._create_conversation(db_session_with_containers, app1) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1 = self._create_message( + db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False + ) expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace # Tenant2: not whitelisted, within grace period - account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app2 = self._create_app(tenant2, account2) - conv2 = self._create_conversation(app2) - msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app2 = self._create_app(db_session_with_containers, tenant2, account2) + conv2 = self._create_conversation(db_session_with_containers, app2) + msg2 = self._create_message( + db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False + ) expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace # Mock billing service @@ -1019,22 +1106,26 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # Verify both messages still exist - assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted - assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + assert db_session_with_containers.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg2.id).count() == 1 + ) # Within grace period def test_empty_whitelist_deletes_eligible_messages( - self, db_session_with_containers, mock_billing_enabled, mock_whitelist + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist ): """Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" # Arrange - Create sandbox tenant with expired messages - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) msg_ids = [] for i in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i) + ) msg_ids.append(msg.id) # Mock billing service @@ -1068,4 +1159,4 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 3 # Verify all messages were deleted - assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e04725627b..694dc1c1b9 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -32,7 +33,7 @@ class TestMetadataService: "document_service": mock_document_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -53,18 +54,16 @@ class TestMetadataService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -73,15 +72,17 @@ class TestMetadataService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant): + def _create_test_dataset( + self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant + ): """ Helper method to create a test dataset for testing. @@ -105,14 +106,14 @@ class TestMetadataService: built_in_field_enabled=False, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account): + def _create_test_document( + self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account + ): """ Helper method to create a test document for testing. @@ -141,14 +142,12 @@ class TestMetadataService: doc_language="en", ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful metadata creation with valid parameters. """ @@ -178,13 +177,14 @@ class TestMetadataService: assert result.created_by == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.created_at is not None - def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_name_too_long( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata creation fails when name exceeds 255 characters. """ @@ -207,7 +207,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): MetadataService.create_metadata(dataset.id, metadata_args) - def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_name_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata creation fails when name already exists in the same dataset. """ @@ -235,7 +237,7 @@ class TestMetadataService: MetadataService.create_metadata(dataset.id, second_metadata_args) def test_create_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata creation fails when name conflicts with built-in field names. @@ -260,7 +262,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.create_metadata(dataset.id, metadata_args) - def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful metadata name update with valid parameters. """ @@ -291,12 +295,13 @@ class TestMetadataService: assert result.updated_at is not None # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.name == new_name - def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_too_long( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when new name exceeds 255 characters. """ @@ -323,7 +328,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) - def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when new name already exists in the same dataset. """ @@ -351,7 +358,7 @@ class TestMetadataService: MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") def test_update_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata name update fails when new name conflicts with built-in field names. @@ -378,7 +385,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) - def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when metadata ID does not exist. """ @@ -406,7 +415,7 @@ class TestMetadataService: # Assert: Verify the method returns None when metadata is not found assert result is None - def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful metadata deletion with valid parameters. """ @@ -434,12 +443,11 @@ class TestMetadataService: assert result.id == metadata.id # Verify metadata was deleted from database - from extensions.ext_database import db - deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None - def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test metadata deletion fails when metadata ID does not exist. """ @@ -467,7 +475,7 @@ class TestMetadataService: assert result is None def test_delete_metadata_with_document_bindings( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata deletion successfully removes document metadata bindings. @@ -500,15 +508,13 @@ class TestMetadataService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Set document metadata document.doc_metadata = {"test_metadata": "test_value"} - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Act: Execute the method under test result = MetadataService.delete_metadata(dataset.id, metadata.id) @@ -517,13 +523,13 @@ class TestMetadataService: assert result is not None # Verify metadata was deleted from database - deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None # Note: The service attempts to update document metadata but may not succeed # due to mock configuration. The main functionality (metadata deletion) is verified. - def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of built-in metadata fields. """ @@ -548,7 +554,9 @@ class TestMetadataService: assert "string" in field_types assert "time" in field_types - def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_built_in_field_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful enabling of built-in fields for a dataset. """ @@ -579,16 +587,15 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the expected outcomes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (enabling built-in fields) is verified def test_enable_built_in_field_already_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test enabling built-in fields when they are already enabled. @@ -607,10 +614,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] @@ -619,11 +625,11 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the method returns early without changes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True def test_enable_built_in_field_with_no_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test enabling built-in fields for a dataset with no documents. @@ -647,12 +653,13 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the expected outcomes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True - def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_built_in_field_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful disabling of built-in fields for a dataset. """ @@ -673,10 +680,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Set document metadata with built-in fields document.doc_metadata = { @@ -686,8 +692,8 @@ class TestMetadataService: BuiltInField.last_update_date: 1234567890.0, BuiltInField.source: "test_source", } - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ @@ -698,14 +704,14 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the expected outcomes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (disabling built-in fields) is verified def test_disable_built_in_field_already_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test disabling built-in fields when they are already disabled. @@ -732,13 +738,12 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the method returns early without changes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False def test_disable_built_in_field_with_no_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test disabling built-in fields for a dataset with no documents. @@ -757,10 +762,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id to return empty list mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] @@ -769,10 +773,12 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the expected outcomes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False - def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_documents_metadata_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful update of documents metadata. """ @@ -815,24 +821,25 @@ class TestMetadataService: MetadataService.update_documents_metadata(dataset, operation_data) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify document metadata was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.doc_metadata is not None assert "test_metadata" in document.doc_metadata assert document.doc_metadata["test_metadata"] == "test_value" # Verify metadata binding was created binding = ( - db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first() + db_session_with_containers.query(DatasetMetadataBinding) + .filter_by(metadata_id=metadata.id, document_id=document.id) + .first() ) assert binding is not None assert binding.tenant_id == tenant.id assert binding.dataset_id == dataset.id def test_update_documents_metadata_with_built_in_fields_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of documents metadata when built-in fields are enabled. @@ -850,10 +857,9 @@ class TestMetadataService: # Enable built-in fields dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id @@ -884,7 +890,7 @@ class TestMetadataService: # Assert: Verify the expected outcomes # Verify document metadata was updated with both custom and built-in fields - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.doc_metadata is not None assert "test_metadata" in document.doc_metadata assert document.doc_metadata["test_metadata"] == "test_value" @@ -893,7 +899,7 @@ class TestMetadataService: # The main functionality (custom metadata update) is verified def test_update_documents_metadata_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of documents metadata when document is not found. @@ -936,7 +942,7 @@ class TestMetadataService: MetadataService.update_documents_metadata(dataset, operation_data) def test_knowledge_base_metadata_lock_check_dataset_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check for dataset operations. @@ -959,7 +965,7 @@ class TestMetadataService: assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" def test_knowledge_base_metadata_lock_check_document_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check for document operations. @@ -982,7 +988,7 @@ class TestMetadataService: assert call_args[0][0] == f"document_metadata_lock_{document_id}" def test_knowledge_base_metadata_lock_check_lock_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check when lock already exists. @@ -999,7 +1005,7 @@ class TestMetadataService: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) def test_knowledge_base_metadata_lock_check_document_lock_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check when document lock already exists. @@ -1013,7 +1019,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."): MetadataService.knowledge_base_metadata_lock_check(None, document_id) - def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_metadatas_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of dataset metadata information. """ @@ -1046,10 +1054,8 @@ class TestMetadataService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Act: Execute the method under test result = MetadataService.get_dataset_metadatas(dataset) @@ -1071,7 +1077,7 @@ class TestMetadataService: assert result["built_in_field_enabled"] is False def test_get_dataset_metadatas_with_built_in_fields_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of dataset metadata when built-in fields are enabled. @@ -1086,10 +1092,9 @@ class TestMetadataService: # Enable built-in fields dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id @@ -1114,7 +1119,9 @@ class TestMetadataService: # Verify built-in field status assert result["built_in_field_enabled"] is True - def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_metadatas_no_metadata( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of dataset metadata when no metadata exists. """ diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 7c8472e819..989df42499 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from models.account import TenantAccountJoin, TenantAccountRole from models.model import Account, Tenant @@ -67,7 +68,7 @@ class TestModelLoadBalancingService: "credential_schema": mock_credential_schema, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -88,18 +89,16 @@ class TestModelLoadBalancingService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -108,8 +107,8 @@ class TestModelLoadBalancingService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -117,7 +116,7 @@ class TestModelLoadBalancingService: return account, tenant def _create_test_provider_and_setting( - self, db_session_with_containers, tenant_id, mock_external_service_dependencies + self, db_session_with_containers: Session, tenant_id, mock_external_service_dependencies ): """ Helper method to create a test provider and provider model setting. @@ -132,8 +131,6 @@ class TestModelLoadBalancingService: """ fake = Faker() - from extensions.ext_database import db - # Create provider provider = Provider( tenant_id=tenant_id, @@ -141,8 +138,8 @@ class TestModelLoadBalancingService: provider_type="custom", is_valid=True, ) - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Create provider model setting provider_model_setting = ProviderModelSetting( @@ -153,12 +150,14 @@ class TestModelLoadBalancingService: enabled=True, load_balancing_enabled=False, ) - db.session.add(provider_model_setting) - db.session.commit() + db_session_with_containers.add(provider_model_setting) + db_session_with_containers.commit() return provider, provider_model_setting - def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_model_load_balancing_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful model load balancing enablement. @@ -193,14 +192,15 @@ class TestModelLoadBalancingService: assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value # Verify database state - from extensions.ext_database import db - db.session.refresh(provider) - db.session.refresh(provider_model_setting) + db_session_with_containers.refresh(provider) + db_session_with_containers.refresh(provider_model_setting) assert provider.id is not None assert provider_model_setting.id is not None - def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_model_load_balancing_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful model load balancing disablement. @@ -235,15 +235,14 @@ class TestModelLoadBalancingService: assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value # Verify database state - from extensions.ext_database import db - db.session.refresh(provider) - db.session.refresh(provider_model_setting) + db_session_with_containers.refresh(provider) + db_session_with_containers.refresh(provider_model_setting) assert provider.id is not None assert provider_model_setting.id is not None def test_enable_model_load_balancing_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when provider does not exist. @@ -275,11 +274,12 @@ class TestModelLoadBalancingService: assert "Provider nonexistent_provider does not exist." in str(exc_info.value) # Verify no database state changes occurred - from extensions.ext_database import db - db.session.rollback() + db_session_with_containers.rollback() - def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_load_balancing_configs_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of load balancing configurations. @@ -298,7 +298,6 @@ class TestModelLoadBalancingService: ) # Create load balancing config - from extensions.ext_database import db load_balancing_config = LoadBalancingModelConfig( tenant_id=tenant.id, @@ -309,11 +308,11 @@ class TestModelLoadBalancingService: encrypted_config='{"api_key": "test_key"}', enabled=True, ) - db.session.add(load_balancing_config) - db.session.commit() + db_session_with_containers.add(load_balancing_config) + db_session_with_containers.commit() # Verify the config was created - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None # Setup mocks for get_load_balancing_configs method @@ -358,11 +357,11 @@ class TestModelLoadBalancingService: assert configs[0]["ttl"] == 0 # Verify database state - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None def test_get_load_balancing_configs_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when provider does not exist in get_load_balancing_configs. @@ -394,12 +393,11 @@ class TestModelLoadBalancingService: assert "Provider nonexistent_provider does not exist." in str(exc_info.value) # Verify no database state changes occurred - from extensions.ext_database import db - db.session.rollback() + db_session_with_containers.rollback() def test_get_load_balancing_configs_with_inherit_config( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test load balancing configs retrieval with inherit configuration. @@ -419,7 +417,6 @@ class TestModelLoadBalancingService: ) # Create load balancing config - from extensions.ext_database import db load_balancing_config = LoadBalancingModelConfig( tenant_id=tenant.id, @@ -430,8 +427,8 @@ class TestModelLoadBalancingService: encrypted_config='{"api_key": "test_key"}', enabled=True, ) - db.session.add(load_balancing_config) - db.session.commit() + db_session_with_containers.add(load_balancing_config) + db_session_with_containers.commit() # Setup mocks for inherit config scenario mock_provider_config = mock_external_service_dependencies["provider_config"] @@ -467,11 +464,11 @@ class TestModelLoadBalancingService: assert configs[1]["name"] == "config1" # Verify database state - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None # Verify inherit config was created in database - inherit_configs = db.session.scalars( + inherit_configs = db_session_with_containers.scalars( select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") ).all() assert len(inherit_configs) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index f7044f7d45..6afc5aa43c 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,9 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from core.model_runtime.entities.model_entities import FetchFrom, ModelType +from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -29,7 +30,7 @@ class TestModelProviderService: "model_provider_factory": mock_model_provider_factory, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,18 +51,16 @@ class TestModelProviderService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -70,8 +69,8 @@ class TestModelProviderService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -80,7 +79,7 @@ class TestModelProviderService: def _create_test_provider( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str = "openai", @@ -109,16 +108,14 @@ class TestModelProviderService: quota_used=0, ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() return provider def _create_test_provider_model( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str, @@ -149,16 +146,14 @@ class TestModelProviderService: is_valid=True, ) - from extensions.ext_database import db - - db.session.add(provider_model) - db.session.commit() + db_session_with_containers.add(provider_model) + db_session_with_containers.commit() return provider_model def _create_test_provider_model_setting( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str, @@ -190,14 +185,12 @@ class TestModelProviderService: load_balancing_enabled=False, ) - from extensions.ext_database import db - - db.session.add(provider_model_setting) - db.session.commit() + db_session_with_containers.add(provider_model_setting) + db_session_with_containers.commit() return provider_model_setting - def test_get_provider_list_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_provider_list_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful provider list retrieval. @@ -275,7 +268,7 @@ class TestModelProviderService: mock_provider_config.is_custom_configuration_available.assert_called_once() def test_get_provider_list_with_model_type_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test provider list retrieval with model type filtering. @@ -374,7 +367,9 @@ class TestModelProviderService: assert result[0].provider == "cohere" assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types - def test_get_models_by_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_models_by_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of models by provider. @@ -407,8 +402,8 @@ class TestModelProviderService: # Create mock models from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - from core.model_runtime.entities.common_entities import I18nObject - from core.model_runtime.entities.provider_entities import ProviderEntity + from dify_graph.model_runtime.entities.common_entities import I18nObject + from dify_graph.model_runtime.entities.provider_entities import ProviderEntity # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( @@ -485,7 +480,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_configurations.get_models.assert_called_once_with(provider="openai") - def test_get_provider_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_provider_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of provider credentials. @@ -543,7 +540,7 @@ class TestModelProviderService: mock_method.assert_called_once_with(tenant.id, "openai") def test_provider_credentials_validate_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful validation of provider credentials. @@ -585,7 +582,7 @@ class TestModelProviderService: mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) def test_provider_credentials_validate_invalid_provider( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test validation failure for non-existent provider. @@ -617,7 +614,7 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) def test_get_default_model_of_model_type_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of default model for a specific model type. @@ -643,7 +640,7 @@ class TestModelProviderService: # Create mock default model response from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity - from core.model_runtime.entities.common_entities import I18nObject + from dify_graph.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", @@ -673,7 +670,7 @@ class TestModelProviderService: mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM) def test_update_default_model_of_model_type_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of default model for a specific model type. @@ -706,7 +703,9 @@ class TestModelProviderService: tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4" ) - def test_get_model_provider_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_provider_icon_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model provider icon. @@ -743,7 +742,9 @@ class TestModelProviderService: # Verify mock interactions mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US") - def test_switch_preferred_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_preferred_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful switching of preferred provider type. @@ -779,7 +780,7 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configuration.switch_preferred_provider_type.assert_called_once() - def test_enable_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_model_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful enabling of a model. @@ -815,7 +816,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4") - def test_get_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model credentials. @@ -872,7 +875,9 @@ class TestModelProviderService: # Verify the method was called with correct parameters mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) - def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_model_credentials_validate_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful validation of model credentials. @@ -914,7 +919,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials ) - def test_save_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful saving of model credentials. @@ -955,7 +962,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" ) - def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful removal of model credentials. @@ -993,7 +1002,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" ) - def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_models_by_model_type_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of models by model type. @@ -1070,7 +1081,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) - def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_parameter_rules_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model parameter rules. @@ -1137,7 +1150,7 @@ class TestModelProviderService: ) def test_get_model_parameter_rules_no_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parameter rules retrieval when no credentials are available. @@ -1181,7 +1194,7 @@ class TestModelProviderService: ) def test_get_model_parameter_rules_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parameter rules retrieval when provider does not exist. diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 9e6b9837ae..e3ec1d1df3 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.model import EndUser, Message from models.web import SavedMessage @@ -38,7 +39,7 @@ class TestSavedMessageService: "message_service": mock_message_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -85,7 +86,7 @@ class TestSavedMessageService: return app, account - def _create_test_end_user(self, db_session_with_containers, app): + def _create_test_end_user(self, db_session_with_containers: Session, app): """ Helper method to create a test end user for testing. @@ -108,14 +109,12 @@ class TestSavedMessageService: is_anonymous=False, ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_message(self, db_session_with_containers, app, user): + def _create_test_message(self, db_session_with_containers: Session, app, user): """ Helper method to create a test message for testing. @@ -143,10 +142,8 @@ class TestSavedMessageService: mode="chat", ) - from extensions.ext_database import db - - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message( @@ -168,13 +165,13 @@ class TestSavedMessageService: status="success", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message def test_pagination_by_last_id_success_with_account_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination by last ID with account user. @@ -207,10 +204,8 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add_all([saved_message1, saved_message2]) - db.session.commit() + db_session_with_containers.add_all([saved_message1, saved_message2]) + db_session_with_containers.commit() # Mock MessageService.pagination_by_last_id return value from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -240,15 +235,15 @@ class TestSavedMessageService: assert actual_include_ids == expected_include_ids # Verify database state - db.session.refresh(saved_message1) - db.session.refresh(saved_message2) + db_session_with_containers.refresh(saved_message1) + db_session_with_containers.refresh(saved_message2) assert saved_message1.id is not None assert saved_message2.id is not None assert saved_message1.created_by_role == "account" assert saved_message2.created_by_role == "account" def test_pagination_by_last_id_success_with_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination by last ID with end user. @@ -282,10 +277,8 @@ class TestSavedMessageService: created_by=end_user.id, ) - from extensions.ext_database import db - - db.session.add_all([saved_message1, saved_message2]) - db.session.commit() + db_session_with_containers.add_all([saved_message1, saved_message2]) + db_session_with_containers.commit() # Mock MessageService.pagination_by_last_id return value from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -317,14 +310,16 @@ class TestSavedMessageService: assert actual_include_ids == expected_include_ids # Verify database state - db.session.refresh(saved_message1) - db.session.refresh(saved_message2) + db_session_with_containers.refresh(saved_message1) + db_session_with_containers.refresh(saved_message2) assert saved_message1.id is not None assert saved_message2.id is not None assert saved_message1.created_by_role == "end_user" assert saved_message2.created_by_role == "end_user" - def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_success_with_new_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful save of a new message. @@ -347,10 +342,9 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was created in database - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -373,10 +367,12 @@ class TestSavedMessageService: ) # Verify database state - db.session.refresh(saved_message) + db_session_with_containers.refresh(saved_message) assert saved_message.id is not None - def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_error_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when no user is provided. @@ -396,12 +392,11 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) # Verify no database operations were performed - from extensions.ext_database import db - saved_messages = db.session.query(SavedMessage).all() + saved_messages = db_session_with_containers.query(SavedMessage).all() assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -422,10 +417,9 @@ class TestSavedMessageService: assert result is None # Verify no saved message was created - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -435,7 +429,9 @@ class TestSavedMessageService: assert saved_message is None - def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_success_existing_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of an existing saved message. @@ -457,14 +453,12 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(saved_message) - db.session.commit() + db_session_with_containers.add(saved_message) + db_session_with_containers.commit() # Verify saved message exists assert ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -481,7 +475,7 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was deleted from database deleted_saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -494,11 +488,13 @@ class TestSavedMessageService: assert deleted_saved_message is None # Verify database state - db.session.commit() + db_session_with_containers.commit() # The message should still exist, only the saved_message should be deleted - assert db.session.query(Message).where(Message.id == message.id).first() is not None + assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_error_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when no user is provided. @@ -522,7 +518,7 @@ class TestSavedMessageService: # Instead, we verify that the error was properly raised pass - def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -543,10 +539,9 @@ class TestSavedMessageService: assert result is None # Verify no saved message was created - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -556,7 +551,9 @@ class TestSavedMessageService: assert saved_message is None - def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_success_existing_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of an existing saved message. @@ -578,14 +575,12 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(saved_message) - db.session.commit() + db_session_with_containers.add(saved_message) + db_session_with_containers.commit() # Verify saved message exists assert ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -602,7 +597,7 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was deleted from database deleted_saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -615,6 +610,6 @@ class TestSavedMessageService: assert deleted_saved_message is None # Verify database state - db.session.commit() + db_session_with_containers.commit() # The message should still exist, only the saved_message should be deleted - assert db.session.query(Message).where(Message.id == message.id).first() is not None + assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index e8c7f17e0b..597ba6b75b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -4,6 +4,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -29,7 +30,7 @@ class TestTagService: "current_user": mock_current_user, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,18 +51,16 @@ class TestTagService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -70,8 +69,8 @@ class TestTagService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -82,7 +81,7 @@ class TestTagService: return account, tenant - def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + def _create_test_dataset(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id): """ Helper method to create a test dataset for testing. @@ -107,14 +106,12 @@ class TestTagService: created_by=mock_external_service_dependencies["current_user"].id, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id): """ Helper method to create a test app for testing. @@ -141,15 +138,13 @@ class TestTagService: created_by=mock_external_service_dependencies["current_user"].id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app def _create_test_tags( - self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, tag_type, count=3 ): """ Helper method to create test tags for testing. @@ -176,16 +171,14 @@ class TestTagService: ) tags.append(tag) - from extensions.ext_database import db - for tag in tags: - db.session.add(tag) - db.session.commit() + db_session_with_containers.add(tag) + db_session_with_containers.commit() return tags def _create_test_tag_bindings( - self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id + self, db_session_with_containers: Session, mock_external_service_dependencies, tags, target_id, tenant_id ): """ Helper method to create test tag bindings for testing. @@ -211,15 +204,13 @@ class TestTagService: ) tag_bindings.append(tag_binding) - from extensions.ext_database import db - for tag_binding in tag_bindings: - db.session.add(tag_binding) - db.session.commit() + db_session_with_containers.add(tag_binding) + db_session_with_containers.commit() return tag_bindings - def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tags with binding count. @@ -270,7 +261,9 @@ class TestTagService: # The ordering is handled by the database, we just verify the results are returned assert len(result) == 3 - def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_with_keyword_filter( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval with keyword filtering. @@ -291,12 +284,11 @@ class TestTagService: ) # Update tag names to make them searchable - from extensions.ext_database import db tags[0].name = "python_development" tags[1].name = "machine_learning" tags[2].name = "web_development" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test with keyword filter result = TagService.get_tags("app", tenant.id, keyword="development") @@ -314,7 +306,7 @@ class TestTagService: assert len(result_no_match) == 0 def test_get_tags_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test tag retrieval with special characters in keyword to verify SQL injection prevention. @@ -330,8 +322,6 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - from extensions.ext_database import db - # Create tags with special characters in names tag_with_percent = Tag( name="50% discount", @@ -340,7 +330,7 @@ class TestTagService: created_by=account.id, ) tag_with_percent.id = str(uuid.uuid4()) - db.session.add(tag_with_percent) + db_session_with_containers.add(tag_with_percent) tag_with_underscore = Tag( name="test_data_tag", @@ -349,7 +339,7 @@ class TestTagService: created_by=account.id, ) tag_with_underscore.id = str(uuid.uuid4()) - db.session.add(tag_with_underscore) + db_session_with_containers.add(tag_with_underscore) tag_with_backslash = Tag( name="path\\to\\tag", @@ -358,7 +348,7 @@ class TestTagService: created_by=account.id, ) tag_with_backslash.id = str(uuid.uuid4()) - db.session.add(tag_with_backslash) + db_session_with_containers.add(tag_with_backslash) # Create tag that should NOT match tag_no_match = Tag( @@ -368,9 +358,9 @@ class TestTagService: created_by=account.id, ) tag_no_match.id = str(uuid.uuid4()) - db.session.add(tag_no_match) + db_session_with_containers.add(tag_no_match) - db.session.commit() + db_session_with_containers.commit() # Act & Assert: Test 1 - Search with % character result = TagService.get_tags("app", tenant.id, keyword="50%") @@ -392,7 +382,7 @@ class TestTagService: assert len(result) == 1 assert all("50%" in item.name for item in result) - def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_empty_result(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. @@ -414,7 +404,9 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_target_ids_by_tag_ids_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of target IDs by tag IDs. @@ -469,7 +461,7 @@ class TestTagService: assert second_dataset_count == 1 def test_get_target_ids_by_tag_ids_empty_tag_ids( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target ID retrieval with empty tag IDs list. @@ -493,7 +485,7 @@ class TestTagService: assert isinstance(result, list) def test_get_target_ids_by_tag_ids_no_matching_tags( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target ID retrieval when no tags match the criteria. @@ -521,7 +513,7 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tags by tag name. @@ -542,11 +534,10 @@ class TestTagService: ) # Update tag names to make them searchable - from extensions.ext_database import db tags[0].name = "python_tag" tags[1].name = "ml_tag" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag") @@ -558,7 +549,9 @@ class TestTagService: assert result[0].type == "app" assert result[0].tenant_id == tenant.id - def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_no_matches( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by name when no matches exist. @@ -580,7 +573,9 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_empty_parameters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by name with empty parameters. @@ -605,7 +600,9 @@ class TestTagService: assert result_empty_name is not None assert len(result_empty_name) == 0 - def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_by_target_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of tags by target ID. @@ -644,7 +641,9 @@ class TestTagService: assert tag.tenant_id == tenant.id assert tag.id in [t.id for t in tags] - def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_by_target_id_no_bindings( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by target ID when no tags are bound. @@ -669,7 +668,7 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag creation. @@ -698,17 +697,18 @@ class TestTagService: assert result.id is not None # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None # Verify tag was actually saved to database - saved_tag = db.session.query(Tag).where(Tag.id == result.id).first() + saved_tag = db_session_with_containers.query(Tag).where(Tag.id == result.id).first() assert saved_tag is not None assert saved_tag.name == "test_tag_name" - def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tags_duplicate_name_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag creation with duplicate name. @@ -731,7 +731,7 @@ class TestTagService: TagService.save_tags(tag_args) assert "Tag name already exists" in str(exc_info.value) - def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag update. @@ -763,17 +763,16 @@ class TestTagService: assert result.id == tag.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.name == "updated_name" # Verify tag was actually updated in database - updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first() + updated_tag = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert updated_tag is not None assert updated_tag.name == "updated_name" - def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag update for non-existent tag. @@ -799,7 +798,9 @@ class TestTagService: TagService.update_tags(update_args, non_existent_tag_id) assert "Tag not found" in str(exc_info.value) - def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_duplicate_name_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag update with duplicate name. @@ -828,7 +829,9 @@ class TestTagService: TagService.update_tags(update_args, tag2.id) assert "Tag name already exists" in str(exc_info.value) - def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_binding_count_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of tag binding count. @@ -863,7 +866,7 @@ class TestTagService: assert result_tag_without_bindings == 0 def test_get_tag_binding_count_non_existent_tag( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test binding count retrieval for non-existent tag. @@ -889,7 +892,7 @@ class TestTagService: # Assert: Verify the expected outcomes assert result == 0 - def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag deletion. @@ -916,12 +919,11 @@ class TestTagService: ) # Verify tag and binding exist before deletion - from extensions.ext_database import db - tag_before = db.session.query(Tag).where(Tag.id == tag.id).first() + tag_before = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert tag_before is not None - binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + binding_before = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first() assert binding_before is not None # Act: Execute the method under test @@ -929,14 +931,14 @@ class TestTagService: # Assert: Verify the expected outcomes # Verify tag was deleted - tag_after = db.session.query(Tag).where(Tag.id == tag.id).first() + tag_after = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert tag_after is None # Verify tag binding was deleted - binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + binding_after = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first() assert binding_after is None - def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag deletion for non-existent tag. @@ -960,7 +962,7 @@ class TestTagService: TagService.delete_tag(non_existent_tag_id) assert "Tag not found" in str(exc_info.value) - def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag binding creation. @@ -988,12 +990,11 @@ class TestTagService: TagService.save_tag_binding(binding_args) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify tag bindings were created for tag in tags: binding = ( - db.session.query(TagBinding) + db_session_with_containers.query(TagBinding) .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) .first() ) @@ -1001,7 +1002,9 @@ class TestTagService: assert binding.tenant_id == tenant.id assert binding.created_by == account.id - def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_duplicate_handling( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag binding creation with duplicate bindings. @@ -1032,15 +1035,16 @@ class TestTagService: TagService.save_tag_binding(binding_args) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify only one binding exists - bindings = db.session.scalars( + bindings = db_session_with_containers.scalars( select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) ).all() assert len(bindings) == 1 - def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_invalid_target_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag binding creation with invalid target type. @@ -1071,7 +1075,7 @@ class TestTagService: TagService.save_tag_binding(binding_args) assert "Invalid binding type" in str(exc_info.value) - def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag binding deletion. @@ -1098,10 +1102,11 @@ class TestTagService: ) # Verify binding exists before deletion - from extensions.ext_database import db binding_before = ( - db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + db_session_with_containers.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() ) assert binding_before is not None @@ -1112,12 +1117,14 @@ class TestTagService: # Assert: Verify the expected outcomes # Verify tag binding was deleted binding_after = ( - db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + db_session_with_containers.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() ) assert binding_after is None def test_delete_tag_binding_non_existent_binding( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tag binding deletion for non-existent binding. @@ -1145,15 +1152,14 @@ class TestTagService: # Assert: Verify the expected outcomes # No error should be raised, and database state should remain unchanged - from extensions.ext_database import db - bindings = db.session.scalars( + bindings = db_session_with_containers.scalars( select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) ).all() assert len(bindings) == 0 def test_check_target_exists_knowledge_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful target existence check for knowledge type. @@ -1179,7 +1185,7 @@ class TestTagService: # No exception should be raised for existing dataset def test_check_target_exists_knowledge_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target existence check for non-existent knowledge dataset. @@ -1204,7 +1210,9 @@ class TestTagService: TagService.check_target_exists("knowledge", non_existent_dataset_id) assert "Dataset not found" in str(exc_info.value) - def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_app_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful target existence check for app type. @@ -1228,7 +1236,9 @@ class TestTagService: # Assert: Verify the expected outcomes # No exception should be raised for existing app - def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test target existence check for non-existent app. @@ -1252,7 +1262,9 @@ class TestTagService: TagService.check_target_exists("app", non_existent_app_id) assert "App not found" in str(exc_info.value) - def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_invalid_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test target existence check for invalid type. diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 5315960d73..912aa3dd2f 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -2,11 +2,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity -from extensions.ext_database import db from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService @@ -47,7 +47,7 @@ class TestTriggerProviderService: "account_feature_service": mock_account_feature_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -84,7 +84,7 @@ class TestTriggerProviderService: def _create_test_subscription( self, - db_session_with_containers, + db_session_with_containers: Session, tenant_id, user_id, provider_id, @@ -135,14 +135,14 @@ class TestTriggerProviderService: expires_at=-1, ) - db.session.add(subscription) - db.session.commit() - db.session.refresh(subscription) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + db_session_with_containers.refresh(subscription) return subscription def test_rebuild_trigger_subscription_success_with_merged_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful rebuild with credential merging (HIDDEN_VALUE handling). @@ -217,7 +217,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "new-secret-value" # New value # Verify database state was updated - db.session.refresh(subscription) + db_session_with_containers.refresh(subscription) assert subscription.name == "updated_name" assert subscription.parameters == {"param1": "updated_value"} @@ -244,7 +244,7 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_with_all_new_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when all credentials are new (no HIDDEN_VALUE). @@ -304,7 +304,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "completely-new-secret" def test_rebuild_trigger_subscription_with_all_hidden_values( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing). @@ -363,7 +363,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == original_credentials["api_secret"] def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original. @@ -422,7 +422,7 @@ class TestTriggerProviderService: assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE def test_rebuild_trigger_subscription_rollback_on_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that transaction is rolled back on error. @@ -470,12 +470,12 @@ class TestTriggerProviderService: ) # Verify subscription state was not changed (rolled back) - db.session.refresh(subscription) + db_session_with_containers.refresh(subscription) assert subscription.name == original_name assert subscription.parameters == original_parameters def test_rebuild_trigger_subscription_subscription_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error when subscription is not found. @@ -501,7 +501,7 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_name_uniqueness_check( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that name uniqueness is checked when updating name. diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index bbbf48ede9..f1e8c152f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models import Account @@ -45,7 +46,7 @@ class TestWebConversationService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -90,7 +91,7 @@ class TestWebConversationService: return app, account - def _create_test_end_user(self, db_session_with_containers, app): + def _create_test_end_user(self, db_session_with_containers: Session, app): """ Helper method to create a test end user for testing. @@ -111,14 +112,12 @@ class TestWebConversationService: tenant_id=app.tenant_id, ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_conversation(self, db_session_with_containers, app, user, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, user, fake): """ Helper method to create a test conversation for testing. @@ -152,14 +151,14 @@ class TestWebConversationService: is_deleted=False, ) - from extensions.ext_database import db - - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() return conversation - def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by last ID with basic parameters. """ @@ -194,7 +193,7 @@ class TestWebConversationService: assert result.data[1].updated_at >= result.data[2].updated_at def test_pagination_by_last_id_with_pinned_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with pinned conversation filter. @@ -222,11 +221,9 @@ class TestWebConversationService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(pinned_conversation1) - db.session.add(pinned_conversation2) - db.session.commit() + db_session_with_containers.add(pinned_conversation1) + db_session_with_containers.add(pinned_conversation2) + db_session_with_containers.commit() # Test pagination with pinned filter result = WebConversationService.pagination_by_last_id( @@ -251,7 +248,7 @@ class TestWebConversationService: assert set(returned_ids) == set(expected_ids) def test_pagination_by_last_id_with_unpinned_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with unpinned conversation filter. @@ -273,10 +270,8 @@ class TestWebConversationService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(pinned_conversation) - db.session.commit() + db_session_with_containers.add(pinned_conversation) + db_session_with_containers.commit() # Test pagination with unpinned filter result = WebConversationService.pagination_by_last_id( @@ -303,7 +298,7 @@ class TestWebConversationService: expected_unpinned_ids = [conv.id for conv in conversations[1:]] assert set(returned_ids) == set(expected_unpinned_ids) - def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful pinning of a conversation. """ @@ -317,10 +312,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify the conversation was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -336,7 +330,9 @@ class TestWebConversationService: assert pinned_conversation.created_by_role == "account" assert pinned_conversation.created_by == account.id - def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_already_pinned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pinning a conversation that is already pinned (should not create duplicate). """ @@ -353,9 +349,8 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify only one pinned conversation record exists - from extensions.ext_database import db - pinned_conversations = db.session.scalars( + pinned_conversations = db_session_with_containers.scalars( select(PinnedConversation).where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -366,7 +361,9 @@ class TestWebConversationService: assert len(pinned_conversations) == 1 - def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_with_end_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pinning a conversation with an end user. """ @@ -383,10 +380,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, end_user) # Verify the conversation was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -402,7 +398,7 @@ class TestWebConversationService: assert pinned_conversation.created_by_role == "end_user" assert pinned_conversation.created_by == end_user.id - def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful unpinning of a conversation. """ @@ -416,10 +412,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify it was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -436,7 +431,7 @@ class TestWebConversationService: # Verify it was unpinned pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -448,7 +443,9 @@ class TestWebConversationService: assert pinned_conversation is None - def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_not_pinned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test unpinning a conversation that is not pinned (should not cause error). """ @@ -462,10 +459,9 @@ class TestWebConversationService: WebConversationService.unpin(app, conversation.id, account) # Verify no pinned conversation record exists - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -478,7 +474,7 @@ class TestWebConversationService: assert pinned_conversation is None def test_pagination_by_last_id_user_required_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that pagination_by_last_id raises ValueError when user is None. @@ -499,7 +495,7 @@ class TestWebConversationService: sort_by="-updated_at", ) - def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_user_none(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test that pin method returns early when user is None. """ @@ -513,10 +509,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, None) # Verify no pinned conversation was created - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -526,7 +521,9 @@ class TestWebConversationService: assert pinned_conversation is None - def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_user_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that unpin method returns early when user is None. """ @@ -540,10 +537,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify it was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -560,7 +556,7 @@ class TestWebConversationService: # Verify the conversation is still pinned pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index d1c566e477..9a1595d266 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized from libs.password import hash_password @@ -45,7 +46,7 @@ class TestWebAppAuthService: "enterprise_service": mock_enterprise_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -68,18 +69,16 @@ class TestWebAppAuthService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -88,15 +87,17 @@ class TestWebAppAuthService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_with_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test account with password for testing. @@ -131,18 +132,16 @@ class TestWebAppAuthService: account.password = base64.b64encode(password_hash).decode() account.password_salt = base64.b64encode(salt).decode() - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -151,15 +150,17 @@ class TestWebAppAuthService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant, password - def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant): + def _create_test_app_and_site( + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant + ): """ Helper method to create a test app and site for testing. @@ -188,10 +189,8 @@ class TestWebAppAuthService: enable_api=True, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Create site site = Site( @@ -203,12 +202,12 @@ class TestWebAppAuthService: status="normal", customize_token_strategy="not_allow", ) - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() return app, site - def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful authentication with valid email and password. @@ -233,14 +232,15 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.password is not None assert result.password_salt is not None - def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with non-existent email. @@ -262,7 +262,7 @@ class TestWebAppAuthService: with pytest.raises(AccountNotFoundError): WebAppAuthService.authenticate(non_existent_email, "any_password") - def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_banned(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with banned account. @@ -292,10 +292,8 @@ class TestWebAppAuthService: account.password = base64.b64encode(password_hash).decode() account.password_salt = base64.b64encode(salt).decode() - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(AccountLoginError) as exc_info: @@ -303,7 +301,9 @@ class TestWebAppAuthService: assert "Account is banned." in str(exc_info.value) - def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_invalid_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with invalid password. @@ -323,7 +323,7 @@ class TestWebAppAuthService: assert "Invalid email or password." in str(exc_info.value) def test_authenticate_account_without_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test authentication for account without password. @@ -345,10 +345,8 @@ class TestWebAppAuthService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(AccountPasswordError) as exc_info: @@ -356,7 +354,7 @@ class TestWebAppAuthService: assert "Invalid email or password." in str(exc_info.value) - def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful login and JWT token generation. @@ -388,7 +386,9 @@ class TestWebAppAuthService: assert call_args["auth_type"] == "internal" assert "exp" in call_args - def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful user retrieval through email. @@ -413,12 +413,13 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None - def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test user retrieval with non-existent email. @@ -435,7 +436,9 @@ class TestWebAppAuthService: # Assert: Verify proper handling assert result is None - def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_banned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test user retrieval with banned account. @@ -456,10 +459,8 @@ class TestWebAppAuthService: status=AccountStatus.BANNED, ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(Unauthorized) as exc_info: @@ -468,7 +469,7 @@ class TestWebAppAuthService: assert "Account is banned." in str(exc_info.value) def test_send_email_code_login_email_with_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email with account. @@ -509,7 +510,7 @@ class TestWebAppAuthService: assert "code" in mail_call_args[1] def test_send_email_code_login_email_with_email_only( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email with email only. @@ -549,7 +550,7 @@ class TestWebAppAuthService: assert "code" in mail_call_args[1] def test_send_email_code_login_email_no_email_provided( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email without providing email. @@ -566,7 +567,9 @@ class TestWebAppAuthService: assert "Email must be provided." in str(exc_info.value) - def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_email_code_login_data_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of email code login data. @@ -593,7 +596,9 @@ class TestWebAppAuthService: "mock_token", "email_code_login" ) - def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_email_code_login_data_no_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email code login data retrieval when no data exists. @@ -617,7 +622,7 @@ class TestWebAppAuthService: ) def test_revoke_email_code_login_token_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful revocation of email code login token. @@ -636,7 +641,7 @@ class TestWebAppAuthService: "mock_token", "email_code_login" ) - def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful end user creation. @@ -668,14 +673,15 @@ class TestWebAppAuthService: assert result.external_user_id == "enterpriseuser" # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.created_at is not None assert result.updated_at is not None - def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_site_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test end user creation with non-existent site code. @@ -693,7 +699,9 @@ class TestWebAppAuthService: assert "Site not found." in str(exc_info.value) - def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test end user creation when app is not found. @@ -708,10 +716,8 @@ class TestWebAppAuthService: status="normal", ) - from extensions.ext_database import db - - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() site = Site( app_id="00000000-0000-0000-0000-000000000000", @@ -722,8 +728,8 @@ class TestWebAppAuthService: status="normal", customize_token_strategy="not_allow", ) - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: @@ -732,7 +738,7 @@ class TestWebAppAuthService: assert "App not found." in str(exc_info.value) def test_is_app_require_permission_check_with_access_mode_private( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement for private access mode. @@ -751,7 +757,7 @@ class TestWebAppAuthService: assert result is True def test_is_app_require_permission_check_with_access_mode_public( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement for public access mode. @@ -770,7 +776,7 @@ class TestWebAppAuthService: assert result is False def test_is_app_require_permission_check_with_app_code( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement using app code. @@ -796,7 +802,7 @@ class TestWebAppAuthService: ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id") def test_is_app_require_permission_check_no_parameters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement with no parameters. @@ -814,7 +820,7 @@ class TestWebAppAuthService: assert "Either app_code or app_id must be provided." in str(exc_info.value) def test_get_app_auth_type_with_access_mode_public( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test app authentication type for public access mode. @@ -833,7 +839,7 @@ class TestWebAppAuthService: assert result == WebAppAuthType.PUBLIC def test_get_app_auth_type_with_access_mode_private( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test app authentication type for private access mode. @@ -851,7 +857,9 @@ class TestWebAppAuthService: # Assert: Verify correct result assert result == WebAppAuthType.INTERNAL - def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_auth_type_with_app_code( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app authentication type using app code. @@ -878,7 +886,9 @@ class TestWebAppAuthService: "enterprise_service" ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id") - def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_auth_type_no_parameters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app authentication type with no parameters. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 040fb826e1..a3440b6b67 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -5,8 +5,9 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from core.workflow.entities.workflow_execution import WorkflowExecutionStatus +from dify_graph.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole from services.account_service import AccountService, TenantService @@ -48,7 +49,7 @@ class TestWorkflowAppService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -96,7 +97,7 @@ class TestWorkflowAppService: return app, account - def _create_test_tenant_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_tenant_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test tenant and account for testing. @@ -126,7 +127,7 @@ class TestWorkflowAppService: return tenant, account - def _create_test_app(self, db_session_with_containers, tenant, account): + def _create_test_app(self, db_session_with_containers: Session, tenant, account): """ Helper method to create a test app for testing. @@ -160,7 +161,7 @@ class TestWorkflowAppService: return app - def _create_test_workflow_data(self, db_session_with_containers, app, account): + def _create_test_workflow_data(self, db_session_with_containers: Session, app, account): """ Helper method to create test workflow data for testing. @@ -174,8 +175,6 @@ class TestWorkflowAppService: """ fake = Faker() - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -188,8 +187,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run workflow_run = WorkflowRun( @@ -212,8 +211,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -227,13 +226,13 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() return workflow, workflow_run, workflow_app_log def test_get_paginate_workflow_app_logs_basic_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination of workflow app logs with basic parameters. @@ -268,13 +267,12 @@ class TestWorkflowAppService: assert log_entry.workflow_run_id == workflow_run.id # Verify database state - from extensions.ext_database import db - db.session.refresh(workflow_app_log) + db_session_with_containers.refresh(workflow_app_log) assert workflow_app_log.id is not None def test_get_paginate_workflow_app_logs_with_keyword_search( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with keyword search functionality. @@ -287,11 +285,10 @@ class TestWorkflowAppService: ) # Update workflow run with searchable content - from extensions.ext_database import db workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"}) workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"}) - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test with keyword search service = WorkflowAppService() @@ -317,7 +314,7 @@ class TestWorkflowAppService: assert len(result_no_match["data"]) == 0 def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. @@ -332,8 +329,6 @@ class TestWorkflowAppService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) - from extensions.ext_database import db - service = WorkflowAppService() # Test 1: Search with % character @@ -353,8 +348,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_1) - db.session.flush() + db_session_with_containers.add(workflow_run_1) + db_session_with_containers.flush() workflow_app_log_1 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -367,8 +362,8 @@ class TestWorkflowAppService: ) workflow_app_log_1.id = str(uuid.uuid4()) workflow_app_log_1.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_1) - db.session.commit() + db_session_with_containers.add(workflow_app_log_1) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 @@ -395,8 +390,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_2) - db.session.flush() + db_session_with_containers.add(workflow_run_2) + db_session_with_containers.flush() workflow_app_log_2 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -409,8 +404,8 @@ class TestWorkflowAppService: ) workflow_app_log_2.id = str(uuid.uuid4()) workflow_app_log_2.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_2) - db.session.commit() + db_session_with_containers.add(workflow_app_log_2) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 @@ -437,8 +432,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_4) - db.session.flush() + db_session_with_containers.add(workflow_run_4) + db_session_with_containers.flush() workflow_app_log_4 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -451,8 +446,8 @@ class TestWorkflowAppService: ) workflow_app_log_4.id = str(uuid.uuid4()) workflow_app_log_4.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_4) - db.session.commit() + db_session_with_containers.add(workflow_app_log_4) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 @@ -467,7 +462,7 @@ class TestWorkflowAppService: assert workflow_run_4.id not in found_run_ids def test_get_paginate_workflow_app_logs_with_status_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with status filtering. @@ -476,8 +471,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -490,8 +483,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow runs with different statuses statuses = ["succeeded", "failed", "running", "stopped"] @@ -519,8 +512,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -533,8 +526,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -568,7 +561,7 @@ class TestWorkflowAppService: assert result_running["data"][0].workflow_run.status == "running" def test_get_paginate_workflow_app_logs_with_time_filtering( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with time-based filtering. @@ -577,8 +570,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -591,8 +582,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow runs with different timestamps base_time = datetime.now(UTC) @@ -627,8 +618,8 @@ class TestWorkflowAppService: created_at=timestamp, finished_at=timestamp + timedelta(minutes=1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -641,8 +632,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = timestamp - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -682,7 +673,7 @@ class TestWorkflowAppService: assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago def test_get_paginate_workflow_app_logs_with_pagination( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with different page sizes and limits. @@ -691,8 +682,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -705,8 +694,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create 25 workflow runs and logs total_logs = 25 @@ -734,8 +723,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -748,8 +737,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -798,7 +787,7 @@ class TestWorkflowAppService: assert len(result_large_limit["data"]) == total_logs def test_get_paginate_workflow_app_logs_with_user_role_filtering( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with user role and session filtering. @@ -807,8 +796,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -821,8 +808,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create end user end_user = EndUser( @@ -835,8 +822,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create workflow runs and logs for both account and end user workflow_runs = [] @@ -864,8 +851,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -878,8 +865,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -906,8 +893,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i + 10), finished_at=datetime.now(UTC) + timedelta(minutes=i + 11), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -920,8 +907,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i + 10) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -994,7 +981,7 @@ class TestWorkflowAppService: assert "Account not found" in str(exc_info.value) def test_get_paginate_workflow_app_logs_with_uuid_keyword_search( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with UUID keyword search functionality. @@ -1003,8 +990,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -1017,8 +1002,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run with specific UUID workflow_run_id = str(uuid.uuid4()) @@ -1042,8 +1027,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC) + timedelta(minutes=1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -1057,8 +1042,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() # Act & Assert: Test UUID keyword search service = WorkflowAppService() @@ -1085,7 +1070,7 @@ class TestWorkflowAppService: assert result_invalid_uuid["total"] == 0 def test_get_paginate_workflow_app_logs_with_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with edge cases and boundary conditions. @@ -1094,8 +1079,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -1108,8 +1091,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run with edge case data workflow_run = WorkflowRun( @@ -1132,8 +1115,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -1147,8 +1130,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() # Act & Assert: Test edge cases service = WorkflowAppService() @@ -1185,7 +1168,7 @@ class TestWorkflowAppService: assert result_high_page["has_more"] is False def test_get_paginate_workflow_app_logs_with_empty_results( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with empty results and no data scenarios. @@ -1252,7 +1235,7 @@ class TestWorkflowAppService: assert "Account not found" in str(exc_info.value) def test_get_paginate_workflow_app_logs_with_complex_query_combinations( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with complex query combinations. @@ -1352,7 +1335,7 @@ class TestWorkflowAppService: assert len(result_time_status_limit["data"]) <= 2 def test_get_paginate_workflow_app_logs_with_large_dataset_performance( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with large dataset for performance validation. @@ -1444,7 +1427,7 @@ class TestWorkflowAppService: assert result_last_page["page"] == 3 def test_get_paginate_workflow_app_logs_with_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with proper tenant isolation. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 1f91b40963..ab409deb89 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,8 +1,9 @@ import pytest from faker import Faker +from sqlalchemy.orm import Session -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.variables.segments import StringSegment +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable @@ -44,7 +45,7 @@ class TestWorkflowDraftVariableService: # WorkflowDraftVariableService doesn't have external dependencies that need mocking return {} - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None): + def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None): """ Helper method to create a test app with realistic data for testing. @@ -75,13 +76,11 @@ class TestWorkflowDraftVariableService: app.created_by = fake.uuid4() app.updated_by = app.created_by - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers, app, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None): """ Helper method to create a test workflow associated with an app. @@ -110,15 +109,14 @@ class TestWorkflowDraftVariableService: conversation_variables=[], rag_pipeline_variables=[], ) - from extensions.ext_database import db - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow def _create_test_variable( self, - db_session_with_containers, + db_session_with_containers: Session, app_id, node_id, name, @@ -174,13 +172,12 @@ class TestWorkflowDraftVariableService: visible=True, editable=True, ) - from extensions.ext_database import db - db.session.add(variable) - db.session.commit() + db_session_with_containers.add(variable) + db_session_with_containers.commit() return variable - def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting a single variable by ID successfully. @@ -202,7 +199,7 @@ class TestWorkflowDraftVariableService: assert retrieved_variable.app_id == app.id assert retrieved_variable.get_value().value == test_value.value - def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting a variable that doesn't exist. @@ -217,7 +214,7 @@ class TestWorkflowDraftVariableService: assert retrieved_variable is None def test_get_draft_variables_by_selectors_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting variables by selectors successfully. @@ -268,7 +265,7 @@ class TestWorkflowDraftVariableService: assert var.get_value().value == var3_value.value def test_list_variables_without_values_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test listing variables without values successfully with pagination. @@ -300,7 +297,7 @@ class TestWorkflowDraftVariableService: assert var.name is not None assert var.app_id == app.id - def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_node_variables_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test listing variables for a specific node successfully. @@ -352,7 +349,9 @@ class TestWorkflowDraftVariableService: assert "var2" in var_names assert "var3" not in var_names - def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_conversation_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test listing conversation variables successfully. @@ -393,7 +392,7 @@ class TestWorkflowDraftVariableService: assert "conv_var2" in var_names assert "sys_var" not in var_names - def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating a variable's name and value successfully. @@ -418,14 +417,15 @@ class TestWorkflowDraftVariableService: assert updated_variable.name == "new_name" assert updated_variable.get_value().value == new_value.value assert updated_variable.last_edited_at is not None - from extensions.ext_database import db - db.session.refresh(variable) + db_session_with_containers.refresh(variable) assert variable.name == "new_name" assert variable.get_value().value == new_value.value assert variable.last_edited_at is not None - def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_variable_not_editable( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that updating a non-editable variable raises an exception. @@ -445,17 +445,18 @@ class TestWorkflowDraftVariableService: node_execution_id=fake.uuid4(), editable=False, # Set as non-editable ) - from extensions.ext_database import db - db.session.add(variable) - db.session.commit() + db_session_with_containers.add(variable) + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) with pytest.raises(UpdateNotSupportedError) as exc_info: service.update_variable(variable, name="new_name", value=new_value) assert "variable not support updating" in str(exc_info.value) assert variable.id in str(exc_info.value) - def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_reset_conversation_variable_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test resetting conversation variable successfully. @@ -467,7 +468,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from core.workflow.variables.variables import StringVariable + from dify_graph.variables.variables import StringVariable conv_var = StringVariable( id=fake.uuid4(), @@ -476,9 +477,8 @@ class TestWorkflowDraftVariableService: selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"], ) workflow.conversation_variables = [conv_var] - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() modified_value = StringSegment(value=fake.word()) variable = self._create_test_variable( db_session_with_containers, @@ -489,17 +489,17 @@ class TestWorkflowDraftVariableService: fake=fake, ) variable.last_edited_at = fake.date_time() - db.session.commit() + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) reset_variable = service.reset_variable(workflow, variable) assert reset_variable is not None assert reset_variable.get_value().value == "default_value" assert reset_variable.last_edited_at is None - db.session.refresh(variable) + db_session_with_containers.refresh(variable) assert variable.get_value().value == "default_value" assert variable.last_edited_at is None - def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test deleting a single variable successfully. @@ -513,14 +513,15 @@ class TestWorkflowDraftVariableService: variable = self._create_test_variable( db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake ) - from extensions.ext_database import db - assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None + assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None service = WorkflowDraftVariableService(db_session_with_containers) service.delete_variable(variable) - assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None + assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None - def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_workflow_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting all variables for a workflow successfully. @@ -550,20 +551,25 @@ class TestWorkflowDraftVariableService: other_value, fake=fake, ) - from extensions.ext_database import db - app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() - other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + app_variables = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + ) assert len(app_variables) == 3 assert len(other_app_variables) == 1 service = WorkflowDraftVariableService(db_session_with_containers) service.delete_workflow_variables(app.id) - app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() - other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables_after = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + ) assert len(app_variables_after) == 0 assert len(other_app_variables_after) == 1 - def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_node_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting all variables for a specific node successfully. @@ -605,14 +611,15 @@ class TestWorkflowDraftVariableService: conv_value, fake=fake, ) - from extensions.ext_database import db - target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + target_node_variables = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + ) other_node_variables = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() ) conv_variables = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -622,13 +629,13 @@ class TestWorkflowDraftVariableService: service = WorkflowDraftVariableService(db_session_with_containers) service.delete_node_variables(app.id, node_id) target_node_variables_after = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() ) other_node_variables_after = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() ) conv_variables_after = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -637,7 +644,7 @@ class TestWorkflowDraftVariableService: assert len(conv_variables_after) == 1 def test_prefill_conversation_variable_default_values_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prefill conversation variable default values successfully. @@ -650,7 +657,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from core.workflow.variables.variables import StringVariable + from dify_graph.variables.variables import StringVariable conv_var1 = StringVariable( id=fake.uuid4(), @@ -665,13 +672,12 @@ class TestWorkflowDraftVariableService: selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"], ) workflow.conversation_variables = [conv_var1, conv_var2] - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) service.prefill_conversation_variable_default_values(workflow) draft_variables = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -686,7 +692,7 @@ class TestWorkflowDraftVariableService: assert var.get_variable_type() == DraftVariableType.CONVERSATION def test_get_conversation_id_from_draft_variable_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting conversation ID from draft variable successfully. @@ -713,7 +719,7 @@ class TestWorkflowDraftVariableService: assert retrieved_conv_id == conversation_id def test_get_conversation_id_from_draft_variable_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting conversation ID when it doesn't exist. @@ -728,7 +734,9 @@ class TestWorkflowDraftVariableService: retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) assert retrieved_conv_id is None - def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_system_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test listing system variables successfully. @@ -775,7 +783,9 @@ class TestWorkflowDraftVariableService: assert "sys_var2" in var_names assert "conv_var" not in var_names - def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_by_name_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting variables by name successfully for different types. @@ -822,7 +832,9 @@ class TestWorkflowDraftVariableService: assert retrieved_node_var.name == "test_node_var" assert retrieved_node_var.node_id == "test_node" - def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_by_name_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting variables by name when they don't exist. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 3a88081db3..38ef3975b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.enums import CreatorUserRole from models.model import ( @@ -48,7 +49,7 @@ class TestWorkflowRunService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -94,7 +95,7 @@ class TestWorkflowRunService: return app, account def _create_test_workflow_run( - self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0 + self, db_session_with_containers: Session, app, account, triggered_from="debugging", offset_minutes=0 ): """ Helper method to create a test workflow run for testing. @@ -110,8 +111,6 @@ class TestWorkflowRunService: """ fake = Faker() - from extensions.ext_database import db - # Create workflow run with offset timestamp base_time = datetime.now(UTC) created_time = base_time - timedelta(minutes=offset_minutes) @@ -136,12 +135,12 @@ class TestWorkflowRunService: finished_at=created_time, ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() return workflow_run - def _create_test_message(self, db_session_with_containers, app, account, workflow_run): + def _create_test_message(self, db_session_with_containers: Session, app, account, workflow_run): """ Helper method to create a test message for testing. @@ -156,8 +155,6 @@ class TestWorkflowRunService: """ fake = Faker() - from extensions.ext_database import db - # Create conversation first (required for message) from models.model import Conversation @@ -170,8 +167,8 @@ class TestWorkflowRunService: from_source=CreatorUserRole.ACCOUNT, from_account_id=account.id, ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message() @@ -193,12 +190,14 @@ class TestWorkflowRunService: message.workflow_run_id = workflow_run.id message.inputs = {"input": "test input"} - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message - def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_workflow_runs_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination of workflow runs with debugging trigger. @@ -239,7 +238,7 @@ class TestWorkflowRunService: assert workflow_run.tenant_id == app.tenant_id def test_get_paginate_workflow_runs_with_last_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of workflow runs with last_id parameter. @@ -282,7 +281,7 @@ class TestWorkflowRunService: assert workflow_run.tenant_id == app.tenant_id def test_get_paginate_workflow_runs_default_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of workflow runs with default limit. @@ -320,7 +319,7 @@ class TestWorkflowRunService: assert workflow_run_result.tenant_id == app.tenant_id def test_get_paginate_advanced_chat_workflow_runs_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination of advanced chat workflow runs with message information. @@ -365,7 +364,7 @@ class TestWorkflowRunService: assert workflow_run.app_id == app.id assert workflow_run.tenant_id == app.tenant_id - def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_workflow_run_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of workflow run by ID. @@ -395,7 +394,7 @@ class TestWorkflowRunService: assert result.type == "chat" assert result.version == "1.0.0" - def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_workflow_run_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test workflow run retrieval when run ID does not exist. @@ -419,7 +418,7 @@ class TestWorkflowRunService: assert result is None def test_get_workflow_run_node_executions_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of workflow run node executions. @@ -438,7 +437,6 @@ class TestWorkflowRunService: workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") # Create node executions - from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionModel node_executions = [] @@ -462,7 +460,7 @@ class TestWorkflowRunService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(node_execution) + db_session_with_containers.add(node_execution) node_executions.append(node_execution) paused_node_execution = WorkflowNodeExecutionModel( @@ -484,9 +482,9 @@ class TestWorkflowRunService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(paused_node_execution) + db_session_with_containers.add(paused_node_execution) - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test workflow_run_service = WorkflowRunService() @@ -509,7 +507,7 @@ class TestWorkflowRunService: assert node_execution.node_id.startswith("node_") def test_get_workflow_run_node_executions_empty( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions for a workflow run with no executions. @@ -560,7 +558,7 @@ class TestWorkflowRunService: assert len(result) == 0 def test_get_workflow_run_node_executions_invalid_workflow_run_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions with invalid workflow run ID. @@ -611,7 +609,7 @@ class TestWorkflowRunService: assert len(result) == 0 def test_get_workflow_run_node_executions_database_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions when database encounters an error. @@ -662,7 +660,7 @@ class TestWorkflowRunService: ) def test_get_workflow_run_node_executions_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test node execution retrieval for end user. @@ -680,7 +678,6 @@ class TestWorkflowRunService: workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") # Create end user - from extensions.ext_database import db from models.model import EndUser end_user = EndUser( @@ -692,8 +689,8 @@ class TestWorkflowRunService: external_user_id=str(uuid.uuid4()), name=fake.name(), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create node execution from models.workflow import WorkflowNodeExecutionModel @@ -717,8 +714,8 @@ class TestWorkflowRunService: created_by=end_user.id, created_at=datetime.now(UTC), ) - db.session.add(node_execution) - db.session.commit() + db_session_with_containers.add(node_execution) + db_session_with_containers.commit() # Act: Execute the method under test workflow_run_service = WorkflowRunService() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index c29cda9a73..bfb23bac68 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -10,6 +10,7 @@ from unittest.mock import MagicMock import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, App, Workflow from models.model import AppMode @@ -32,7 +33,7 @@ class TestWorkflowService: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test account with realistic data. @@ -67,18 +68,16 @@ class TestWorkflowService: tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at - from extensions.ext_database import db - - db.session.add(tenant) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.add(account) + db_session_with_containers.commit() # Set the current tenant for the account account.current_tenant = tenant return account - def _create_test_app(self, db_session_with_containers, fake=None): + def _create_test_app(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test app with realistic data. @@ -106,13 +105,11 @@ class TestWorkflowService: ) app.updated_by = app.created_by - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers, app, account, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None): """ Helper method to create a test workflow associated with an app. @@ -141,13 +138,11 @@ class TestWorkflowService: conversation_variables=[], ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow - def test_get_node_last_run_success(self, db_session_with_containers): + def test_get_node_last_run_success(self, db_session_with_containers: Session): """ Test successful retrieval of the most recent execution for a specific node. @@ -180,10 +175,8 @@ class TestWorkflowService: node_execution.created_by = account.id # Required field node_execution.created_at = fake.date_time_this_year() - from extensions.ext_database import db - - db.session.add(node_execution) - db.session.commit() + db_session_with_containers.add(node_execution) + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -196,7 +189,7 @@ class TestWorkflowService: assert result.workflow_id == workflow.id assert result.status == "succeeded" - def test_get_node_last_run_not_found(self, db_session_with_containers): + def test_get_node_last_run_not_found(self, db_session_with_containers: Session): """ Test retrieval when no execution record exists for the specified node. @@ -217,7 +210,7 @@ class TestWorkflowService: # Assert assert result is None - def test_is_workflow_exist_true(self, db_session_with_containers): + def test_is_workflow_exist_true(self, db_session_with_containers: Session): """ Test workflow existence check when a draft workflow exists. @@ -238,7 +231,7 @@ class TestWorkflowService: # Assert assert result is True - def test_is_workflow_exist_false(self, db_session_with_containers): + def test_is_workflow_exist_false(self, db_session_with_containers: Session): """ Test workflow existence check when no draft workflow exists. @@ -258,7 +251,7 @@ class TestWorkflowService: # Assert assert result is False - def test_get_draft_workflow_success(self, db_session_with_containers): + def test_get_draft_workflow_success(self, db_session_with_containers: Session): """ Test successful retrieval of a draft workflow. @@ -284,7 +277,7 @@ class TestWorkflowService: assert result.app_id == app.id assert result.tenant_id == app.tenant_id - def test_get_draft_workflow_not_found(self, db_session_with_containers): + def test_get_draft_workflow_not_found(self, db_session_with_containers: Session): """ Test draft workflow retrieval when no draft workflow exists. @@ -304,7 +297,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_published_workflow_by_id_success(self, db_session_with_containers): + def test_get_published_workflow_by_id_success(self, db_session_with_containers: Session): """ Test successful retrieval of a published workflow by ID. @@ -321,9 +314,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -336,7 +327,7 @@ class TestWorkflowService: assert result.version != Workflow.VERSION_DRAFT assert result.app_id == app.id - def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers): + def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers: Session): """ Test error when trying to retrieve a draft workflow as published. @@ -359,7 +350,7 @@ class TestWorkflowService: with pytest.raises(IsDraftWorkflowError): workflow_service.get_published_workflow_by_id(app, workflow.id) - def test_get_published_workflow_by_id_not_found(self, db_session_with_containers): + def test_get_published_workflow_by_id_not_found(self, db_session_with_containers: Session): """ Test retrieval when no workflow exists with the specified ID. @@ -379,7 +370,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_published_workflow_success(self, db_session_with_containers): + def test_get_published_workflow_success(self, db_session_with_containers: Session): """ Test successful retrieval of the current published workflow for an app. @@ -395,10 +386,8 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - app.workflow_id = workflow.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -411,7 +400,7 @@ class TestWorkflowService: assert result.version != Workflow.VERSION_DRAFT assert result.app_id == app.id - def test_get_published_workflow_no_workflow_id(self, db_session_with_containers): + def test_get_published_workflow_no_workflow_id(self, db_session_with_containers: Session): """ Test retrieval when app has no associated workflow ID. @@ -431,7 +420,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_all_published_workflow_pagination(self, db_session_with_containers): + def test_get_all_published_workflow_pagination(self, db_session_with_containers: Session): """ Test pagination of published workflows. @@ -455,15 +444,13 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflows[0].id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - First page result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, + session=db_session_with_containers, app_model=app, page=1, limit=3, @@ -476,7 +463,7 @@ class TestWorkflowService: # Act - Second page result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, + session=db_session_with_containers, app_model=app, page=2, limit=3, @@ -487,7 +474,7 @@ class TestWorkflowService: assert len(result_workflows) == 2 assert has_more is False - def test_get_all_published_workflow_user_filter(self, db_session_with_containers): + def test_get_all_published_workflow_user_filter(self, db_session_with_containers: Session): """ Test filtering published workflows by user. @@ -513,22 +500,20 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflow1.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - Filter by account1 result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, app_model=app, page=1, limit=10, user_id=account1.id + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=account1.id ) # Assert assert len(result_workflows) == 1 assert result_workflows[0].created_by == account1.id - def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers): + def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers: Session): """ Test filtering published workflows to show only named workflows. @@ -557,22 +542,20 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflow1.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - Filter named only result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, app_model=app, page=1, limit=10, user_id=None, named_only=True + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None, named_only=True ) # Assert assert len(result_workflows) == 2 assert all(wf.marked_name for wf in result_workflows) - def test_sync_draft_workflow_create_new(self, db_session_with_containers): + def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session): """ Test creating a new draft workflow through sync operation. @@ -624,7 +607,7 @@ class TestWorkflowService: assert result.features == json.dumps(features) assert result.created_by == account.id - def test_sync_draft_workflow_update_existing(self, db_session_with_containers): + def test_sync_draft_workflow_update_existing(self, db_session_with_containers: Session): """ Test updating an existing draft workflow through sync operation. @@ -688,7 +671,7 @@ class TestWorkflowService: assert result.features == json.dumps(new_features) assert result.updated_by == account.id - def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers): + def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers: Session): """ Test error when sync is attempted with mismatched hash. @@ -738,7 +721,7 @@ class TestWorkflowService: conversation_variables=conversation_variables, ) - def test_publish_workflow_success(self, db_session_with_containers): + def test_publish_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow publishing. @@ -755,9 +738,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = Workflow.VERSION_DRAFT - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -777,7 +758,7 @@ class TestWorkflowService: assert len(result.version) > 10 # Should be a reasonable timestamp length assert result.created_by == account.id - def test_publish_workflow_no_draft_error(self, db_session_with_containers): + def test_publish_workflow_no_draft_error(self, db_session_with_containers: Session): """ Test error when publishing workflow without draft. @@ -797,7 +778,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) - def test_publish_workflow_already_published_error(self, db_session_with_containers): + def test_publish_workflow_already_published_error(self, db_session_with_containers: Session): """ Test error when publishing already published workflow. @@ -813,9 +794,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Already published - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -823,7 +802,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) - def test_get_default_block_configs(self, db_session_with_containers): + def test_get_default_block_configs(self, db_session_with_containers: Session): """ Test retrieval of default block configurations for all node types. @@ -847,7 +826,7 @@ class TestWorkflowService: assert isinstance(config, dict) # The structure can vary, so we just check it's a dict - def test_get_default_block_config_specific_type(self, db_session_with_containers): + def test_get_default_block_config_specific_type(self, db_session_with_containers: Session): """ Test retrieval of default block configuration for a specific node type. @@ -867,7 +846,7 @@ class TestWorkflowService: # This is acceptable behavior assert result is None or isinstance(result, dict) - def test_get_default_block_config_invalid_type(self, db_session_with_containers): + def test_get_default_block_config_invalid_type(self, db_session_with_containers: Session): """ Test retrieval of default block configuration for invalid node type. @@ -887,7 +866,7 @@ class TestWorkflowService: # It's also acceptable for the service to raise a ValueError for invalid types pass - def test_get_default_block_config_with_filters(self, db_session_with_containers): + def test_get_default_block_config_with_filters(self, db_session_with_containers: Session): """ Test retrieval of default block configuration with filters. @@ -907,7 +886,7 @@ class TestWorkflowService: # Result might be None if filters don't match, but should not raise error assert result is None or isinstance(result, dict) - def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers): + def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers: Session): """ Test successful conversion from chat mode app to workflow mode. @@ -944,11 +923,9 @@ class TestWorkflowService: ) app_model_config.id = fake.uuid4() - from extensions.ext_database import db - - db.session.add(app_model_config) + db_session_with_containers.add(app_model_config) app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = { @@ -969,7 +946,7 @@ class TestWorkflowService: assert result.icon_type == conversion_args["icon_type"] assert result.icon_background == conversion_args["icon_background"] - def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers): + def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers: Session): """ Test successful conversion from completion mode app to workflow mode. @@ -1006,11 +983,9 @@ class TestWorkflowService: ) app_model_config.id = fake.uuid4() - from extensions.ext_database import db - - db.session.add(app_model_config) + db_session_with_containers.add(app_model_config) app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = { @@ -1031,7 +1006,7 @@ class TestWorkflowService: assert result.icon_type == conversion_args["icon_type"] assert result.icon_background == conversion_args["icon_background"] - def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers): + def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers: Session): """ Test error when attempting to convert unsupported app mode. @@ -1046,9 +1021,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.WORKFLOW - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = {"name": "Test"} @@ -1057,7 +1030,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Current App mode: workflow is not supported convert to workflow"): workflow_service.convert_to_workflow(app_model=app, account=account, args=conversion_args) - def test_validate_features_structure_advanced_chat(self, db_session_with_containers): + def test_validate_features_structure_advanced_chat(self, db_session_with_containers: Session): """ Test feature structure validation for advanced chat mode apps. @@ -1069,9 +1042,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.ADVANCED_CHAT - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = { @@ -1088,7 +1059,7 @@ class TestWorkflowService: # The exact behavior depends on the AdvancedChatAppConfigManager implementation assert result is not None or isinstance(result, dict) - def test_validate_features_structure_workflow(self, db_session_with_containers): + def test_validate_features_structure_workflow(self, db_session_with_containers: Session): """ Test feature structure validation for workflow mode apps. @@ -1100,9 +1071,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.WORKFLOW - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = {"workflow_config": {"max_steps": 10, "timeout": 300}} @@ -1115,7 +1084,7 @@ class TestWorkflowService: # The exact behavior depends on the WorkflowAppConfigManager implementation assert result is not None or isinstance(result, dict) - def test_validate_features_structure_invalid_mode(self, db_session_with_containers): + def test_validate_features_structure_invalid_mode(self, db_session_with_containers: Session): """ Test error when validating features for invalid app mode. @@ -1127,9 +1096,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = "invalid_mode" # Invalid mode - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = {"test": "value"} @@ -1138,7 +1105,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"): workflow_service.validate_features_structure(app_model=app, features=features) - def test_update_workflow_success(self, db_session_with_containers): + def test_update_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow update with allowed fields. @@ -1152,16 +1119,14 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() update_data = {"marked_name": "Updated Workflow Name", "marked_comment": "Updated workflow comment"} # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id, account_id=account.id, @@ -1174,7 +1139,7 @@ class TestWorkflowService: assert result.marked_comment == update_data["marked_comment"] assert result.updated_by == account.id - def test_update_workflow_not_found(self, db_session_with_containers): + def test_update_workflow_not_found(self, db_session_with_containers: Session): """ Test workflow update when workflow doesn't exist. @@ -1186,15 +1151,13 @@ class TestWorkflowService: account = self._create_test_account(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake) - from extensions.ext_database import db - workflow_service = WorkflowService() non_existent_workflow_id = fake.uuid4() update_data = {"marked_name": "Test"} # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id, account_id=account.id, @@ -1204,7 +1167,7 @@ class TestWorkflowService: # Assert assert result is None - def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers): + def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers: Session): """ Test that workflow update ignores disallowed fields. @@ -1218,9 +1181,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) original_name = workflow.marked_name - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() update_data = { @@ -1231,7 +1192,7 @@ class TestWorkflowService: # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id, account_id=account.id, @@ -1245,7 +1206,7 @@ class TestWorkflowService: assert result.graph == workflow.graph assert result.features == workflow.features - def test_delete_workflow_success(self, db_session_with_containers): + def test_delete_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow deletion. @@ -1262,25 +1223,23 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act result = workflow_service.delete_workflow( - session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id ) # Assert assert result is True # Verify workflow is actually deleted - deleted_workflow = db.session.query(Workflow).filter_by(id=workflow.id).first() + deleted_workflow = db_session_with_containers.query(Workflow).filter_by(id=workflow.id).first() assert deleted_workflow is None - def test_delete_workflow_draft_error(self, db_session_with_containers): + def test_delete_workflow_draft_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a draft workflow. @@ -1296,9 +1255,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) # Keep as draft version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -1306,9 +1263,11 @@ class TestWorkflowService: from services.errors.workflow_service import DraftWorkflowDeletionError with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow versions"): - workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) + workflow_service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id + ) - def test_delete_workflow_in_use_error(self, db_session_with_containers): + def test_delete_workflow_in_use_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a workflow that's in use by an app. @@ -1327,9 +1286,7 @@ class TestWorkflowService: # Associate workflow with app app.workflow_id = workflow.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -1337,9 +1294,11 @@ class TestWorkflowService: from services.errors.workflow_service import WorkflowInUseError with pytest.raises(WorkflowInUseError, match="Cannot delete workflow that is currently in use by app"): - workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) + workflow_service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id + ) - def test_delete_workflow_not_found_error(self, db_session_with_containers): + def test_delete_workflow_not_found_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a non-existent workflow. @@ -1351,17 +1310,15 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) non_existent_workflow_id = fake.uuid4() - from extensions.ext_database import db - workflow_service = WorkflowService() # Act & Assert with pytest.raises(ValueError, match=f"Workflow with ID {non_existent_workflow_id} not found"): workflow_service.delete_workflow( - session=db.session, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id + session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id ) - def test_run_free_workflow_node_success(self, db_session_with_containers): + def test_run_free_workflow_node_success(self, db_session_with_containers: Session): """ Test successful execution of a free workflow node. @@ -1393,8 +1350,8 @@ class TestWorkflowService: from unittest.mock import patch - from core.app.workflow.node_factory import DifyNodeFactory from core.model_manager import ModelInstance + from core.workflow.node_factory import DifyNodeFactory # Act with patch.object( @@ -1413,7 +1370,7 @@ class TestWorkflowService: assert result.workflow_id == "" # No workflow ID for free nodes assert result.index == 1 - def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers): + def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers: Session): """ Test execution of a free workflow node with complex input data. @@ -1454,7 +1411,7 @@ class TestWorkflowService: error_msg = str(exc_info.value).lower() assert any(keyword in error_msg for keyword in ["start", "not supported", "external"]) - def test_handle_node_run_result_success(self, db_session_with_containers): + def test_handle_node_run_result_success(self, db_session_with_containers: Session): """ Test successful handling of node run results. @@ -1472,10 +1429,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus - from core.workflow.graph_events import NodeRunSucceededEvent - from core.workflow.node_events import NodeRunResult - from core.workflow.nodes.base.node import Node + from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus + from dify_graph.graph_events import NodeRunSucceededEvent + from dify_graph.node_events import NodeRunResult + from dify_graph.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1517,19 +1474,19 @@ class TestWorkflowService: # Assert assert result is not None assert result.node_id == node_id - from core.workflow.enums import NodeType + from dify_graph.enums import NodeType assert result.node_type == NodeType.START # Should match the mock node type assert result.title == "Test Node" # Import the enum for comparison - from core.workflow.enums import WorkflowNodeExecutionStatus + from dify_graph.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs is not None assert result.outputs is not None assert result.process_data is not None - def test_handle_node_run_result_failure(self, db_session_with_containers): + def test_handle_node_run_result_failure(self, db_session_with_containers: Session): """ Test handling of failed node run results. @@ -1547,10 +1504,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus - from core.workflow.graph_events import NodeRunFailedEvent - from core.workflow.node_events import NodeRunResult - from core.workflow.nodes.base.node import Node + from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus + from dify_graph.graph_events import NodeRunFailedEvent + from dify_graph.node_events import NodeRunResult + from dify_graph.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1592,13 +1549,13 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from core.workflow.enums import WorkflowNodeExecutionStatus + from dify_graph.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error is not None assert "Test error message" in str(result.error) - def test_handle_node_run_result_continue_on_error(self, db_session_with_containers): + def test_handle_node_run_result_continue_on_error(self, db_session_with_containers: Session): """ Test handling of node run results with continue_on_error strategy. @@ -1616,10 +1573,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus - from core.workflow.graph_events import NodeRunFailedEvent - from core.workflow.node_events import NodeRunResult - from core.workflow.nodes.base.node import Node + from dify_graph.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus + from dify_graph.graph_events import NodeRunFailedEvent + from dify_graph.node_events import NodeRunResult + from dify_graph.nodes.base.node import Node # Create mock node with continue_on_error mock_node = MagicMock(spec=Node) @@ -1662,7 +1619,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from core.workflow.enums import WorkflowNodeExecutionStatus + from dify_graph.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.EXCEPTION # Should be EXCEPTION, not FAILED assert result.outputs is not None diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 4249642bc9..92dec24c7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.workspace_service import WorkspaceService @@ -29,7 +30,7 @@ class TestWorkspaceService: "dify_config": mock_dify_config, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,10 +51,8 @@ class TestWorkspaceService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -62,8 +61,8 @@ class TestWorkspaceService: plan="basic", custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}', ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -72,15 +71,15 @@ class TestWorkspaceService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_info_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tenant information with all features enabled. @@ -121,13 +120,12 @@ class TestWorkspaceService: assert "replace_webapp_logo" in result["custom_config"] # Verify database state - from extensions.ext_database import db - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_without_custom_config( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval when custom config features are disabled. @@ -167,13 +165,12 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - from extensions.ext_database import db - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_normal_user_role( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for normal user role without privileged features. @@ -191,11 +188,14 @@ class TestWorkspaceService: ) # Update the join to have normal role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.NORMAL - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -220,11 +220,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_admin_role_and_logo_replacement( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for admin role with logo replacement enabled. @@ -242,11 +242,14 @@ class TestWorkspaceService: ) # Update the join to have admin role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.ADMIN - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -268,10 +271,12 @@ class TestWorkspaceService: assert "replace_webapp_logo" in result["custom_config"] # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None - def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_info_with_tenant_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant info retrieval when tenant parameter is None. @@ -290,7 +295,7 @@ class TestWorkspaceService: assert result is None def test_get_tenant_info_with_custom_config_variations( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval with various custom config configurations. @@ -323,10 +328,8 @@ class TestWorkspaceService: # Update tenant custom config import json - from extensions.ext_database import db - tenant.custom_config = json.dumps(config) - db.session.commit() + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -353,11 +356,11 @@ class TestWorkspaceService: assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"] # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_editor_role_and_limited_permissions( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for editor role with limited permissions. @@ -375,11 +378,14 @@ class TestWorkspaceService: ) # Update the join to have editor role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.EDITOR - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -400,11 +406,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_dataset_operator_role( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for dataset operator role. @@ -422,11 +428,14 @@ class TestWorkspaceService: ) # Update the join to have dataset operator role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.DATASET_OPERATOR - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -447,11 +456,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_complex_custom_config_scenarios( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval with complex custom config scenarios. @@ -491,10 +500,8 @@ class TestWorkspaceService: # Update tenant custom config import json - from extensions.ext_database import db - tenant.custom_config = json.dumps(config) - db.session.commit() + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -525,5 +532,5 @@ class TestWorkspaceService: assert result["custom_config"]["remove_webapp_brand"] is False # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 2ff71ea6ea..bffdca623a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker from pydantic import TypeAdapter, ValidationError +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant @@ -34,7 +35,7 @@ class TestApiToolManageService: "provider_controller": mock_provider_controller, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -55,18 +56,16 @@ class TestApiToolManageService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -77,8 +76,8 @@ class TestApiToolManageService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -118,7 +117,7 @@ class TestApiToolManageService: """ def test_parser_api_schema_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful parsing of API schema. @@ -163,7 +162,7 @@ class TestApiToolManageService: assert api_key_value_field["default"] == "" def test_parser_api_schema_invalid_schema( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parsing of invalid API schema. @@ -183,7 +182,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_parser_api_schema_malformed_json( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parsing of malformed JSON schema. @@ -203,7 +202,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_convert_schema_to_tool_bundles_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of schema to tool bundles. @@ -233,7 +232,7 @@ class TestApiToolManageService: assert tool_bundle.operation_id == "testOperation" def test_convert_schema_to_tool_bundles_with_extra_info( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of schema to tool bundles with extra info. @@ -259,7 +258,7 @@ class TestApiToolManageService: assert isinstance(schema_type, str) def test_convert_schema_to_tool_bundles_invalid_schema( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of invalid schema to tool bundles. @@ -279,7 +278,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_create_api_tool_provider_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful creation of API tool provider. @@ -324,10 +323,9 @@ class TestApiToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db provider = ( - db.session.query(ApiToolProvider) + db_session_with_containers.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .first() ) @@ -347,7 +345,7 @@ class TestApiToolManageService: mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once() def test_create_api_tool_provider_duplicate_name( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with duplicate name. @@ -404,7 +402,7 @@ class TestApiToolManageService: assert f"provider {provider_name} already exists" in str(exc_info.value) def test_create_api_tool_provider_invalid_schema_type( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with invalid schema type. @@ -436,7 +434,7 @@ class TestApiToolManageService: assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with missing auth type. @@ -479,7 +477,7 @@ class TestApiToolManageService: assert "auth_type is required" in str(exc_info.value) def test_create_api_tool_provider_with_api_key_auth( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful creation of API tool provider with API key authentication. @@ -522,10 +520,9 @@ class TestApiToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db provider = ( - db.session.query(ApiToolProvider) + db_session_with_containers.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .first() ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 6cae83ac37..0f2e3980af 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ToolProviderType from models import Account, Tenant @@ -41,7 +42,7 @@ class TestMCPToolManageService: "tool_transform_service": mock_tool_transform_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -62,18 +63,16 @@ class TestMCPToolManageService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -84,8 +83,8 @@ class TestMCPToolManageService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -93,7 +92,7 @@ class TestMCPToolManageService: return account, tenant def _create_test_mcp_provider( - self, db_session_with_containers, mock_external_service_dependencies, tenant_id, user_id + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, user_id ): """ Helper method to create a test MCP tool provider for testing. @@ -124,15 +123,13 @@ class TestMCPToolManageService: sse_read_timeout=300.0, ) - from extensions.ext_database import db - - db.session.add(mcp_provider) - db.session.commit() + db_session_with_containers.add(mcp_provider) + db_session_with_containers.commit() return mcp_provider def test_get_mcp_provider_by_provider_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of MCP provider by provider ID. @@ -153,9 +150,8 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id) # Assert: Verify the expected outcomes @@ -166,12 +162,12 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.server_identifier == mcp_provider.server_identifier def test_get_mcp_provider_by_provider_id_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP provider is not found by provider ID. @@ -190,14 +186,13 @@ class TestMCPToolManageService: non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id) def test_get_mcp_provider_by_provider_id_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation when retrieving MCP provider by provider ID. @@ -223,14 +218,13 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id) def test_get_mcp_provider_by_server_identifier_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of MCP provider by server identifier. @@ -251,9 +245,8 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id) # Assert: Verify the expected outcomes @@ -264,12 +257,12 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.name == mcp_provider.name def test_get_mcp_provider_by_server_identifier_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP provider is not found by server identifier. @@ -288,14 +281,13 @@ class TestMCPToolManageService: non_existent_identifier = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id) def test_get_mcp_provider_by_server_identifier_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation when retrieving MCP provider by server identifier. @@ -321,13 +313,12 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id) - def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful creation of MCP provider. @@ -365,9 +356,8 @@ class TestMCPToolManageService: # Act: Execute the method under test from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", @@ -389,10 +379,9 @@ class TestMCPToolManageService: assert result.type == ToolProviderType.MCP # Verify database state - from extensions.ext_database import db created_provider = ( - db.session.query(MCPToolProvider) + db_session_with_containers.query(MCPToolProvider) .filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider") .first() ) @@ -410,7 +399,9 @@ class TestMCPToolManageService: ) mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once() - def test_create_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_mcp_provider_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when creating MCP provider with duplicate name. @@ -427,9 +418,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", @@ -463,7 +453,7 @@ class TestMCPToolManageService: ) def test_create_mcp_provider_duplicate_server_url( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when creating MCP provider with duplicate server URL. @@ -481,9 +471,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", @@ -517,7 +506,7 @@ class TestMCPToolManageService: ) def test_create_mcp_provider_duplicate_server_identifier( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when creating MCP provider with duplicate server identifier. @@ -535,9 +524,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", @@ -570,7 +558,7 @@ class TestMCPToolManageService: ), ) - def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of MCP tools for a tenant. @@ -602,9 +590,7 @@ class TestMCPToolManageService: ) provider3.name = "Gamma Provider" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Setup mock for transformation service from core.tools.entities.api_entities import ToolProviderApiEntity @@ -647,9 +633,8 @@ class TestMCPToolManageService: ] # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_providers(tenant_id=tenant.id, for_list=True) # Assert: Verify the expected outcomes @@ -666,7 +651,9 @@ class TestMCPToolManageService: mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3 ) - def test_retrieve_mcp_tools_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_empty_list( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of MCP tools when tenant has no providers. @@ -684,9 +671,8 @@ class TestMCPToolManageService: # No MCP providers created for this tenant # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_providers(tenant_id=tenant.id, for_list=False) # Assert: Verify the expected outcomes @@ -697,7 +683,9 @@ class TestMCPToolManageService: # Verify no transformation service calls for empty list mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called() - def test_retrieve_mcp_tools_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_tenant_isolation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant isolation when retrieving MCP tools. @@ -756,9 +744,8 @@ class TestMCPToolManageService: ] # Act: Execute the method under test for both tenants - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result1 = service.list_providers(tenant_id=tenant1.id, for_list=True) result2 = service.list_providers(tenant_id=tenant2.id, for_list=True) @@ -769,7 +756,7 @@ class TestMCPToolManageService: assert result2[0].id == provider2.id def test_list_mcp_tool_from_remote_server_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful listing of MCP tools from remote server. @@ -797,9 +784,7 @@ class TestMCPToolManageService: mcp_provider.authed = True # Provider must be authenticated to list tools mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -821,9 +806,8 @@ class TestMCPToolManageService: mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes @@ -834,7 +818,7 @@ class TestMCPToolManageService: # Note: server_url is mocked, so we skip that assertion to avoid encryption issues # Verify database state was updated - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True assert mcp_provider.tools != "[]" assert mcp_provider.updated_at is not None @@ -844,7 +828,7 @@ class TestMCPToolManageService: mock_mcp_client.assert_called_once() def test_list_mcp_tool_from_remote_server_auth_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP server requires authentication. @@ -871,9 +855,7 @@ class TestMCPToolManageService: mcp_provider.authed = False mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -887,19 +869,18 @@ class TestMCPToolManageService: mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="Please auth the tool first"): service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is False assert mcp_provider.tools == "[]" def test_list_mcp_tool_from_remote_server_connection_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP server connection fails. @@ -926,9 +907,7 @@ class TestMCPToolManageService: mcp_provider.authed = True # Provider must be authenticated to test connection errors mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -942,18 +921,17 @@ class TestMCPToolManageService: mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"): service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True # Provider remains authenticated assert mcp_provider.tools == "[]" - def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful deletion of MCP tool. @@ -974,20 +952,19 @@ class TestMCPToolManageService: ) # Verify provider exists - from extensions.ext_database import db - assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None + assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None # Act: Execute the method under test - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes # Provider should be deleted from database - deleted_provider = db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() + deleted_provider = db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() assert deleted_provider is None - def test_delete_mcp_tool_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when deleting non-existent MCP tool. @@ -1005,13 +982,14 @@ class TestMCPToolManageService: non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id) - def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_tenant_isolation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant isolation when deleting MCP tool. @@ -1036,18 +1014,16 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id) # Verify provider still exists in tenant1 - from extensions.ext_database import db - assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None + assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None - def test_update_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful update of MCP provider. @@ -1070,14 +1046,12 @@ class TestMCPToolManageService: original_name = mcp_provider.name original_icon = mcp_provider.icon - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test from core.entities.mcp_provider import MCPConfiguration - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider( tenant_id=tenant.id, provider_id=mcp_provider.id, @@ -1094,7 +1068,7 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.name == "Updated MCP Provider" assert mcp_provider.server_identifier == "updated_identifier_123" assert mcp_provider.timeout == 45.0 @@ -1108,7 +1082,9 @@ class TestMCPToolManageService: assert icon_data["content"] == "🚀" assert icon_data["background"] == "#4ECDC4" - def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_mcp_provider_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when updating MCP provider with duplicate name. @@ -1134,15 +1110,12 @@ class TestMCPToolManageService: ) provider2.name = "Second Provider" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Act & Assert: Verify proper error handling for duplicate name from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool First Provider already exists"): service.update_provider( tenant_id=tenant.id, @@ -1160,7 +1133,7 @@ class TestMCPToolManageService: ) def test_update_mcp_provider_credentials_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of MCP provider credentials. @@ -1185,9 +1158,7 @@ class TestMCPToolManageService: mcp_provider.authed = False mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the provider controller and encryption with ( @@ -1202,9 +1173,8 @@ class TestMCPToolManageService: mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider_credentials( provider_id=mcp_provider.id, tenant_id=tenant.id, @@ -1213,7 +1183,7 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True assert mcp_provider.updated_at is not None @@ -1225,7 +1195,7 @@ class TestMCPToolManageService: assert "new_key" in credentials def test_update_mcp_provider_credentials_not_authed( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of MCP provider credentials when not authenticated. @@ -1249,9 +1219,7 @@ class TestMCPToolManageService: mcp_provider.authed = True mcp_provider.tools = '[{"name": "test_tool"}]' - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the provider controller and encryption with ( @@ -1266,9 +1234,8 @@ class TestMCPToolManageService: mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider_credentials( provider_id=mcp_provider.id, tenant_id=tenant.id, @@ -1277,12 +1244,14 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is False assert mcp_provider.tools == "[]" assert mcp_provider.updated_at is not None - def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_re_connect_mcp_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful reconnection to MCP provider. @@ -1343,7 +1312,9 @@ class TestMCPToolManageService: sse_read_timeout=mcp_provider.sse_read_timeout, ) - def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_re_connect_mcp_provider_auth_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test reconnection to MCP provider when authentication fails. @@ -1385,7 +1356,7 @@ class TestMCPToolManageService: assert result.encrypted_credentials == "{}" def test_re_connect_mcp_provider_connection_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test reconnection to MCP provider when connection fails. diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index fa13790942..f3736333ea 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject @@ -27,7 +28,7 @@ class TestToolTransformService: } def _create_test_tool_provider( - self, db_session_with_containers, mock_external_service_dependencies, provider_type="api" + self, db_session_with_containers: Session, mock_external_service_dependencies, provider_type="api" ): """ Helper method to create a test tool provider for testing. @@ -89,14 +90,12 @@ class TestToolTransformService: else: raise ValueError(f"Unknown provider type: {provider_type}") - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() return provider - def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_plugin_icon_url_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful plugin icon URL generation. @@ -126,7 +125,7 @@ class TestToolTransformService: assert result == expected_url def test_get_plugin_icon_url_with_empty_console_url( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test plugin icon URL generation when CONSOLE_API_URL is empty. @@ -156,7 +155,7 @@ class TestToolTransformService: assert result == expected_url def test_get_tool_provider_icon_url_builtin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for builtin providers. @@ -194,7 +193,7 @@ class TestToolTransformService: assert result == expected_encoded def test_get_tool_provider_icon_url_api_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for API providers. @@ -220,7 +219,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_api_invalid_json( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tool provider icon URL generation for API providers with invalid JSON. @@ -246,7 +245,7 @@ class TestToolTransformService: assert result["content"] == "😁" or result["content"] == "\ud83d\ude01" def test_get_tool_provider_icon_url_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for workflow providers. @@ -271,7 +270,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_mcp_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for MCP providers. @@ -296,7 +295,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_unknown_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tool provider icon URL generation for unknown provider types. @@ -317,7 +316,9 @@ class TestToolTransformService: # Assert: Verify the expected outcomes assert result == "" - def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_dict_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful provider repacking with dictionary input. @@ -341,7 +342,9 @@ class TestToolTransformService: # Note: provider name may contain spaces that get URL encoded assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"] - def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_entity_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful provider repacking with ToolProviderApiEntity input. @@ -389,7 +392,7 @@ class TestToolTransformService: assert "test_icon_dark.png" in provider.icon_dark def test_repack_provider_entity_no_plugin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful provider repacking with ToolProviderApiEntity input without plugin_id. @@ -435,7 +438,9 @@ class TestToolTransformService: assert provider.icon_dark["background"] == "#252525" assert provider.icon_dark["content"] == "🔧" - def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_entity_no_dark_icon( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test provider repacking with ToolProviderApiEntity input without dark icon. @@ -477,7 +482,7 @@ class TestToolTransformService: assert provider.icon_dark == "" def test_builtin_provider_to_user_provider_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of builtin provider to user provider. @@ -545,7 +550,7 @@ class TestToolTransformService: assert result.original_credentials == {"api_key": "decrypted_key"} def test_builtin_provider_to_user_provider_plugin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of builtin provider to user provider with plugin. @@ -589,7 +594,7 @@ class TestToolTransformService: assert result.allow_delete is False def test_builtin_provider_to_user_provider_no_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of builtin provider to user provider without credentials. @@ -630,7 +635,9 @@ class TestToolTransformService: assert result.allow_delete is False assert result.masked_credentials == {"api_key": ""} - def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_api_provider_to_controller_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion of API provider to controller. @@ -655,10 +662,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -669,7 +674,7 @@ class TestToolTransformService: # Additional assertions would depend on the actual controller implementation def test_api_provider_to_controller_api_key_query( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of API provider to controller with api_key_query auth type. @@ -693,10 +698,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -706,7 +709,7 @@ class TestToolTransformService: assert hasattr(result, "from_db") def test_api_provider_to_controller_backward_compatibility( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of API provider to controller with backward compatibility auth types. @@ -731,10 +734,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -744,7 +745,7 @@ class TestToolTransformService: assert hasattr(result, "from_db") def test_workflow_provider_to_controller_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of workflow provider to controller. @@ -769,10 +770,8 @@ class TestToolTransformService: parameter_configuration="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Mock the WorkflowToolProviderController.from_db method to avoid app dependency with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db: diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 24fe5c4670..0b3c1112bd 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker from pydantic import ValidationError +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError @@ -63,7 +64,7 @@ class TestWorkflowToolManageService: "tool_transform_service": mock_tool_transform_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -119,14 +120,12 @@ class TestWorkflowToolManageService: conversation_variables=[], ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Update app to reference the workflow app.workflow_id = workflow.id - db.session.commit() + db_session_with_containers.commit() return app, account, workflow @@ -153,7 +152,9 @@ class TestWorkflowToolManageService: ), ] - def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful workflow tool creation with valid parameters. @@ -198,11 +199,10 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db # Check if workflow tool provider was created created_tool_provider = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -230,7 +230,7 @@ class TestWorkflowToolManageService: ].workflow_provider_to_controller.assert_called_once() def test_create_workflow_tool_duplicate_name_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when name already exists. @@ -280,10 +280,9 @@ class TestWorkflowToolManageService: assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value) # Verify only one tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -293,7 +292,7 @@ class TestWorkflowToolManageService: assert tool_count == 1 def test_create_workflow_tool_invalid_app_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app does not exist. @@ -331,10 +330,9 @@ class TestWorkflowToolManageService: assert f"App {non_existent_app_id} not found" in str(exc_info.value) # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -344,7 +342,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_invalid_parameters_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when parameters are invalid. @@ -387,10 +385,9 @@ class TestWorkflowToolManageService: assert "validation error" in str(exc_info.value).lower() # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -400,7 +397,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_duplicate_app_id_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app_id already exists. @@ -450,10 +447,9 @@ class TestWorkflowToolManageService: assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value) # Verify only one tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -463,7 +459,7 @@ class TestWorkflowToolManageService: assert tool_count == 1 def test_create_workflow_tool_workflow_not_found_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app has no workflow. @@ -481,10 +477,9 @@ class TestWorkflowToolManageService: ) # Remove workflow reference from app - from extensions.ext_database import db app.workflow_id = None - db.session.commit() + db_session_with_containers.commit() # Attempt to create workflow tool for app without workflow tool_parameters = self._create_test_workflow_tool_parameters() @@ -505,7 +500,7 @@ class TestWorkflowToolManageService: # Verify no workflow tool was created tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -515,7 +510,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when workflow contains human input nodes. @@ -558,10 +553,8 @@ class TestWorkflowToolManageService: assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - from extensions.ext_database import db - tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -570,7 +563,9 @@ class TestWorkflowToolManageService: assert tool_count == 0 - def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful workflow tool update with valid parameters. @@ -603,10 +598,9 @@ class TestWorkflowToolManageService: ) # Get the created tool - from extensions.ext_database import db created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -641,7 +635,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify database state was updated - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label @@ -658,7 +652,7 @@ class TestWorkflowToolManageService: mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() def test_update_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool update fails when workflow contains human input nodes. @@ -689,10 +683,8 @@ class TestWorkflowToolManageService: parameters=initial_tool_parameters, ) - from extensions.ext_database import db - created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -712,7 +704,7 @@ class TestWorkflowToolManageService: ] } ) - db.session.commit() + db_session_with_containers.commit() with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: WorkflowToolManageService.update_workflow_tool( @@ -728,10 +720,12 @@ class TestWorkflowToolManageService: assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool.name == original_name - def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_workflow_tool_not_found_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test workflow tool update fails when tool does not exist. @@ -768,10 +762,9 @@ class TestWorkflowToolManageService: assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value) # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -781,7 +774,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_update_workflow_tool_same_name_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool update succeeds when keeping the same name. @@ -813,10 +806,9 @@ class TestWorkflowToolManageService: ) # Get the created tool - from extensions.ext_database import db created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -840,12 +832,12 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify tool still exists with the same name - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool.name == first_tool_name assert created_tool.updated_at is not None def test_create_workflow_tool_with_file_parameter_default( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation with FILE parameter having a file object as default. @@ -916,7 +908,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} def test_create_workflow_tool_with_files_parameter_default( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation with FILES (Array[File]) parameter having file objects as default. @@ -991,7 +983,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} def test_create_workflow_tool_db_commit_before_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that database commit happens before validation, causing DB pollution on validation failure. @@ -1035,10 +1027,9 @@ class TestWorkflowToolManageService: # Verify the tool was NOT created in database # This is the expected behavior (no pollution) - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name, diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 2ffb884b82..8c007877fd 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -11,9 +12,9 @@ from core.app.app_config.entities import ( ModelConfigEntity, PromptTemplateEntity, ) -from core.model_runtime.entities.llm_entities import LLMMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension from models.model import App, AppMode, AppModelConfig @@ -79,7 +80,7 @@ class TestWorkflowConverter: mock_config.app_model_config_dict = {} return mock_config - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -100,18 +101,16 @@ class TestWorkflowConverter: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -122,15 +121,17 @@ class TestWorkflowConverter: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account): + def _create_test_app( + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant, account + ): """ Helper method to create a test app for testing. @@ -163,10 +164,8 @@ class TestWorkflowConverter: updated_by=account.id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -177,16 +176,16 @@ class TestWorkflowConverter: created_by=account.id, updated_by=account.id, ) - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Link app model config to app app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() return app - def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_workflow_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful conversion of app to workflow. @@ -225,19 +224,18 @@ class TestWorkflowConverter: assert new_app.created_by == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(new_app) + db_session_with_containers.refresh(new_app) assert new_app.id is not None # Verify workflow was created - workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first() + workflow = db_session_with_containers.query(Workflow).where(Workflow.app_id == new_app.id).first() assert workflow is not None assert workflow.tenant_id == app.tenant_id assert workflow.type == "chat" def test_convert_to_workflow_without_app_model_config_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when app model config is missing. @@ -270,16 +268,14 @@ class TestWorkflowConverter: updated_by=account.id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling workflow_converter = WorkflowConverter() # Check initial state - initial_workflow_count = db.session.query(Workflow).count() + initial_workflow_count = db_session_with_containers.query(Workflow).count() with pytest.raises(ValueError, match="App model config is required"): workflow_converter.convert_to_workflow( @@ -294,12 +290,12 @@ class TestWorkflowConverter: # Verify database state remains unchanged # The workflow creation happens in convert_app_model_config_to_workflow # which is called before the app_model_config check, so we need to clean up - db.session.rollback() - final_workflow_count = db.session.query(Workflow).count() + db_session_with_containers.rollback() + final_workflow_count = db_session_with_containers.query(Workflow).count() assert final_workflow_count == initial_workflow_count def test_convert_app_model_config_to_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of app model config to workflow. @@ -356,16 +352,17 @@ class TestWorkflowConverter: assert answer_node["id"] == "answer" # Verify database state - from extensions.ext_database import db - db.session.refresh(workflow) + db_session_with_containers.refresh(workflow) assert workflow.id is not None # Verify features were set features = json.loads(workflow._features) if workflow._features else {} assert isinstance(features, dict) - def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_start_node_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion to start node. @@ -410,7 +407,9 @@ class TestWorkflowConverter: assert second_variable["label"] == "Number Input" assert second_variable["type"] == "number" - def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_http_request_node_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion to HTTP request node. @@ -436,10 +435,8 @@ class TestWorkflowConverter: api_endpoint="https://api.example.com/test", ) - from extensions.ext_database import db - - db.session.add(api_based_extension) - db.session.commit() + db_session_with_containers.add(api_based_extension) + db_session_with_containers.commit() # Mock encrypter mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key" @@ -489,7 +486,7 @@ class TestWorkflowConverter: assert external_data_variable_node_mapping["external_data"] == code_node["id"] def test_convert_to_knowledge_retrieval_node_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion to knowledge retrieval node. diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index f3ba126706..af9e8d0b2c 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -4,7 +4,7 @@ from uuid import uuid4 from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 8bb536c34a..efeb29cf20 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType -from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -31,7 +31,9 @@ class TestAddDocumentToIndexTask: "index_processor": mock_processor, } - def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_dataset_and_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test dataset and document for testing. @@ -51,15 +53,15 @@ class TestAddDocumentToIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -68,8 +70,8 @@ class TestAddDocumentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -81,8 +83,8 @@ class TestAddDocumentToIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document document = Document( @@ -99,15 +101,15 @@ class TestAddDocumentToIndexTask: enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property works correctly - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, document - def _create_test_segments(self, db_session_with_containers, document, dataset): + def _create_test_segments(self, db_session_with_containers: Session, document, dataset): """ Helper method to create test document segments. @@ -138,13 +140,15 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() return segments - def test_add_document_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_add_document_to_index_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful document indexing with paragraph index type. @@ -180,9 +184,9 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -191,7 +195,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_different_index_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing with different index types. @@ -209,10 +213,10 @@ class TestAddDocumentToIndexTask: # Update document to use different index type document.doc_form = IndexStructureType.QA_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -237,9 +241,9 @@ class TestAddDocumentToIndexTask: assert len(documents) == 3 # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -248,7 +252,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent document. @@ -275,7 +279,7 @@ class TestAddDocumentToIndexTask: # because indexing_cache_key is not defined in that case def test_add_document_to_index_invalid_indexing_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of document with invalid indexing status. @@ -294,7 +298,7 @@ class TestAddDocumentToIndexTask: # Set invalid indexing status document.indexing_status = "processing" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task add_document_to_index_task(document.id) @@ -304,7 +308,7 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_add_document_to_index_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when document's dataset doesn't exist. @@ -326,14 +330,14 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Delete the dataset to simulate dataset not found scenario - db.session.delete(dataset) - db.session.commit() + db_session_with_containers.delete(dataset) + db_session_with_containers.commit() # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False assert document.indexing_status == "error" assert document.error is not None @@ -348,7 +352,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_parent_child_structure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing with parent-child structure. @@ -367,10 +371,10 @@ class TestAddDocumentToIndexTask: # Update document to use parent-child index type document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments with mock child chunks segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -413,9 +417,9 @@ class TestAddDocumentToIndexTask: assert len(doc.children) == 2 # Each document has 2 children # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -424,7 +428,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_already_enabled_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing when segments are already enabled. @@ -459,10 +463,10 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -488,7 +492,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_auto_disable_log_deletion( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that auto disable logs are properly deleted during indexing. @@ -515,10 +519,10 @@ class TestAddDocumentToIndexTask: document_id=document.id, ) log_entry.id = str(fake.uuid4()) - db.session.add(log_entry) + db_session_with_containers.add(log_entry) auto_disable_logs.append(log_entry) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -526,7 +530,9 @@ class TestAddDocumentToIndexTask: # Verify logs exist before processing existing_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() + db_session_with_containers.query(DatasetAutoDisableLog) + .where(DatasetAutoDisableLog.document_id == document.id) + .all() ) assert len(existing_logs) == 2 @@ -535,7 +541,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify auto disable logs were deleted remaining_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() + db_session_with_containers.query(DatasetAutoDisableLog) + .where(DatasetAutoDisableLog.document_id == document.id) + .all() ) assert len(remaining_logs) == 0 @@ -547,14 +555,14 @@ class TestAddDocumentToIndexTask: # Verify segments were enabled for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True # Verify redis cache was cleared assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_general_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test general exception handling during indexing process. @@ -584,7 +592,7 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False assert document.indexing_status == "error" assert document.error is not None @@ -593,14 +601,14 @@ class TestAddDocumentToIndexTask: # Verify segments were not enabled due to error for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False # Should remain disabled due to error # Verify redis cache was still cleared despite error assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_segment_filtering_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment filtering with various edge cases. @@ -638,7 +646,7 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment1) + db_session_with_containers.add(segment1) segments.append(segment1) # Segment 2: Should be processed (enabled=True, status="completed") @@ -658,7 +666,7 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment2) + db_session_with_containers.add(segment2) segments.append(segment2) # Segment 3: Should NOT be processed (enabled=False, status="processing") @@ -677,7 +685,7 @@ class TestAddDocumentToIndexTask: status="processing", # Not completed created_by=document.created_by, ) - db.session.add(segment3) + db_session_with_containers.add(segment3) segments.append(segment3) # Segment 4: Should be processed (enabled=False, status="completed") @@ -696,10 +704,10 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment4) + db_session_with_containers.add(segment4) segments.append(segment4) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -728,11 +736,11 @@ class TestAddDocumentToIndexTask: assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3 # Verify database state changes - db.session.refresh(document) - db.session.refresh(segment1) - db.session.refresh(segment2) - db.session.refresh(segment3) - db.session.refresh(segment4) + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(segment1) + db_session_with_containers.refresh(segment2) + db_session_with_containers.refresh(segment3) + db_session_with_containers.refresh(segment4) # All segments should be enabled because the task updates ALL segments for the document assert segment1.enabled is True @@ -744,7 +752,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_comprehensive_error_scenarios( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test comprehensive error scenarios and recovery. @@ -779,7 +787,7 @@ class TestAddDocumentToIndexTask: document.indexing_status = "completed" document.error = None document.disabled_at = None - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -789,7 +797,7 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify consistent error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False, f"Document should be disabled for {error_name}" assert document.indexing_status == "error", f"Document status should be error for {error_name}" assert document.error is not None, f"Error should be recorded for {error_name}" @@ -798,7 +806,7 @@ class TestAddDocumentToIndexTask: # Verify segments remain disabled due to error for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False, f"Segments should remain disabled for {error_name}" # Verify redis cache was still cleared despite error diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index f94c5b19e6..ec789418a8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,8 +11,8 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -49,7 +49,7 @@ class TestBatchCleanDocumentTask: "get_image_ids": mock_get_image_ids, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -69,16 +69,16 @@ class TestBatchCleanDocumentTask: status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -87,15 +87,15 @@ class TestBatchCleanDocumentTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account - def _create_test_dataset(self, db_session_with_containers, account): + def _create_test_dataset(self, db_session_with_containers: Session, account): """ Helper method to create a test dataset for testing. @@ -119,12 +119,12 @@ class TestBatchCleanDocumentTask: embedding_model_provider="openai", ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, dataset, account): + def _create_test_document(self, db_session_with_containers: Session, dataset, account): """ Helper method to create a test document for testing. @@ -153,12 +153,12 @@ class TestBatchCleanDocumentTask: doc_form="text_model", ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def _create_test_document_segment(self, db_session_with_containers, document, account): + def _create_test_document_segment(self, db_session_with_containers: Session, document, account): """ Helper method to create a test document segment for testing. @@ -186,12 +186,12 @@ class TestBatchCleanDocumentTask: status="completed", ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment - def _create_test_upload_file(self, db_session_with_containers, account): + def _create_test_upload_file(self, db_session_with_containers: Session, account): """ Helper method to create a test upload file for testing. @@ -220,13 +220,13 @@ class TestBatchCleanDocumentTask: used=False, ) - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file def test_batch_clean_document_task_successful_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful cleanup of documents with segments and files. @@ -245,7 +245,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -261,18 +261,18 @@ class TestBatchCleanDocumentTask: # The task should have processed the segment and cleaned up the database # Verify database cleanup - db.session.commit() # Ensure all changes are committed + db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_with_image_files( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup of documents containing image references. @@ -300,8 +300,8 @@ class TestBatchCleanDocumentTask: status="completed", ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Store original IDs for verification segment_id = segment.id @@ -313,17 +313,17 @@ class TestBatchCleanDocumentTask: ) # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Verify that the task completed successfully by checking the log output # The task should have processed the segment and cleaned up the database def test_batch_clean_document_task_no_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when document has no segments. @@ -339,7 +339,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -354,21 +354,21 @@ class TestBatchCleanDocumentTask: # Since there are no segments, the task should handle this gracefully # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when dataset is not found. @@ -386,8 +386,8 @@ class TestBatchCleanDocumentTask: dataset_id = dataset.id # Delete the dataset to simulate not found scenario - db.session.delete(dataset) - db.session.commit() + db_session_with_containers.delete(dataset) + db_session_with_containers.commit() # Execute the task with non-existent dataset batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) @@ -399,14 +399,14 @@ class TestBatchCleanDocumentTask: mock_external_service_dependencies["storage"].delete.assert_not_called() # Verify that no database cleanup occurred - db.session.commit() + db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db.session.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when storage operations fail. @@ -423,7 +423,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -442,18 +442,18 @@ class TestBatchCleanDocumentTask: # The task should continue processing even when storage operations fail # Verify database cleanup still occurred despite storage failure - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_multiple_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup of multiple documents in a single batch operation. @@ -482,7 +482,7 @@ class TestBatchCleanDocumentTask: segments.append(segment) upload_files.append(upload_file) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_ids = [doc.id for doc in documents] @@ -498,20 +498,20 @@ class TestBatchCleanDocumentTask: # The task should process all documents and clean up all associated resources # Verify database cleanup for all resources - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup with different document form types. @@ -527,12 +527,12 @@ class TestBatchCleanDocumentTask: for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) - db.session.commit() + db_session_with_containers.commit() document = self._create_test_document(db_session_with_containers, dataset, account) # Update document doc_form document.doc_form = doc_form - db.session.commit() + db_session_with_containers.commit() segment = self._create_test_document_segment(db_session_with_containers, document, account) @@ -549,20 +549,20 @@ class TestBatchCleanDocumentTask: # The task should handle different document forms correctly # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None except Exception as e: # If the task fails due to external service issues (e.g., plugin daemon), # we should still verify that the database state is consistent # This is a common scenario in test environments where external services may not be available - db.session.commit() + db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -572,7 +572,7 @@ class TestBatchCleanDocumentTask: pass def test_batch_clean_document_task_large_batch_performance( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup performance with a large batch of documents. @@ -604,7 +604,7 @@ class TestBatchCleanDocumentTask: segments.append(segment) upload_files.append(upload_file) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_ids = [doc.id for doc in documents] @@ -629,20 +629,20 @@ class TestBatchCleanDocumentTask: # The task should handle large batches efficiently # Verify database cleanup for all resources - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test full integration with real database operations. @@ -683,12 +683,12 @@ class TestBatchCleanDocumentTask: # Add all to database for segment in segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Verify initial state - assert db.session.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db.session.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 + assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None # Store original IDs for verification document_id = document.id @@ -704,17 +704,17 @@ class TestBatchCleanDocumentTask: # The task should process all segments and clean up all associated resources # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None # Verify final database state - assert db.session.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db.session.query(UploadFile).filter_by(id=file_id).first() is None + assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 + assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 2156743c17..a2324979db 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -29,20 +30,19 @@ class TestBatchCreateSegmentToIndexTask: """Integration tests for batch_create_segment_to_index_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" - from extensions.ext_database import db from extensions.ext_redis import redis_client # Clear all test data - db.session.query(DocumentSegment).delete() - db.session.query(Document).delete() - db.session.query(Dataset).delete() - db.session.query(UploadFile).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + db_session_with_containers.query(DocumentSegment).delete() + db_session_with_containers.query(Document).delete() + db_session_with_containers.query(Dataset).delete() + db_session_with_containers.query(UploadFile).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() # Clear Redis cache redis_client.flushdb() @@ -75,7 +75,7 @@ class TestBatchCreateSegmentToIndexTask: "embedding_model": mock_embedding_model, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -95,18 +95,16 @@ class TestBatchCreateSegmentToIndexTask: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -115,15 +113,15 @@ class TestBatchCreateSegmentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, db_session_with_containers, account, tenant): + def _create_test_dataset(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test dataset for testing. @@ -148,14 +146,12 @@ class TestBatchCreateSegmentToIndexTask: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, account, tenant, dataset): + def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset): """ Helper method to create a test document for testing. @@ -186,14 +182,12 @@ class TestBatchCreateSegmentToIndexTask: word_count=0, ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def _create_test_upload_file(self, db_session_with_containers, account, tenant): + def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test upload file for testing. @@ -221,10 +215,8 @@ class TestBatchCreateSegmentToIndexTask: used=False, ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file @@ -252,7 +244,7 @@ class TestBatchCreateSegmentToIndexTask: return csv_content def test_batch_create_segment_to_index_task_success_text_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful batch creation of segments for text model documents. @@ -293,11 +285,10 @@ class TestBatchCreateSegmentToIndexTask: ) # Verify results - from extensions.ext_database import db # Check that segments were created segments = ( - db.session.query(DocumentSegment) + db_session_with_containers.query(DocumentSegment) .filter_by(document_id=document.id) .order_by(DocumentSegment.position) .all() @@ -316,7 +307,7 @@ class TestBatchCreateSegmentToIndexTask: assert segment.answer is None # text_model doesn't have answers # Check that document word count was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count > 0 # Verify vector service was called @@ -331,7 +322,7 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"completed" def test_batch_create_segment_to_index_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when dataset does not exist. @@ -370,17 +361,16 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created (since dataset doesn't exist) - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify no documents were modified - documents = db.session.query(Document).all() + documents = db_session_with_containers.query(Document).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when document does not exist. @@ -419,18 +409,17 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) - db.session.refresh(dataset) - segments_for_dataset = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + db_session_with_containers.refresh(dataset) + segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when document is not available for indexing. @@ -498,11 +487,9 @@ class TestBatchCreateSegmentToIndexTask: ), ] - from extensions.ext_database import db - for document in test_cases: - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Test each unavailable document for document in test_cases: @@ -524,11 +511,11 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db.session.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when upload file does not exist. @@ -567,17 +554,16 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify document remains unchanged - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count == 0 def test_batch_create_segment_to_index_task_empty_csv_file( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when CSV file is empty. @@ -619,17 +605,16 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify document remains unchanged - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count == 0 def test_batch_create_segment_to_index_task_position_calculation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test proper position calculation for segments when existing segments exist. @@ -664,11 +649,9 @@ class TestBatchCreateSegmentToIndexTask: ) existing_segments.append(segment) - from extensions.ext_database import db - for segment in existing_segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Create CSV content csv_content = self._create_test_csv_content("text_model") @@ -695,7 +678,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions all_segments = ( - db.session.query(DocumentSegment) + db_session_with_containers.query(DocumentSegment) .filter_by(document_id=document.id) .order_by(DocumentSegment.position) .all() @@ -716,7 +699,7 @@ class TestBatchCreateSegmentToIndexTask: assert segment.completed_at is not None # Check that document word count was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count > 0 # Verify vector service was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index cd99b2965f..8eb881258a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -37,7 +38,7 @@ class TestCleanDatasetTask: """Integration tests for clean_dataset_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" from extensions.ext_redis import redis_client @@ -82,7 +83,7 @@ class TestCleanDatasetTask: "index_processor": mock_index_processor, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -127,7 +128,7 @@ class TestCleanDatasetTask: return account, tenant - def _create_test_dataset(self, db_session_with_containers, account, tenant): + def _create_test_dataset(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test dataset for testing. @@ -157,7 +158,7 @@ class TestCleanDatasetTask: return dataset - def _create_test_document(self, db_session_with_containers, account, tenant, dataset): + def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset): """ Helper method to create a test document for testing. @@ -194,7 +195,7 @@ class TestCleanDatasetTask: return document - def _create_test_segment(self, db_session_with_containers, account, tenant, dataset, document): + def _create_test_segment(self, db_session_with_containers: Session, account, tenant, dataset, document): """ Helper method to create a test document segment for testing. @@ -230,7 +231,7 @@ class TestCleanDatasetTask: return segment - def _create_test_upload_file(self, db_session_with_containers, account, tenant): + def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test upload file for testing. @@ -264,7 +265,7 @@ class TestCleanDatasetTask: return upload_file def test_clean_dataset_task_success_basic_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful basic dataset cleanup with minimal data. @@ -325,7 +326,7 @@ class TestCleanDatasetTask: mock_storage.delete.assert_not_called() def test_clean_dataset_task_success_with_documents_and_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful dataset cleanup with documents and segments. @@ -433,7 +434,7 @@ class TestCleanDatasetTask: assert mock_storage.delete.call_count == 3 def test_clean_dataset_task_success_with_invalid_doc_form( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful dataset cleanup with invalid doc_form handling. @@ -493,7 +494,7 @@ class TestCleanDatasetTask: assert mock_factory.call_count == 4 def test_clean_dataset_task_error_handling_and_rollback( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling and rollback mechanism when database operations fail. @@ -542,7 +543,7 @@ class TestCleanDatasetTask: # This demonstrates the resilience of the cleanup process def test_clean_dataset_task_with_image_file_references( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup with image file references in document segments. @@ -634,7 +635,7 @@ class TestCleanDatasetTask: mock_get_image_ids.assert_called_once() def test_clean_dataset_task_performance_with_large_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup performance with large amounts of data. @@ -704,11 +705,9 @@ class TestCleanDatasetTask: binding.created_at = datetime.now() bindings.append(binding) - from extensions.ext_database import db - - db.session.add_all(metadata_items) - db.session.add_all(bindings) - db.session.commit() + db_session_with_containers.add_all(metadata_items) + db_session_with_containers.add_all(bindings) + db_session_with_containers.commit() # Measure cleanup performance import time @@ -772,7 +771,7 @@ class TestCleanDatasetTask: print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds") def test_clean_dataset_task_storage_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup when storage operations fail. @@ -838,7 +837,7 @@ class TestCleanDatasetTask: # consistency in the database def test_clean_dataset_task_edge_cases_and_boundary_conditions( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup with edge cases and boundary conditions. diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index 8785c948d1..ab9e5b639a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -13,8 +13,8 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -34,7 +34,7 @@ class TestDisableSegmentFromIndexTask: mock_processor.clean.return_value = None yield mock_processor - def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]: + def _create_test_account_and_tenant(self, db_session_with_containers: Session) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. @@ -53,8 +53,8 @@ class TestDisableSegmentFromIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -62,8 +62,8 @@ class TestDisableSegmentFromIndexTask: status="normal", plan="basic", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -72,15 +72,15 @@ class TestDisableSegmentFromIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset: + def _create_test_dataset(self, db_session_with_containers: Session, tenant: Tenant, account: Account) -> Dataset: """ Helper method to create a test dataset. @@ -101,13 +101,18 @@ class TestDisableSegmentFromIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset def _create_test_document( - self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model" + self, + db_session_with_containers: Session, + dataset: Dataset, + tenant: Tenant, + account: Account, + doc_form: str = "text_model", ) -> Document: """ Helper method to create a test document. @@ -140,13 +145,14 @@ class TestDisableSegmentFromIndexTask: tokens=500, completed_at=datetime.now(UTC), ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document def _create_test_segment( self, + db_session_with_containers: Session, document: Document, dataset: Dataset, tenant: Tenant, @@ -185,12 +191,12 @@ class TestDisableSegmentFromIndexTask: created_by=account.id, completed_at=datetime.now(UTC) if status == "completed" else None, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment - def test_disable_segment_success(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_success(self, db_session_with_containers: Session, mock_index_processor): """ Test successful segment disabling from index. @@ -202,9 +208,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Set up Redis cache indexing_cache_key = f"segment_{segment.id}_indexing" @@ -226,10 +232,10 @@ class TestDisableSegmentFromIndexTask: assert redis_client.get(indexing_cache_key) is None # Verify segment is still in database - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.id is not None - def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_not_found(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment is not found. @@ -251,7 +257,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_not_completed(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment is not in completed status. @@ -262,9 +268,11 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with non-completed segment account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment( + db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True + ) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -275,7 +283,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_no_dataset(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment has no associated dataset. @@ -286,13 +294,13 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Manually remove dataset association segment.dataset_id = "00000000-0000-0000-0000-000000000000" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -303,7 +311,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_no_document(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment has no associated document. @@ -314,13 +322,13 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Manually remove document association segment.document_id = "00000000-0000-0000-0000-000000000000" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -331,7 +339,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_disabled(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when document is disabled. @@ -342,12 +350,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with disabled document account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.enabled = False - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -358,7 +366,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_archived(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when document is archived. @@ -369,12 +377,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with archived document account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.archived = True - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -385,7 +393,9 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_indexing_not_completed( + self, db_session_with_containers: Session, mock_index_processor + ): """ Test handling when document indexing is not completed. @@ -396,12 +406,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with incomplete indexing account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.indexing_status = "indexing" - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -412,7 +422,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_index_processor_exception(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when index processor raises an exception. @@ -424,9 +434,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Set up Redis cache indexing_cache_key = f"segment_{segment.id}_indexing" @@ -449,13 +459,13 @@ class TestDisableSegmentFromIndexTask: assert call_args[0][1] == [segment.index_node_id] # Check index node IDs # Verify segment was re-enabled - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True # Verify Redis cache was still cleared assert redis_client.get(indexing_cache_key) is None - def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_different_doc_forms(self, db_session_with_containers: Session, mock_index_processor): """ Test disabling segments with different document forms. @@ -470,9 +480,11 @@ class TestDisableSegmentFromIndexTask: for doc_form in doc_forms: # Arrange: Create test data for each form account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account, doc_form=doc_form) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document( + db_session_with_containers, dataset, tenant, account, doc_form=doc_form + ) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Reset mock for each iteration mock_index_processor.reset_mock() @@ -489,7 +501,7 @@ class TestDisableSegmentFromIndexTask: assert call_args[0][0].id == dataset.id # Check dataset ID assert call_args[0][1] == [segment.index_node_id] # Check index node IDs - def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_redis_cache_handling(self, db_session_with_containers: Session, mock_index_processor): """ Test Redis cache handling during segment disabling. @@ -500,9 +512,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Test with cache present indexing_cache_key = f"segment_{segment.id}_indexing" @@ -517,13 +529,13 @@ class TestDisableSegmentFromIndexTask: assert redis_client.get(indexing_cache_key) is None # Test with no cache present - segment2 = self._create_test_segment(document, dataset, tenant, account) + segment2 = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) result2 = disable_segment_from_index_task(segment2.id) # Assert: Verify task still works without cache assert result2 is None - def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_performance_timing(self, db_session_with_containers: Session, mock_index_processor): """ Test performance timing of segment disabling task. @@ -534,9 +546,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task and measure time start_time = time.perf_counter() @@ -548,7 +560,9 @@ class TestDisableSegmentFromIndexTask: execution_time = end_time - start_time assert execution_time < 5.0 # Should complete within 5 seconds - def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_database_session_management( + self, db_session_with_containers: Session, mock_index_processor + ): """ Test database session management during task execution. @@ -559,9 +573,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -570,10 +584,10 @@ class TestDisableSegmentFromIndexTask: assert result is None # Verify segment is still accessible (session was properly managed) - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.id is not None - def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_concurrent_execution(self, db_session_with_containers: Session, mock_index_processor): """ Test concurrent execution of segment disabling tasks. @@ -584,12 +598,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create multiple test segments account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) segments = [] for i in range(3): - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) segments.append(segment) # Act: Execute tasks concurrently (simulated) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index a93a80e231..8f47b48ae2 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy.orm import Session from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument @@ -31,7 +32,7 @@ class TestDisableSegmentsFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test account with realistic data. @@ -79,7 +80,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers, account, fake=None): + def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None): """ Helper method to create a test dataset with realistic data. @@ -113,7 +114,7 @@ class TestDisableSegmentsFromIndexTask: return dataset - def _create_test_document(self, db_session_with_containers, dataset, account, fake=None): + def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None): """ Helper method to create a test document with realistic data. @@ -158,7 +159,9 @@ class TestDisableSegmentsFromIndexTask: return document - def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None): + def _create_test_segments( + self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None + ): """ Helper method to create test document segments with realistic data. @@ -210,7 +213,7 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None): + def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None): """ Helper method to create a dataset process rule. @@ -239,14 +242,12 @@ class TestDisableSegmentsFromIndexTask: process_rule.created_by = dataset.created_by process_rule.updated_by = dataset.updated_by - from extensions.ext_database import db - - db.session.add(process_rule) - db.session.commit() + db_session_with_containers.add(process_rule) + db_session_with_containers.commit() return process_rule - def test_disable_segments_success(self, db_session_with_containers): + def test_disable_segments_success(self, db_session_with_containers: Session): """ Test successful disabling of segments from index. @@ -297,7 +298,7 @@ class TestDisableSegmentsFromIndexTask: expected_key = f"segment_{segment.id}_indexing" mock_redis.delete.assert_any_call(expected_key) - def test_disable_segments_dataset_not_found(self, db_session_with_containers): + def test_disable_segments_dataset_not_found(self, db_session_with_containers: Session): """ Test handling when dataset is not found. @@ -320,7 +321,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when dataset is not found mock_redis.delete.assert_not_called() - def test_disable_segments_document_not_found(self, db_session_with_containers): + def test_disable_segments_document_not_found(self, db_session_with_containers: Session): """ Test handling when document is not found. @@ -344,7 +345,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when document is not found mock_redis.delete.assert_not_called() - def test_disable_segments_document_invalid_status(self, db_session_with_containers): + def test_disable_segments_document_invalid_status(self, db_session_with_containers: Session): """ Test handling when document has invalid status for disabling. @@ -360,9 +361,8 @@ class TestDisableSegmentsFromIndexTask: # Test case 1: Document not enabled document.enabled = False - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() segment_ids = [segment.id for segment in segments] @@ -379,7 +379,7 @@ class TestDisableSegmentsFromIndexTask: # Test case 2: Document archived document.enabled = True document.archived = True - db.session.commit() + db_session_with_containers.commit() with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: # Act @@ -393,7 +393,7 @@ class TestDisableSegmentsFromIndexTask: document.enabled = True document.archived = False document.indexing_status = "indexing" - db.session.commit() + db_session_with_containers.commit() with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: # Act @@ -403,7 +403,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_redis.delete.assert_not_called() - def test_disable_segments_no_segments_found(self, db_session_with_containers): + def test_disable_segments_no_segments_found(self, db_session_with_containers: Session): """ Test handling when no segments are found for the given IDs. @@ -430,7 +430,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when no segments are found mock_redis.delete.assert_not_called() - def test_disable_segments_index_processor_error(self, db_session_with_containers): + def test_disable_segments_index_processor_error(self, db_session_with_containers: Session): """ Test handling when index processor encounters an error. @@ -464,13 +464,14 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value # Verify segments were rolled back to enabled state - from extensions.ext_database import db - db.session.refresh(segments[0]) - db.session.refresh(segments[1]) + db_session_with_containers.refresh(segments[0]) + db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + updated_segments = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + ) for segment in updated_segments: assert segment.enabled is True @@ -480,7 +481,7 @@ class TestDisableSegmentsFromIndexTask: # Verify Redis cache cleanup was still called assert mock_redis.delete.call_count == len(segments) - def test_disable_segments_with_different_doc_forms(self, db_session_with_containers): + def test_disable_segments_with_different_doc_forms(self, db_session_with_containers: Session): """ Test disabling segments with different document forms. @@ -503,9 +504,8 @@ class TestDisableSegmentsFromIndexTask: for doc_form in doc_forms: # Update document form document.doc_form = doc_form - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Mock the index processor factory with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: @@ -523,7 +523,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_factory.assert_called_with(doc_form) - def test_disable_segments_performance_timing(self, db_session_with_containers): + def test_disable_segments_performance_timing(self, db_session_with_containers: Session): """ Test that the task properly measures and logs performance timing. @@ -568,7 +568,7 @@ class TestDisableSegmentsFromIndexTask: assert performance_log is not None assert "0.5" in performance_log # Should log the execution time - def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers): + def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session): """ Test that Redis cache is properly cleaned up for all segments. @@ -610,7 +610,7 @@ class TestDisableSegmentsFromIndexTask: for expected_key in expected_keys: assert expected_key in actual_calls - def test_disable_segments_database_session_cleanup(self, db_session_with_containers): + def test_disable_segments_database_session_cleanup(self, db_session_with_containers: Session): """ Test that database session is properly closed after task execution. @@ -643,7 +643,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value # Session lifecycle is managed by context manager; no explicit close assertion - def test_disable_segments_empty_segment_ids(self, db_session_with_containers): + def test_disable_segments_empty_segment_ids(self, db_session_with_containers: Session): """ Test handling when empty segment IDs list is provided. @@ -669,7 +669,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when no segments are provided mock_redis.delete.assert_not_called() - def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers): + def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers: Session): """ Test handling when some segment IDs are valid and others are invalid. diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index b2e1ce3b89..c61e37b1e9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.indexing_runner import DocumentIsPausedError from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -282,7 +283,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents - def test_duplicate_document_indexing_task_success( + def _test_duplicate_document_indexing_task_success( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -324,7 +325,7 @@ class TestDuplicateDocumentIndexingTasks: processed_documents = call_args[0][0] # First argument should be documents list assert len(processed_documents) == 3 - def test_duplicate_document_indexing_task_with_segment_cleanup( + def _test_duplicate_document_indexing_task_with_segment_cleanup( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -374,7 +375,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() - def test_duplicate_document_indexing_task_dataset_not_found( + def _test_duplicate_document_indexing_task_dataset_not_found( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -445,7 +446,7 @@ class TestDuplicateDocumentIndexingTasks: processed_documents = call_args[0][0] # First argument should be documents list assert len(processed_documents) == 2 # Only existing documents - def test_duplicate_document_indexing_task_indexing_runner_exception( + def _test_duplicate_document_indexing_task_indexing_runner_exception( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -486,7 +487,7 @@ class TestDuplicateDocumentIndexingTasks: assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None - def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( + def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -549,7 +550,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify indexing runner was not called due to early validation error mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() - def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( + def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -783,3 +784,90 @@ class TestDuplicateDocumentIndexingTasks: document_ids=document_ids, ) mock_queue.delete_task_key.assert_not_called() + + def test_successful_duplicate_document_indexing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test successful duplicate document indexing flow.""" + self._test_duplicate_document_indexing_task_success( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when dataset is not found.""" + self._test_duplicate_document_indexing_task_dataset_not_found( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing with billing enabled and sandbox plan.""" + self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_with_billing_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when billing limit is exceeded.""" + self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_runner_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when IndexingRunner raises an error.""" + self._test_duplicate_document_indexing_task_indexing_runner_exception( + db_session_with_containers, mock_external_service_dependencies + ) + + def _test_duplicate_document_indexing_task_document_is_paused( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when document is paused.""" + # Arrange + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + for document in documents: + document.is_paused = True + db_session_with_containers.add(document) + db_session_with_containers.commit() + + document_ids = [doc.id for doc in documents] + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError( + "Document paused" + ) + + # Act + _duplicate_document_indexing_task(dataset.id, document_ids) + db_session_with_containers.expire_all() + + # Assert + for doc_id in document_ids: + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + assert updated_document.is_paused is True + assert updated_document.indexing_status == "parsing" + assert updated_document.display_status == "paused" + assert updated_document.processing_started_at is not None + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_duplicate_document_indexing_document_is_paused( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when document is paused.""" + self._test_duplicate_document_indexing_task_document_is_paused( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_cleans_old_segments( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test that duplicate document indexing cleans old segments.""" + self._test_duplicate_document_indexing_task_with_segment_cleanup( + db_session_with_containers, mock_external_service_dependencies + ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index b3d9e49b30..bc29395545 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -31,7 +31,9 @@ class TestEnableSegmentsToIndexTask: "index_processor": mock_processor, } - def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_dataset_and_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test dataset and document for testing. @@ -51,15 +53,15 @@ class TestEnableSegmentsToIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -68,8 +70,8 @@ class TestEnableSegmentsToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -81,8 +83,8 @@ class TestEnableSegmentsToIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document document = Document( @@ -99,16 +101,16 @@ class TestEnableSegmentsToIndexTask: enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property works correctly - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, document def _create_test_segments( - self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed" + self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed" ): """ Helper method to create test document segments. @@ -144,14 +146,14 @@ class TestEnableSegmentsToIndexTask: status=status, created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() return segments def test_enable_segments_to_index_with_different_index_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segments indexing with different index types. @@ -169,10 +171,10 @@ class TestEnableSegmentsToIndexTask: # Update document to use different index type document.doc_form = IndexStructureType.QA_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -204,7 +206,7 @@ class TestEnableSegmentsToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_enable_segments_to_index_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -229,7 +231,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent document. @@ -256,7 +258,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_invalid_document_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of document with invalid status. @@ -284,12 +286,12 @@ class TestEnableSegmentsToIndexTask: document.enabled = True document.archived = False document.indexing_status = "completed" - db.session.commit() + db_session_with_containers.commit() # Set invalid status for attr, value in status_attrs.items(): setattr(document, attr, value) - db.session.commit() + db_session_with_containers.commit() # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -304,11 +306,11 @@ class TestEnableSegmentsToIndexTask: # Clean up segments for next iteration for segment in segments: - db.session.delete(segment) - db.session.commit() + db_session_with_containers.delete(segment) + db_session_with_containers.commit() def test_enable_segments_to_index_segments_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when no segments are found. @@ -338,7 +340,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segments indexing with parent-child structure. @@ -357,10 +359,10 @@ class TestEnableSegmentsToIndexTask: # Update document to use parent-child index type document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments with mock child chunks segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -410,7 +412,7 @@ class TestEnableSegmentsToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_enable_segments_to_index_general_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test general exception handling during indexing process. @@ -443,7 +445,7 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify error handling for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False assert segment.status == "error" assert segment.error is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py index 6c3a9ef20a..ff72232d12 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py @@ -2,8 +2,8 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task @@ -30,7 +30,7 @@ class TestMailAccountDeletionTask: "email_service": mock_email_service, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -49,16 +49,16 @@ class TestMailAccountDeletionTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -67,12 +67,14 @@ class TestMailAccountDeletionTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return account - def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_deletion_success_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful account deletion success email sending. @@ -109,7 +111,7 @@ class TestMailAccountDeletionTask: ) def test_send_deletion_success_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion success email when mail service is not initialized. @@ -132,7 +134,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_not_called() def test_send_deletion_success_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion success email when email service raises exception. @@ -154,7 +156,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_called_once() def test_send_account_deletion_verification_code_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful account deletion verification code email sending. @@ -193,7 +195,7 @@ class TestMailAccountDeletionTask: ) def test_send_account_deletion_verification_code_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion verification code email when mail service is not initialized. @@ -217,7 +219,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_not_called() def test_send_account_deletion_verification_code_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion verification code email when email service raises exception. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 5fd6c56f7a..0876a39f82 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -9,8 +9,8 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import ( +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -18,7 +18,7 @@ from core.workflow.nodes.human_input.entities import ( HumanInputNodeData, MemberRecipient, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool +from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_storage import storage from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -96,8 +96,7 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - engine = db_session_with_containers.get_bind() - repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = FormCreateParams( app_id=app_id, workflow_execution_id=workflow_execution_id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index b9977b1fb6..ef7191299a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -4,11 +4,11 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Pipeline from models.workflow import Workflow @@ -52,7 +52,7 @@ class TestRagPipelineRunTasks: "delete_file": mock_delete_file, } - def _create_test_pipeline_and_workflow(self, db_session_with_containers): + def _create_test_pipeline_and_workflow(self, db_session_with_containers: Session): """ Helper method to create test pipeline and workflow for testing. @@ -71,15 +71,15 @@ class TestRagPipelineRunTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -88,8 +88,8 @@ class TestRagPipelineRunTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create workflow workflow = Workflow( @@ -107,8 +107,8 @@ class TestRagPipelineRunTasks: conversation_variables=[], rag_pipeline_variables=[], ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create pipeline pipeline = Pipeline( @@ -119,14 +119,14 @@ class TestRagPipelineRunTasks: created_by=account.id, ) pipeline.id = str(uuid.uuid4()) - db.session.add(pipeline) - db.session.commit() + db_session_with_containers.add(pipeline) + db_session_with_containers.commit() # Refresh entities to ensure they're properly loaded - db.session.refresh(account) - db.session.refresh(tenant) - db.session.refresh(workflow) - db.session.refresh(pipeline) + db_session_with_containers.refresh(account) + db_session_with_containers.refresh(tenant) + db_session_with_containers.refresh(workflow) + db_session_with_containers.refresh(pipeline) return account, tenant, pipeline, workflow @@ -209,7 +209,7 @@ class TestRagPipelineRunTasks: return json.dumps(entities_data) def test_priority_rag_pipeline_run_task_success( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test successful priority RAG pipeline run task execution. @@ -254,7 +254,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_rag_pipeline_run_task_success( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test successful regular RAG pipeline run task execution. @@ -299,7 +299,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_priority_rag_pipeline_run_task_with_waiting_tasks( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test priority RAG pipeline run task with waiting tasks in queue using real Redis. @@ -351,7 +351,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining def test_rag_pipeline_run_task_legacy_compatibility( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility. @@ -419,7 +419,7 @@ class TestRagPipelineRunTasks: redis_client.delete(legacy_task_key) def test_rag_pipeline_run_task_with_waiting_tasks( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with waiting tasks in queue using real Redis. @@ -469,7 +469,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining def test_priority_rag_pipeline_run_task_error_handling( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test error handling in priority RAG pipeline run task using real Redis. @@ -526,7 +526,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_rag_pipeline_run_task_error_handling( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test error handling in regular RAG pipeline run task using real Redis. @@ -581,7 +581,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_priority_rag_pipeline_run_task_tenant_isolation( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test tenant isolation in priority RAG pipeline run task using real Redis. @@ -648,7 +648,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_rag_pipeline_run_task_tenant_isolation( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test tenant isolation in regular RAG pipeline run task using real Redis. @@ -713,7 +713,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_run_single_rag_pipeline_task_success( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test successful run_single_rag_pipeline_task execution. @@ -748,7 +748,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_run_single_rag_pipeline_task_entity_validation_error( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test run_single_rag_pipeline_task with invalid entity data. @@ -793,7 +793,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_run_single_rag_pipeline_task_database_entity_not_found( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test run_single_rag_pipeline_task with non-existent database entities. @@ -838,7 +838,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_priority_rag_pipeline_run_task_file_not_found( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test priority RAG pipeline run task with non-existent file. @@ -888,7 +888,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_rag_pipeline_run_task_file_not_found( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with non-existent file. diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 8501a8e39b..182c9ef882 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -4,8 +4,8 @@ from unittest.mock import ANY, call, patch import pytest from core.db.session_factory import session_factory -from core.workflow.variables.segments import StringSegment -from core.workflow.variables.types import SegmentType +from dify_graph.variables.segments import StringSegment +from dify_graph.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 5f4f28cf4f..ca76fa0a4b 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -27,8 +27,8 @@ import pytest from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker -from core.workflow.entities import WorkflowExecution -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.entities import WorkflowExecution +from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 604d68f257..7bfc6c9e13 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -18,7 +18,7 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index f9788e2e50..83601dc1b9 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -10,10 +10,10 @@ from flask import Flask from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index cf10182ad3..f34702a257 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -13,8 +13,8 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.variables.types import SegmentType +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.types import SegmentType from factories.variable_factory import build_segment from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -310,8 +310,8 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from core.workflow.file.enums import FileTransferMethod, FileType - from core.workflow.file.models import File + from dify_graph.file.enums import FileTransferMethod, FileType + from dify_graph.file.models import File # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( @@ -368,8 +368,8 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from core.workflow.file.enums import FileTransferMethod, FileType - from core.workflow.file.models import File + from dify_graph.file.enums import FileTransferMethod, FileType + from dify_graph.file.models import File # Create a File object with REMOTE_URL transfer method test_file = File( diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index 59b6614d5e..f2e57eb65f 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -13,8 +13,8 @@ from flask import Flask from flask.views import MethodView from werkzeug.exceptions import Forbidden -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index b70e70105c..1923ab7fa7 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -29,7 +29,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index c5b1cbc127..4e4482f704 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -34,7 +34,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 0eb3854c84..4eada73b82 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -35,7 +35,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -315,7 +315,7 @@ class TestWorkflowStopMechanism: def test_graph_engine_manager_has_send_stop_command(self): """Test GraphEngineManager has send_stop_command method.""" - from core.workflow.graph_engine.manager import GraphEngineManager + from dify_graph.graph_engine.manager import GraphEngineManager assert hasattr(GraphEngineManager, "send_stop_command") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index fcaa61a871..9e95f45a0a 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,7 +1,7 @@ from types import SimpleNamespace from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py index 4a613e35b0..ba8c903f65 100644 --- a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py +++ b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py @@ -3,7 +3,7 @@ from collections.abc import Generator from core.agent.entities import AgentScratchpadUnit from core.agent.output_parser.cot_output_parser import CotAgentOutputParser -from core.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta +from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]: diff --git a/api/tests/unit_tests/core/agent/patterns/test_base.py b/api/tests/unit_tests/core/agent/patterns/test_base.py index b0e0d44940..3ab34b8d67 100644 --- a/api/tests/unit_tests/core/agent/patterns/test_base.py +++ b/api/tests/unit_tests/core/agent/patterns/test_base.py @@ -4,10 +4,10 @@ from decimal import Decimal from unittest.mock import MagicMock import pytest +from core.model_runtime.entities.llm_entities import LLMUsage from core.agent.entities import AgentLog, ExecutionContext from core.agent.patterns.base import AgentPattern -from core.model_runtime.entities.llm_entities import LLMUsage class ConcreteAgentPattern(AgentPattern): diff --git a/api/tests/unit_tests/core/agent/patterns/test_function_call.py b/api/tests/unit_tests/core/agent/patterns/test_function_call.py index 6b3600dbbf..3bffa9dfff 100644 --- a/api/tests/unit_tests/core/agent/patterns/test_function_call.py +++ b/api/tests/unit_tests/core/agent/patterns/test_function_call.py @@ -4,8 +4,6 @@ from decimal import Decimal from unittest.mock import MagicMock import pytest - -from core.agent.entities import AgentLog, ExecutionContext from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import ( PromptMessageTool, @@ -13,6 +11,8 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.agent.entities import AgentLog, ExecutionContext + @pytest.fixture def mock_model_instance(): diff --git a/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py b/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py index 07b9df2acf..ca5931cdca 100644 --- a/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py +++ b/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock import pytest +from core.model_runtime.entities.model_entities import ModelFeature from core.agent.entities import AgentEntity, ExecutionContext from core.agent.patterns.function_call import FunctionCallStrategy from core.agent.patterns.react import ReActStrategy from core.agent.patterns.strategy_factory import StrategyFactory -from core.model_runtime.entities.model_entities import ModelFeature @pytest.fixture diff --git a/api/tests/unit_tests/core/agent/test_agent_app_runner.py b/api/tests/unit_tests/core/agent/test_agent_app_runner.py index 06584fa986..e2ff344260 100644 --- a/api/tests/unit_tests/core/agent/test_agent_app_runner.py +++ b/api/tests/unit_tests/core/agent/test_agent_app_runner.py @@ -4,10 +4,10 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest +from core.model_runtime.entities.llm_entities import LLMUsage from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage -from core.model_runtime.entities.llm_entities import LLMUsage class TestOrganizePromptMessages: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 9dddb18595..de99833aac 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,6 +1,6 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.workflow.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from dify_graph.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent def test_convert_with_vision(): diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 0ca54a2f4a..12ab587564 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.workflow.variables import SegmentType +from dify_graph.variables import SegmentType from factories import variable_factory from models import ConversationVariable, Workflow diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py index a94b5445f7..be773557f6 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py @@ -10,7 +10,7 @@ import pytest from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent -from core.workflow.entities.pause_reason import HumanInputRequired +from dify_graph.entities.pause_reason import HumanInputRequired from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import EndUser diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index 1931e230b2..67b3777c40 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -9,8 +9,8 @@ from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.workflow.file.enums import FileTransferMethod, FileType +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index cd5ea8986a..b0789bbc1e 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -3,9 +3,9 @@ from types import SimpleNamespace import pytest from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport -from core.workflow.runtime import GraphRuntimeState -from core.workflow.runtime.variable_pool import VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState +from dify_graph.runtime.variable_pool import VariablePool +from dify_graph.system_variable import SystemVariable def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 5508a117c1..72430a3347 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,8 +1,8 @@ from collections.abc import Mapping, Sequence from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from core.workflow.variables.segments import ArrayFileSegment, FileSegment +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from dify_graph.variables.segments import ArrayFileSegment, FileSegment class TestWorkflowResponseConverterFetchFilesFromVariableValue: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 1c36b4d12b..4ed7d73cd0 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -4,9 +4,9 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 0a9794e41c..5879e8fb9b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -2,9 +2,9 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index d25bff92dc..69d476bd13 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -23,9 +23,9 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import NodeType -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import NodeType +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 04c8696525..43a97ae098 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest from core.app.apps.base_app_generator import BaseAppGenerator -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 97c993928e..44af89601c 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -8,32 +8,32 @@ API_DIR = str(Path(__file__).resolve().parents[5]) if API_DIR not in sys.path: sys.path.insert(0, API_DIR) -import core.workflow.nodes.human_input.entities # noqa: F401 +import dify_graph.nodes.human_input.entities # noqa: F401 from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_events import ( +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeRunResult, PauseRequestedEvent -from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.node_events import NodeRunResult, PauseRequestedEvent +from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: ops_stub = ModuleType("core.ops.ops_trace_manager") @@ -142,11 +142,11 @@ def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph: graph_config = _build_graph_config(pause_on=pause_on) - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index f4efb240c0..1388279221 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -4,8 +4,8 @@ import pytest from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.graph_events.graph import GraphRunPausedEvent +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index f5903d28bd..2e0715e974 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -8,8 +8,8 @@ import pytest from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index c30b925d88..65c6bd6654 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -10,12 +10,12 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph_events.graph import GraphRunPausedEvent -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph_events.graph import GraphRunPausedEvent +from dify_graph.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.enums import FormInputType +from dify_graph.system_variable import SystemVariable from models.account import Account diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 32cb1ed47c..5b23e71035 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -7,9 +7,9 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from models.account import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index d3ae577d0d..7d0e1d25f6 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -3,16 +3,16 @@ from datetime import datetime from unittest.mock import Mock from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.protocols.command_channel import CommandChannel -from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import StringVariable -from core.workflow.variables.segments import Segment +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.protocols.command_channel import CommandChannel +from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import StringVariable +from dify_graph.variables.segments import Segment class MockReadOnlyVariablePool: diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 539f0cb581..035f0ee05c 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -13,17 +13,17 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.graph_engine.entities.commands import GraphEngineCommand -from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from core.workflow.graph_events.graph import ( +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.graph_engine.entities.commands import GraphEngineCommand +from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from dify_graph.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from core.workflow.variables.segments import Segment +from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool +from dify_graph.variables.segments import Segment from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 40f58c9ddf..13fbca6e26 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -25,9 +25,9 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher -from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from core.model_runtime.entities.message_entities import TextPromptMessageContent from core.ops.ops_trace_manager import TraceQueueManager +from dify_graph.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index 9ee1df8bdc..52c91fb8c9 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -3,8 +3,8 @@ from collections.abc import Generator from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/unit_tests/core/file/test_file_manager.py b/api/tests/unit_tests/core/file/test_file_manager.py index aa387c9aad..65707c74fc 100644 --- a/api/tests/unit_tests/core/file/test_file_manager.py +++ b/api/tests/unit_tests/core/file/test_file_manager.py @@ -3,13 +3,14 @@ from unittest.mock import patch from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.workflow.file import File, FileTransferMethod, FileType from core.workflow.file.file_manager import ( _encode_file_ref, restore_multimodal_content, to_prompt_message_content, ) +from core.workflow.file import File, FileTransferMethod, FileType + class TestEncodeFileRef: """Tests for _encode_file_ref function.""" diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index 4d4ccc2672..deebf41320 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -1,4 +1,4 @@ -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType def test_file(): diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py index c8545d88cf..f000b3d5ef 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py @@ -3,6 +3,7 @@ Unit tests for sandbox file path detection and conversion. """ import pytest +from core.workflow.variables.segments import ArrayFileSegment, FileSegment from core.llm_generator.output_parser.file_ref import ( FILE_PATH_DESCRIPTION_SUFFIX, @@ -13,7 +14,6 @@ from core.llm_generator.output_parser.file_ref import ( is_file_path_property, ) from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.variables.segments import ArrayFileSegment, FileSegment def _build_file(file_id: str) -> File: diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 40a7700394..f982765b1a 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -18,7 +18,7 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py index 5fbdabceed..d42b7ca0d9 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock, patch import pytest -from core.model_runtime.entities.message_entities import AssistantPromptMessage -from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage +from dify_graph.model_runtime.model_providers.__base.large_language_model import _increase_tool_call ToolCall = AssistantPromptMessage.ToolCall @@ -97,7 +97,9 @@ def test__increase_tool_call(): # case 4: mock_id_generator = MagicMock() mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] - with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator): + with patch( + "dify_graph.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator + ): _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) @@ -107,6 +109,6 @@ def test__increase_tool_call__no_id_no_name_first_delta_should_raise(): ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')), ] actual: list[ToolCall] = [] - with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): + with patch("dify_graph.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): with pytest.raises(ValueError): _increase_tool_call(inputs, actual) diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py index 09d527cb12..8dcfd10ec6 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -1,10 +1,10 @@ -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result +from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result def _make_chunk( diff --git a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py index c10f7b89c3..4e435cb4c6 100644 --- a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py +++ b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal -from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata class TestLLMUsage: diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 4f398ce66e..32389b4d64 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType class TestGetNodeSpanKind: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 9e871fcb74..4f038d4a5b 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -19,14 +19,6 @@ import httpx import pytest from pydantic import BaseModel -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.plugin.entities.plugin_daemon import ( CredentialType, PluginDaemonInnerError, @@ -44,6 +36,14 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 1d25639343..3e184cbf21 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -5,16 +5,16 @@ import pytest from configs import dify_config from core.app.app_config.entities import ModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, UserPromptMessage, ) -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.file import File, FileTransferMethod, FileType from models.model import Conversation @@ -142,7 +142,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) - with patch("core.workflow.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: + with patch("dify_graph.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: mock_get_encoded_string.return_value = ImagePromptMessageContent( url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" ) diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index d157a41d2c..634703740c 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -5,14 +5,14 @@ from core.app.entities.app_invoke_entities import ( ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index e5da51d733..4136816562 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,4 +1,4 @@ -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.message_entities import ( ImagePromptMessageContent, TextPromptMessageContent, UserPromptMessage, diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 16896a0c6c..7976120547 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,10 +2,10 @@ # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -# from core.model_runtime.entities.message_entities import UserPromptMessage -# from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule -# from core.model_runtime.entities.provider_entities import ProviderEntity -# from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from dify_graph.model_runtime.entities.message_entities import UserPromptMessage +# from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +# from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +# from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index c822ecbe78..2ef66e8a96 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 63596bc320..6e71f0c61f 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -52,14 +52,14 @@ import pytest from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from core.model_runtime.errors.invoke import ( +from core.rag.embedding.cached_embedding import CacheEmbedding +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError, ) -from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index c00fee8fe5..b011ade884 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -61,9 +61,9 @@ from core.indexing_runner import ( DocumentIsPausedError, IndexingRunner, ) -from core.model_runtime.entities.model_entities import ModelType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document +from dify_graph.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index e4597e7f8c..0e53482c51 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,13 +17,13 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.model_manager import ModelInstance -from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance(): diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 4bc802dc23..682a451117 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -5,8 +5,8 @@ import pytest from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.nodes.knowledge_retrieval import exc -from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest +from dify_graph.nodes.knowledge_retrieval import exc +from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest from models.dataset import Dataset # ==================== Helper Functions ==================== diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e6d0371cd5..e7eecfa297 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -11,7 +11,7 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType +from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index f6211f4cca..b613573927 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import NodeType -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from dify_graph.enums import NodeType +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index 7f1e2c5e5b..fe9eed0307 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -12,8 +12,8 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 811ed2143b..9af4d12664 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses from datetime import datetime from types import SimpleNamespace -from unittest.mock import MagicMock import pytest @@ -15,7 +14,7 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -24,7 +23,7 @@ from core.workflow.nodes.human_input.entities import ( MemberRecipient, UserAction, ) -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -35,7 +34,7 @@ from models.human_input import ( def _build_repository() -> HumanInputFormRepositoryImpl: - return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") + return HumanInputFormRepositoryImpl(tenant_id="tenant-id") def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: @@ -389,8 +388,21 @@ def _session_factory(session: _FakeSession): return _factory +def _patch_repo_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + """Patch repository's global session factory to return our fake session. + + The repositories under test now use a global session factory; patch its + create_session method so unit tests don't hit a real database. + """ + monkeypatch.setattr( + "core.repositories.human_input_repository.session_factory.create_session", + _session_factory(session), + raising=True, + ) + + class TestHumanInputFormRepositoryImplPublicMethods: - def test_get_form_returns_entity_and_recipients(self): + def test_get_form_returns_entity_and_recipients(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -408,7 +420,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: access_token="token-123", ) session = _FakeSession(scalars_results=[form, [recipient]]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -418,13 +431,14 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" - def test_get_form_returns_none_when_missing(self): + def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") assert repo.get_form("run-1", "node-1") is None - def test_get_form_returns_unsubmitted_state(self): + def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -436,7 +450,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: expiration_time=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -445,7 +460,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert entity.selected_action_id is None assert entity.submitted_data is None - def test_get_form_returns_submission_when_completed(self): + def test_get_form_returns_submission_when_completed(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -460,7 +475,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: submitted_at=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -471,7 +487,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: class TestHumanInputFormSubmissionRepository: - def test_get_by_token_returns_record(self): + def test_get_by_token_returns_record(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -490,7 +506,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_token("token-123") @@ -499,7 +516,7 @@ class TestHumanInputFormSubmissionRepository: assert record.recipient_type == RecipientType.STANDALONE_WEB_APP assert record.submitted is False - def test_get_by_form_id_and_recipient_type_uses_recipient(self): + def test_get_by_form_id_and_recipient_type_uses_recipient(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -518,7 +535,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_form_id_and_recipient_type( form_id=form.id, @@ -553,7 +571,8 @@ class TestHumanInputFormSubmissionRepository: forms={form.id: form}, recipients={recipient.id: recipient}, ) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record: HumanInputFormRecord = repo.mark_submitted( form_id=form.id, diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 07f28f162a..bae5bae06d 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -10,11 +10,11 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 485be90eae..c880b8d41b 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -16,11 +16,11 @@ from sqlalchemy import Engine from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index b9c5fbd7d8..251d6fd25e 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,6 @@ import json -from core.workflow.file import File, FileTransferMethod, FileType, FileUploadConfig +from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 5a7547e85c..92e4b58473 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -6,7 +6,7 @@ from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 636fac7a40..90ed1647aa 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -12,9 +12,9 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormOption, diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 3163d53b87..3abfb8c9f8 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -2,8 +2,8 @@ import pytest from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings -from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting diff --git a/api/tests/unit_tests/core/test_trigger_debug_event_selectors.py b/api/tests/unit_tests/core/test_trigger_debug_event_selectors.py index 2b508ca654..14b42adbbe 100644 --- a/api/tests/unit_tests/core/test_trigger_debug_event_selectors.py +++ b/api/tests/unit_tests/core/test_trigger_debug_event_selectors.py @@ -6,7 +6,7 @@ import pytest import pytz from core.trigger.debug import event_selectors -from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig +from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig class _DummyRedis: diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index a9af8bea1d..d47d4d6130 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -3,10 +3,10 @@ import dataclasses from pydantic import BaseModel from core.helper import encrypter -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.segments import ( +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -22,8 +22,8 @@ from core.workflow.variables.segments import ( StringSegment, get_segment_discriminator, ) -from core.workflow.variables.types import SegmentType -from core.workflow.variables.variables import ( +from dify_graph.variables.types import SegmentType +from dify_graph.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, ArrayNumberVariable, diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index 51c279d4eb..8704e3a8e9 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,6 +1,6 @@ import pytest -from core.workflow.variables.types import ArrayValidation, SegmentType +from dify_graph.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 642ada03f1..c01b58d0db 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -10,10 +10,10 @@ from typing import Any import pytest -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File -from core.workflow.variables.segment_group import SegmentGroup -from core.workflow.variables.segments import ( +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.variables.segment_group import SegmentGroup +from dify_graph.variables.segments import ( ArrayFileSegment, BooleanSegment, FileSegment, @@ -22,7 +22,7 @@ from core.workflow.variables.segments import ( ObjectSegment, StringSegment, ) -from core.workflow.variables.types import ArrayValidation, SegmentType +from dify_graph.variables.types import ArrayValidation, SegmentType def create_test_file( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 6fc162e533..dd0fe2e65a 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from core.workflow.variables import ( +from dify_graph.variables import ( ArrayFileVariable, ArrayVariable, FloatVariable, @@ -11,7 +11,7 @@ from core.workflow.variables import ( SegmentType, StringVariable, ) -from core.workflow.variables.variables import VariableBase +from dify_graph.variables.variables import VariableBase def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index 8dd669e17f..d09b8397c3 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from core.workflow.context.execution_context import ( +from dify_graph.context.execution_context import ( AppContext, ExecutionContext, ExecutionContextBuilder, @@ -286,7 +286,7 @@ class TestCaptureCurrentContext: def test_capture_current_context_returns_context(self): """Test that capture_current_context returns a valid context.""" - from core.workflow.context.execution_context import capture_current_context + from dify_graph.context.execution_context import capture_current_context result = capture_current_context() @@ -303,7 +303,7 @@ class TestCaptureCurrentContext: test_var = contextvars.ContextVar("capture_test_var") test_var.set("test_value_123") - from core.workflow.context.execution_context import capture_current_context + from dify_graph.context.execution_context import capture_current_context result = capture_current_context() @@ -313,12 +313,12 @@ class TestCaptureCurrentContext: class TestTenantScopedContextRegistry: def setup_method(self): - from core.workflow.context import reset_context_provider + from dify_graph.context import reset_context_provider reset_context_provider() def teardown_method(self): - from core.workflow.context import reset_context_provider + from dify_graph.context import reset_context_provider reset_context_provider() @@ -333,7 +333,7 @@ class TestTenantScopedContextRegistry: assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" def test_missing_provider_raises_keyerror(self): - from core.workflow.context import ContextProviderNotFoundError + from dify_graph.context import ContextProviderNotFoundError with pytest.raises(ContextProviderNotFoundError): read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 8d49394653..0df4927697 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock, patch import pytest -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool class StubCoordinator: @@ -115,7 +115,7 @@ class TestGraphRuntimeState: queue = state.ready_queue - from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue + from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue assert isinstance(queue, InMemoryReadyQueue) @@ -124,7 +124,7 @@ class TestGraphRuntimeState: execution = state.graph_execution - from core.workflow.graph_engine.domain.graph_execution import GraphExecution + from dify_graph.graph_engine.domain.graph_execution import GraphExecution assert isinstance(execution, GraphExecution) assert execution.workflow_id == "" @@ -139,7 +139,7 @@ class TestGraphRuntimeState: mock_graph = MagicMock() with patch( - "core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True + "dify_graph.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True ) as coordinator_cls: coordinator_instance = coordinator_cls.return_value state.configure(graph=mock_graph) diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py index 6144df06e0..158f7018b5 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -5,7 +5,7 @@ Tests for PauseReason discriminated union serialization/deserialization. import pytest from pydantic import BaseModel, ValidationError -from core.workflow.entities.pause_reason import ( +from dify_graph.entities.pause_reason import ( HumanInputRequired, PauseReason, SchedulingPause, diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py index f3197ea282..2d4c7f7b77 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ b/api/tests/unit_tests/core/workflow/entities/test_template.py @@ -1,6 +1,6 @@ """Tests for template module.""" -from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment +from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment class TestTemplate: diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index d4254df319..6100ebede5 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -1,5 +1,5 @@ -from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ( +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import ( BooleanSegment, IntegerSegment, NoneSegment, diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py index a4b1189a1c..4035c1a871 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py @@ -8,8 +8,8 @@ from typing import Any import pytest -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution -from core.workflow.enums import NodeType +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution +from dify_graph.enums import NodeType class TestWorkflowNodeExecutionProcessDataTruncation: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py index 01b514ed7c..c46b9e51fd 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph.py @@ -2,10 +2,10 @@ from unittest.mock import Mock -from core.workflow.enums import NodeExecutionType, NodeState, NodeType -from core.workflow.graph.edge import Edge -from core.workflow.graph.graph import Graph -from core.workflow.nodes.base.node import Node +from dify_graph.enums import NodeExecutionType, NodeState, NodeType +from dify_graph.graph.edge import Edge +from dify_graph.graph.graph import Graph +from dify_graph.nodes.base.node import Node def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py index 15d1dcb48d..bd4a0f32e2 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock import pytest -from core.workflow.enums import NodeType -from core.workflow.graph import Graph -from core.workflow.nodes.base.node import Node +from dify_graph.enums import NodeType +from dify_graph.graph import Graph +from dify_graph.nodes.base.node import Node def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index 6858120335..b93f18c5bd 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -4,15 +4,13 @@ from typing import Any import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph.validation import GraphValidationError -from core.workflow.nodes import NodeType -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.graph import Graph +from dify_graph.graph.validation import GraphValidationError +from dify_graph.nodes import NodeType +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params def _build_iteration_graph(node_id: str) -> dict[str, Any]: @@ -53,14 +51,14 @@ def _build_loop_graph(node_id: str) -> dict[str, Any]: def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + user_from="account", + invoke_from="debugger", call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index 5716aae4c7..b98d56147e 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,16 +6,15 @@ from dataclasses import dataclass import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType -from core.workflow.graph import Graph -from core.workflow.graph.validation import GraphValidationError -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.node import Node -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from dify_graph.entities import GraphInitParams +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType +from dify_graph.graph import Graph +from dify_graph.graph.validation import GraphValidationError +from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.nodes.base.node import Node +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _TestNodeData(BaseNodeData): @@ -92,14 +91,14 @@ class _SimpleNodeFactory: @pytest.fixture def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + user_from="account", + invoke_from="service-api", call_depth=0, ) variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 3fff4cf6a9..40ed61eb02 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -68,7 +68,7 @@ print(f"Success rate: {suite_result.success_rate:.1f}%") #### Event Sequence Validation ```python -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, @@ -376,39 +376,39 @@ See `test_mock_example.py` for comprehensive examples including: ```bash # Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py # Run with specific test patterns -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py -k "test_echo" +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -k "test_echo" # Run with verbose output -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py -v +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -v ``` ### Mock System Tests ```bash # Run auto-mock system tests -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_auto_mock_system.py # Run examples -uv run python api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py +uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_example.py # Run simple validation -uv run python api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_simple.py ``` ### All Tests ```bash # Run all graph engine tests -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/ +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ # Run with coverage -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/ --cov=core.workflow.graph_engine +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ --cov=dify_graph.graph_engine # Run in parallel -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/ -n auto +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ -n auto ``` ## Troubleshooting diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index db9b977e4a..4dec618e49 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,15 +3,15 @@ import json from unittest.mock import MagicMock -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import ( +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.graph_engine.entities.commands import ( AbortCommand, CommandType, GraphEngineCommand, UpdateVariablesCommand, VariableUpdate, ) -from core.workflow.variables import IntegerVariable, StringVariable +from dify_graph.variables import IntegerVariable, StringVariable class TestRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index 65bd3d87d4..011c80df96 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -2,18 +2,18 @@ from __future__ import annotations -from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_engine.domain.graph_execution import GraphExecution -from core.workflow.graph_engine.event_management.event_handlers import EventHandler -from core.workflow.graph_engine.event_management.event_manager import EventManager -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import RetryConfig -from core.workflow.runtime import GraphRuntimeState, VariablePool +from dify_graph.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.graph_engine.domain.graph_execution import GraphExecution +from dify_graph.graph_engine.event_management.event_handlers import EventHandler +from dify_graph.graph_engine.event_management.event_manager import EventManager +from dify_graph.graph_engine.graph_state_manager import GraphStateManager +from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue +from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.entities import RetryConfig +from dify_graph.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py index 15eac6b537..25494dc647 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py @@ -4,9 +4,9 @@ from __future__ import annotations import logging -from core.workflow.graph_engine.event_management.event_manager import EventManager -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphEngineEvent +from dify_graph.graph_engine.event_management.event_manager import EventManager +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphEngineEvent class _FaultyLayer(GraphEngineLayer): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py index 0019020ede..73d59ea4e9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, create_autospec -from core.workflow.graph import Edge, Graph -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.graph_traversal.skip_propagator import SkipPropagator +from dify_graph.graph import Edge, Graph +from dify_graph.graph_engine.graph_state_manager import GraphStateManager +from dify_graph.graph_engine.graph_traversal.skip_propagator import SkipPropagator class TestSkipPropagator: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py index 2ef23c7f0f..fc8133f5e1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -7,8 +7,8 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRecipientEntity, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 903800ce88..3d8de0a00d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -10,7 +10,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType @pytest.fixture @@ -63,7 +63,7 @@ def mock_llm_node(): def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" from core.tools.entities.tool_entities import ToolProviderType - from core.workflow.nodes.tool.entities import ToolNodeData + from dify_graph.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from core.workflow.graph_events.node import NodeRunSucceededEvent - from core.workflow.node_events.base import NodeRunResult + from dify_graph.graph_events.node import NodeRunSucceededEvent + from dify_graph.node_events.base import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py index f1086c9936..db32527849 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -2,13 +2,13 @@ from __future__ import annotations import pytest -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers.base import ( +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import ( GraphEngineLayer, GraphEngineLayerNotInitializedError, ) -from core.workflow.graph_events import GraphEngineEvent +from dify_graph.graph_events import GraphEngineEvent from ..test_table_runner import WorkflowRunner diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 9a491d24e1..819fd67f9d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock, patch from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.entities.commands import CommandType -from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.entities.commands import CommandType +from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeRunResult def _build_succeeded_event() -> NodeRunSucceededEvent: @@ -32,6 +32,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -52,6 +53,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = NodeType.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -72,6 +74,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = NodeType.START node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node._model_instance = object() result_event = _build_succeeded_event() @@ -88,6 +91,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -109,6 +113,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index ade846df28..b4a7cec494 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -16,7 +16,7 @@ import pytest from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType class TestObservabilityLayerInitialization: @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from core.workflow.node_events.base import NodeRunResult + from dify_graph.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from core.workflow.node_events.base import NodeRunResult + from dify_graph.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from core.workflow.node_events.base import NodeRunResult + from dify_graph.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index c1fc4acd73..50d14ff48f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -5,18 +5,18 @@ from __future__ import annotations import queue from unittest import mock -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.event_management.event_handlers import EventHandler -from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher -from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from core.workflow.graph_events import ( +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.event_management.event_handlers import EventHandler +from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher +from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunPauseRequestedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeRunResult +from dify_graph.node_events import NodeRunResult from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py index fd1e6fc6dc..7af6b26d87 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py @@ -1,4 +1,4 @@ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index b291f95e0f..f886ae1c2b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -7,7 +7,8 @@ for workflows containing nodes that require third-party services. import pytest -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from .test_table_runner import TableTestRunner, WorkflowTestCase @@ -199,22 +200,19 @@ def test_mock_config_builder(): def test_mock_factory_node_type_detection(): """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), @@ -309,11 +307,9 @@ def test_workflow_without_auto_mock(): def test_register_custom_mock_node(): """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.nodes.template_transform import TemplateTransformNode - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.nodes.template_transform import TemplateTransformNode + from dify_graph.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory @@ -323,15 +319,14 @@ def test_register_custom_mock_node(): # Custom mock implementation pass - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py index b04643b78a..30acbdaf3d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py @@ -1,4 +1,4 @@ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 6c3700ea2b..765c4deba3 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,24 +3,23 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.entities.commands import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.entities.commands import ( AbortCommand, CommandType, PauseCommand, UpdateVariablesCommand, VariableUpdate, ) -from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.variables import IntegerVariable, StringVariable -from models.enums import UserFrom +from dify_graph.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.variables import IntegerVariable, StringVariable def test_abort_command(): @@ -41,13 +40,17 @@ def test_abort_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -99,7 +102,7 @@ def test_redis_channel_serialization(): mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel + from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel # Create channel with a specific key channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") @@ -151,13 +154,17 @@ def test_pause_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -207,13 +214,17 @@ def test_update_variables_command_updates_pool(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py index 96926797ec..3a9a0b18bc 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py @@ -7,7 +7,7 @@ This test suite validates the behavior of a workflow that: 3. Handles multiple answer nodes with different outputs """ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py index ee944c8e3e..cde99196c8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py @@ -6,10 +6,10 @@ This test validates that: - When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) """ -from core.workflow.enums import NodeType -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from dify_graph.enums import NodeType +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_events import ( GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py index bf8034487c..b88c15ea2a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py @@ -1,10 +1,10 @@ import queue from datetime import datetime -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher -from core.workflow.graph_events import NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher +from dify_graph.graph_events import NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult class StubExecutionCoordinator: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py index b1380cd6d2..c87dc75b95 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py @@ -6,7 +6,7 @@ field is missing from the output configuration, ensuring backward compatibility with older workflow definitions. """ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py index 53de8908a8..35406997ed 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock import pytest -from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor -from core.workflow.graph_engine.domain.graph_execution import GraphExecution -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from core.workflow.graph_engine.worker_management.worker_pool import WorkerPool +from dify_graph.graph_engine.command_processing.command_processor import CommandProcessor +from dify_graph.graph_engine.domain.graph_execution import GraphExecution +from dify_graph.graph_engine.graph_state_manager import GraphStateManager +from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from dify_graph.graph_engine.worker_management.worker_pool import WorkerPool def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 5a55d7086e..b9ae680f52 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,15 +10,15 @@ import time from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from core.workflow.enums import ErrorStrategy -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from dify_graph.enums import ErrorStrategy +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_events import ( GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType +from dify_graph.nodes.base.entities import DefaultValue, DefaultValueType # Import the test framework from the new module from .test_mock_config import MockConfigBuilder @@ -455,7 +455,7 @@ def test_if_else_workflow_property_diverse_inputs(query_input): # Tests for the Layer system def test_layer_system_basic(): """Test basic layer functionality with DebugLoggingLayer.""" - from core.workflow.graph_engine.layers import DebugLoggingLayer + from dify_graph.graph_engine.layers import DebugLoggingLayer runner = WorkflowRunner() @@ -495,7 +495,7 @@ def test_layer_system_basic(): def test_layer_chaining(): """Test chaining multiple layers.""" - from core.workflow.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer + from dify_graph.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer # Create a custom test layer class TestLayer(GraphEngineLayer): @@ -549,7 +549,7 @@ def test_layer_chaining(): def test_layer_error_handling(): """Test that layer errors don't crash the engine.""" - from core.workflow.graph_engine.layers import GraphEngineLayer + from dify_graph.graph_engine.layers import GraphEngineLayer # Create a layer that throws errors class FaultyLayer(GraphEngineLayer): @@ -591,7 +591,7 @@ def test_layer_error_handling(): def test_event_sequence_validation(): """Test the new event sequence validation feature.""" - from core.workflow.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent + from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent runner = TableTestRunner() @@ -678,7 +678,7 @@ def test_event_sequence_validation(): def test_event_sequence_validation_with_table_tests(): """Test event sequence validation with table-driven tests.""" - from core.workflow.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent + from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent runner = TableTestRunner() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py index 6385b0b91f..805e7dbbce 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -6,13 +6,13 @@ import json from collections import deque from unittest.mock import MagicMock -from core.workflow.enums import NodeExecutionType, NodeState, NodeType -from core.workflow.graph_engine.domain import GraphExecution -from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator -from core.workflow.graph_engine.response_coordinator.path import Path -from core.workflow.graph_engine.response_coordinator.session import ResponseSession -from core.workflow.graph_events import NodeRunStreamChunkEvent -from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment +from dify_graph.enums import NodeExecutionType, NodeState, NodeType +from dify_graph.graph_engine.domain import GraphExecution +from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator +from dify_graph.graph_engine.response_coordinator.path import Path +from dify_graph.graph_engine.response_coordinator.session import ResponseSession +from dify_graph.graph_events import NodeRunStreamChunkEvent +from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment class CustomGraphExecutionError(Exception): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index 65d34c2009..d54f0be190 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -1,26 +1,27 @@ import time from collections.abc import Mapping -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeState -from core.workflow.graph import Graph -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.llm.entities import ( +from dify_graph.entities import GraphInitParams +from dify_graph.enums import NodeState +from dify_graph.graph import Graph +from dify_graph.graph_engine.graph_state_manager import GraphStateManager +from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -73,11 +74,11 @@ def _build_llm_node( def _build_graph(runtime_state: GraphRuntimeState) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index b117b26b4c..538f53c603 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,10 +4,8 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -16,25 +14,27 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent -from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -47,11 +47,11 @@ def _build_branching_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 45505909ea..36bba6deb6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,10 +3,8 @@ import time from unittest import mock from unittest.mock import MagicMock -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -15,25 +13,27 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent -from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -46,11 +46,11 @@ def _build_llm_human_llm_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index f33d37e8ff..8da179c15e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,34 +1,34 @@ import time from unittest import mock -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.if_else.entities import IfElseNodeData -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.if_else.entities import IfElseNodeData +from dify_graph.nodes.if_else.if_else_node import IfElseNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.utils.condition.entities import Condition +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.utils.condition.entities import Condition +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -37,15 +37,10 @@ from .test_table_runner import TableTestRunner, WorkflowTestCase def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + graph_init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py index 3e21a5b44d..733fd53bc8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py @@ -5,7 +5,7 @@ This test validates the behavior of a loop containing an answer node inside the loop that may produce output errors. """ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunLoopNextEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py index d88c1d9f9e..6ff2722f78 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py @@ -1,4 +1,4 @@ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunLoopNextEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index 5ceb8dd7f7..6041c6ff30 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -11,7 +11,7 @@ from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType @dataclass diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index b862cbe89e..9f33a81985 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -8,9 +8,9 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.enums import NodeType -from core.workflow.nodes.base.node import Node +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import NodeType +from dify_graph.nodes.base.node import Node from .test_mock_nodes import ( MockAgentNode, @@ -28,8 +28,8 @@ from .test_mock_nodes import ( ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState from .test_mock_config import MockConfig diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index aae4de9a27..eb449e6d75 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -5,31 +5,36 @@ Simple test to verify MockNodeFactory works with iteration nodes. import sys from pathlib import Path +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY + # Add api directory to path api_dir = Path(__file__).parent.parent.parent.parent.parent.parent sys.path.insert(0, str(api_dir)) -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory def test_mock_factory_registers_iteration_node(): """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create a MockNodeFactory instance graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -65,10 +70,9 @@ def test_mock_factory_registers_iteration_node(): def test_mock_iteration_node_preserves_config(): """Test that MockIterationNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode # Create mock config @@ -76,13 +80,17 @@ def test_mock_iteration_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -127,10 +135,9 @@ def test_mock_iteration_node_preserves_config(): def test_mock_loop_node_preserves_config(): """Test that MockLoopNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode # Create mock config @@ -138,13 +145,17 @@ def test_mock_loop_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 5aed463a45..3f458e9de9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -11,28 +11,44 @@ from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.code import CodeNode -from core.workflow.nodes.document_extractor import DocumentExtractorNode -from core.workflow.nodes.http_request import HttpRequestNode -from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode -from core.workflow.nodes.llm import LLMNode -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.nodes.parameter_extractor import ParameterExtractorNode -from core.workflow.nodes.question_classifier import QuestionClassifierNode -from core.workflow.nodes.template_transform import TemplateTransformNode -from core.workflow.nodes.tool import ToolNode +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from dify_graph.nodes.agent import AgentNode +from dify_graph.nodes.code import CodeNode +from dify_graph.nodes.document_extractor import DocumentExtractorNode +from dify_graph.nodes.http_request import HttpRequestNode +from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode +from dify_graph.nodes.llm import LLMNode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.parameter_extractor import ParameterExtractorNode +from dify_graph.nodes.question_classifier import QuestionClassifierNode +from dify_graph.nodes.template_transform import TemplateTransformNode +from dify_graph.nodes.template_transform.template_renderer import ( + Jinja2TemplateRenderer, + TemplateRenderError, +) +from dify_graph.nodes.tool import ToolNode if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState from .test_mock_config import MockConfig +class _TestJinja2Renderer(Jinja2TemplateRenderer): + """Simple Jinja2 renderer for tests (avoids code executor).""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + from jinja2 import Template as _Jinja2Template + + try: + return _Jinja2Template(template).render(**variables) + except Exception as exc: # pragma: no cover - pass through as contract error + raise TemplateRenderError(str(exc)) from exc + + class MockNodeMixin: """Mixin providing common mock functionality.""" @@ -50,6 +66,10 @@ class MockNodeMixin: kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + # Ensure TemplateTransformNode receives a renderer now required by constructor + if isinstance(self, TemplateTransformNode): + kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + super().__init__( id=id, config=config, @@ -557,8 +577,8 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): ) -from core.workflow.nodes.iteration import IterationNode -from core.workflow.nodes.loop import LoopNode +from dify_graph.nodes.iteration import IterationNode +from dify_graph.nodes.loop import LoopNode class MockIterationNode(MockNodeMixin, IterationNode): @@ -572,24 +592,20 @@ class MockIterationNode(MockNodeMixin, IterationNode): def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from core.workflow.entities import GraphInitParams - from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine, GraphEngineConfig - from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.graph import Graph + from dify_graph.graph_engine import GraphEngine, GraphEngineConfig + from dify_graph.graph_engine.command_channels import InMemoryChannel + from dify_graph.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -621,7 +637,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): ) if not iteration_graph: - from core.workflow.nodes.iteration.exc import IterationGraphNotFoundError + from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError raise IterationGraphNotFoundError("iteration graph not found") @@ -648,24 +664,20 @@ class MockLoopNode(MockNodeMixin, LoopNode): def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from core.workflow.entities import GraphInitParams - from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine, GraphEngineConfig - from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.graph import Graph + from dify_graph.graph_engine import GraphEngine, GraphEngineConfig + from dify_graph.graph_engine.command_channels import InMemoryChannel + from dify_graph.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 6c4178dfed..1550dca402 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,8 +6,9 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.nodes.code.limits import CodeNodeLimits +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode @@ -39,18 +40,22 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_default_output(self): """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -98,18 +103,22 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_custom_output(self): """Test that MockTemplateTransformNode returns custom configured output.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -158,18 +167,22 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_error_simulation(self): """Test that MockTemplateTransformNode can simulate errors.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -215,19 +228,23 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_with_variables(self): """Test that MockTemplateTransformNode processes templates with variables.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from core.workflow.variables import StringVariable + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool + from dify_graph.variables import StringVariable # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -281,18 +298,22 @@ class TestMockCodeNode: def test_mock_code_node_default_output(self): """Test that MockCodeNode returns default output.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -343,18 +364,22 @@ class TestMockCodeNode: def test_mock_code_node_with_output_schema(self): """Test that MockCodeNode generates outputs based on schema.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -413,18 +438,22 @@ class TestMockCodeNode: def test_mock_code_node_custom_output(self): """Test that MockCodeNode returns custom configured output.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -485,18 +514,22 @@ class TestMockNodeFactory: def test_code_and_template_nodes_mocked_by_default(self): """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -526,18 +559,22 @@ class TestMockNodeFactory: def test_factory_creates_mock_template_transform_node(self): """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -577,18 +614,22 @@ class TestMockNodeFactory: def test_factory_creates_mock_code_node(self): """Test that MockNodeFactory creates MockCodeNode for code type.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 1b781545f5..84d1444585 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -5,11 +5,13 @@ Simple test to validate the auto-mock system without external dependencies. import sys from pathlib import Path +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY + # Add api directory to path api_dir = Path(__file__).parent.parent.parent.parent.parent.parent sys.path.insert(0, str(api_dir)) -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -101,21 +103,24 @@ def test_node_mock_config(): def test_mock_factory_detection(): """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory detection...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -154,21 +159,24 @@ def test_mock_factory_detection(): def test_mock_factory_registration(): """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory registration...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index a6aab81f6c..e681b39cc7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,34 +4,34 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.config import GraphEngineConfig +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.nodes.base.entities import OutputVariableEntity +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params class PauseStateStore(Protocol): @@ -126,11 +126,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 62aa56fc57..60167c0441 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,41 +4,41 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.config import GraphEngineConfig +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, NodeRunPauseRequestedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -129,11 +129,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index a93d03c87e..0ac9d6618d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -12,25 +12,24 @@ import time from unittest.mock import MagicMock, patch from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_events import ( GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent -from core.workflow.nodes.llm.node import LLMNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from dify_graph.node_events import NodeRunResult, StreamCompletedEvent +from dify_graph.nodes.llm.node import LLMNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -87,11 +86,11 @@ def test_parallel_streaming_workflow(): graph_config = workflow_config.get("graph", {}) # Create graph initialization parameters - init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + init_params = build_test_graph_init_params( workflow_id="test_workflow", graph_config=graph_config, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, @@ -100,8 +99,8 @@ def test_parallel_streaming_workflow(): # Create variable pool with system variables system_variables = SystemVariable( - user_id=init_params.user_id, - app_id=init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=init_params.workflow_id, files=[], query="Tell me about yourself", # User query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index 156cfefcd6..7328ce443f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,42 +4,42 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.config import GraphEngineConfig +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -121,11 +121,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 700b3f4b8b..15a7de3c52 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,33 +3,33 @@ import time from typing import Any from unittest.mock import MagicMock -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.graph import GraphRunStartedEvent -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.graph_events.graph import GraphRunStartedEvent +from dify_graph.nodes.base.entities import OutputVariableEntity +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import ( HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: @@ -79,11 +79,11 @@ def _build_human_input_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index 0920940e51..9c84f42db6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -12,9 +12,9 @@ import pytest import redis from core.app.apps.base_app_queue_manager import AppQueueManager -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from dify_graph.graph_engine.manager import GraphEngineManager class TestRedisStopIntegration: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py index 822b6a808f..4ba1e6ae0b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py @@ -7,14 +7,15 @@ from core.workflow.enums import NodeType from core.workflow.graph.graph import Graph from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator from core.workflow.graph_engine.response_coordinator.session import ResponseSession +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.template import Template, VariableSegment + from core.workflow.graph_events import ( ChunkType, NodeRunStreamChunkEvent, ToolCall, ToolResult, ) -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.template import Template, VariableSegment from core.workflow.runtime import VariablePool diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py index 99157a7c3e..4f1741d4fb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py @@ -1,4 +1,4 @@ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 5cbb7cf36e..767a8f60ce 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,27 +12,29 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, cast -from core.app.workflow.node_factory import DifyNodeFactory +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import ( +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -48,6 +50,47 @@ from .test_mock_factory import MockNodeFactory logger = logging.getLogger(__name__) +class _TableTestChildEngineBuilder: + def __init__(self, *, use_mock_factory: bool, mock_config: MockConfig | None) -> None: + self._use_mock_factory = use_mock_factory + self._mock_config = mock_config + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + if self._use_mock_factory: + node_factory = MockNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=self._mock_config, + ) + else: + node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + + child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) + if not child_graph: + raise ValueError("child graph not found") + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + @dataclass class WorkflowTestCase: """Represents a single test case for table-driven testing.""" @@ -149,19 +192,23 @@ class WorkflowRunner: raise ValueError("Fixture missing workflow.graph configuration") graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config=graph_config, - user_id="test_user", - user_from="account", - invoke_from="debugger", # Set to debugger to avoid conversation_id requirement + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, # Set to debugger to avoid conversation_id requirement + } + }, call_depth=0, ) system_variables = SystemVariable( - user_id=graph_init_params.user_id, - app_id=graph_init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, @@ -315,6 +362,10 @@ class TableTestRunner: scale_up_threshold=self.graph_engine_scale_up_threshold, scale_down_idle_time=self.graph_engine_scale_down_idle_time, ), + child_engine_builder=_TableTestChildEngineBuilder( + use_mock_factory=test_case.use_auto_mock, + mock_config=test_case.mock_config, + ), ) # Execute and collect events diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index bfcc6e1a5f..7f26bc11a7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -1,6 +1,6 @@ -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_events import ( GraphRunSucceededEvent, NodeRunStreamChunkEvent, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py index 221e1291d1..f63e8ff4ce 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py @@ -2,9 +2,9 @@ from unittest.mock import patch import pytest -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from .test_table_runner import TableTestRunner, WorkflowTestCase diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 1e95ec1970..f0d80af1ed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,16 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.nodes.answer.answer_node import AnswerNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_answer(): @@ -36,11 +35,11 @@ def test_execute_answer(): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 21a642c2f8..bf814d0c97 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,11 +1,11 @@ import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.node import Node +from dify_graph.enums import NodeType +from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.nodes.base.node import Node # Ensures that all node classes are imported. -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING # Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed. _ = NODE_TYPE_CLASSES_MAPPING diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index 45d222b98c..f8d799e446 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,15 +1,15 @@ import types from collections.abc import Mapping -from core.workflow.enums import NodeType -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.node import Node +from dify_graph.enums import NodeType +from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) -from core.workflow.nodes.variable_assigner.v1.node import ( +from dify_graph.nodes.variable_assigner.v1.node import ( VariableAssignerNode as VariableAssignerV1, ) -from core.workflow.nodes.variable_assigner.v2.node import ( +from dify_graph.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 00c8cb3779..95cb653635 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,13 +1,13 @@ from configs import dify_config -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData -from core.workflow.nodes.code.exc import ( +from dify_graph.nodes.code.code_node import CodeNode +from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData +from dify_graph.nodes.code.exc import ( CodeNodeError, DepthLimitError, OutputValidationError, ) -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.variables.types import SegmentType CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py index 28d59c3568..de7ed0815e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData +from dify_graph.variables.types import SegmentType class TestCodeNodeDataOutput: diff --git a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py index b0115310a6..e30b3776a4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py @@ -4,6 +4,9 @@ from io import BytesIO from typing import Any from unittest.mock import MagicMock +from core.workflow.enums import WorkflowNodeExecutionStatus +from core.workflow.system_variable import SystemVariable + from core.entities.provider_entities import BasicProviderConfig from core.virtual_environment.__base.entities import ( Arch, @@ -17,10 +20,8 @@ from core.virtual_environment.__base.virtual_environment import VirtualEnvironme from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser from core.virtual_environment.channel.transport import NopTransportWriteCloser from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.nodes.command.node import CommandNode from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable class FakeVirtualEnvironment(VirtualEnvironment): diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 584ed23e91..db096b1aed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,6 +1,7 @@ -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from dify_graph.nodes.datasource.datasource_node import DatasourceNode class _VarSeg: @@ -28,13 +29,17 @@ class _GraphState: class _GraphParams: - tenant_id = "t1" - app_id = "app-1" workflow_id = "wf-1" graph_config = {} - user_id = "u1" - user_from = "account" - invoke_from = "debugger" + run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "t1", + "app_id": "app-1", + "user_id": "u1", + "user_from": "account", + "invoke_from": "debugger", + } + } call_depth = 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py index 90f4cd018b..cd822a6f89 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.http_request import build_http_request_config +from dify_graph.nodes.http_request import build_http_request_config def test_build_http_request_config_uses_literal_defaults(): diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py index 47a5df92a4..fec6ad90eb 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, PropertyMock, patch import httpx import pytest -from core.workflow.nodes.http_request.entities import Response +from dify_graph.nodes.http_request.entities import Response @pytest.fixture @@ -104,7 +104,7 @@ def test_mimetype_based_detection(mock_response, content_type, expected_main_typ mock_response.headers = {"content-type": content_type} type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - with patch("core.workflow.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: + with patch("dify_graph.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: # Mock the return value based on expected_main_type if expected_main_type: mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index 67da890eb2..cea7195417 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -2,19 +2,19 @@ import pytest from configs import dify_config from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.file.file_manager import file_manager -from core.workflow.nodes.http_request import ( +from dify_graph.file.file_manager import file_manager +from dify_graph.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeConfig, HttpRequestNodeData, ) -from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout -from core.workflow.nodes.http_request.exc import AuthorizationConfigError -from core.workflow.nodes.http_request.executor import Executor -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout +from dify_graph.nodes.http_request.exc import AuthorizationConfigError +from dify_graph.nodes.http_request.executor import Executor +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index cad0466809..5e34bf1d94 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -4,17 +4,16 @@ from typing import Any import httpx import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.file.file_manager import file_manager -from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.file.file_manager import file_manager +from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=10, @@ -99,11 +98,11 @@ def _build_http_node( ], "edges": [], } - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -162,7 +161,7 @@ def test_run_passes_node_data_ssl_verify_to_executor(monkeypatch: pytest.MonkeyP ) ) - monkeypatch.setattr("core.workflow.nodes.http_request.node.Executor", FakeExecutor) + monkeypatch.setattr("dify_graph.nodes.http_request.node.Executor", FakeExecutor) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index ca4a887d20..d4939b1071 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,5 +1,5 @@ -from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients -from core.workflow.runtime import VariablePool +from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients +from dify_graph.runtime import VariablePool def test_render_body_template_replaces_variable_values(): diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index bfe7b03c13..55aa62a1c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -8,10 +8,11 @@ from unittest.mock import MagicMock import pytest from pydantic import ValidationError -from core.workflow.entities import GraphInitParams -from core.workflow.node_events import PauseRequestedEvent -from core.workflow.node_events.node import StreamCompletedEvent -from core.workflow.nodes.human_input.entities import ( +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.node_events import PauseRequestedEvent +from dify_graph.node_events.node import StreamCompletedEvent +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -24,7 +25,7 @@ from core.workflow.nodes.human_input.entities import ( WebAppDeliveryMethod, _WebAppDeliveryConfig, ) -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( ButtonStyle, DeliveryMethodType, EmailRecipientType, @@ -32,10 +33,10 @@ from core.workflow.nodes.human_input.enums import ( PlaceholderType, TimeoutUnit, ) -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.repositories.human_input_form_repository import HumanInputFormRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository @@ -314,13 +315,17 @@ class TestHumanInputNodeVariableResolution: variable_pool.add(("start", "name"), "Jane Doe") runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -384,13 +389,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -439,13 +448,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user-123", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user-123", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -550,13 +563,17 @@ class TestHumanInputNodeRenderedContent: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index a19ee4dee3..1fea19e795 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,20 +1,19 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.enums import NodeType -from core.workflow.graph_events import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.enums import NodeType +from dify_graph.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now -from models.enums import UserFrom class _FakeFormRepository: @@ -32,13 +31,17 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -92,13 +95,17 @@ def _build_timeout_node() -> HumanInputNode: start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py index d669cc7465..93c199514e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.iteration.entities import ( +from dify_graph.nodes.iteration.entities import ( ErrorHandleMode, IterationNodeData, IterationStartNodeData, diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py index b67e84d1d4..b95a7ad8ae 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py @@ -1,6 +1,6 @@ -from core.workflow.enums import NodeType -from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from core.workflow.nodes.iteration.exc import ( +from dify_graph.enums import NodeType +from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from dify_graph.nodes.iteration.exc import ( InvalidIteratorValueError, IterationGraphNotFoundError, IterationIndexNotFoundError, @@ -8,7 +8,7 @@ from core.workflow.nodes.iteration.exc import ( IteratorVariableNotFoundError, StartNodeIdNotFoundError, ) -from core.workflow.nodes.iteration.iteration_node import IterationNode +from dify_graph.nodes.iteration.iteration_node import IterationNode class TestIterationNodeExceptions: diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py new file mode 100644 index 0000000000..2eb4feef5f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -0,0 +1,100 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +import pytest + +from dify_graph.entities import GraphInitParams +from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError +from dify_graph.nodes.iteration.iteration_node import IterationNode +from dify_graph.runtime import ( + ChildEngineBuilderNotConfiguredError, + ChildGraphNotFoundError, + GraphRuntimeState, + VariablePool, +) +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params + + +class _MissingGraphBuilder: + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> object: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + +def _build_runtime_state() -> GraphRuntimeState: + return GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + start_at=0.0, + ) + + +def _build_iteration_node( + *, + graph_config: Mapping[str, Any], + runtime_state: GraphRuntimeState, + start_node_id: str, +) -> IterationNode: + init_params = build_test_graph_init_params(graph_config=graph_config) + return IterationNode( + id="iteration-node", + config={ + "id": "iteration-node", + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": ["iteration-node", "output"], + "start_node_id": start_node_id, + }, + }, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + +def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing(): + runtime_state = _build_runtime_state() + graph_init_params = build_test_graph_init_params() + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + runtime_state.create_child_engine( + workflow_id="workflow", + graph_init_params=graph_init_params, + graph_runtime_state=_build_runtime_state(), + graph_config={}, + root_node_id="root", + ) + + +def test_iteration_node_only_translates_child_graph_not_found_error(): + runtime_state = _build_runtime_state() + runtime_state.bind_child_engine_builder(_MissingGraphBuilder()) + node = _build_iteration_node( + graph_config={"nodes": [{"id": "present-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="missing-node", + ) + + with pytest.raises(IterationGraphNotFoundError): + node._create_graph_engine(index=0, item="item") + + +def test_iteration_node_propagates_non_graph_not_found_errors(): + runtime_state = _build_runtime_state() + node = _build_iteration_node( + graph_config={"nodes": [{"id": "start-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="start-node", + ) + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + node._create_graph_engine(index=0, item="item") diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 38e434d7d8..8116fc8b3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -4,28 +4,27 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams -from core.workflow.enums import SystemVariableKey, WorkflowNodeExecutionStatus -from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData -from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError -from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode -from core.workflow.repositories.index_processor_protocol import IndexProcessorProtocol, Preview, PreviewItem -from core.workflow.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.segments import StringSegment -from models.enums import UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus +from dify_graph.nodes.knowledge_index.entities import KnowledgeIndexNodeData +from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError +from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode +from dify_graph.repositories.index_processor_protocol import IndexProcessorProtocol, Preview, PreviewItem +from dify_graph.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segments import StringSegment +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index a60dde199d..e194d66ee3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -4,33 +4,32 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.nodes.knowledge_retrieval.entities import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.nodes.knowledge_retrieval.entities import ( KnowledgeRetrievalNodeData, MultipleRetrievalConfig, RerankingModelConfig, SingleRetrievalConfig, ) -from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import StringSegment -from models.enums import UserFrom +from dify_graph.nodes.knowledge_retrieval.exc import RateLimitExceededError +from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from dify_graph.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import StringSegment +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -155,7 +154,7 @@ class TestKnowledgeRetrievalNode: ): """Test _run with query variable in single mode.""" # Arrange - from core.workflow.nodes.llm.entities import ModelConfig + from dify_graph.nodes.llm.entities import ModelConfig query = "What is Python?" query_selector = ["start", "query"] @@ -444,7 +443,7 @@ class TestFetchDatasetRetriever: ): """Test _fetch_dataset_retriever in single mode.""" # Arrange - from core.workflow.nodes.llm.entities import ModelConfig + from dify_graph.nodes.llm.entities import ModelConfig query = "What is Python?" variables = {"query": query} diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 63a87623da..25760ba352 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,13 @@ from unittest.mock import MagicMock import pytest -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams -from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.nodes.list_operator.node import ListOperatorNode -from core.workflow.variables import ArrayNumberSegment, ArrayStringSegment -from models.workflow import WorkflowType +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.nodes.list_operator.node import ListOperatorNode +from dify_graph.runtime import GraphRuntimeState +from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment class TestListOperatorNode: @@ -22,43 +21,40 @@ class TestListOperatorNode: mock_state.variable_pool = mock_variable_pool return mock_state - @pytest.fixture - def mock_graph(self): - """Create mock Graph.""" - return MagicMock(spec=Graph) - @pytest.fixture def graph_init_params(self): """Create GraphInitParams fixture.""" return GraphInitParams( - tenant_id="test", - app_id="test", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test", graph_config={}, - user_id="test", - user_from="test", - invoke_from="test", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": "test", + "invoke_from": "test", + } + }, call_depth=0, ) @pytest.fixture - def list_operator_node_factory(self, graph_init_params, mock_graph, mock_graph_runtime_state): + def list_operator_node_factory(self, graph_init_params, mock_graph_runtime_state): """Factory fixture for creating ListOperatorNode instances.""" def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable return ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) return _create_node - def test_node_initialization(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_node_initialization(self, mock_graph_runtime_state, graph_init_params): """Test node initializes correctly.""" config = { "title": "List Operator", @@ -70,9 +66,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -101,7 +96,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_empty_array(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_empty_array(self, mock_graph_runtime_state, graph_init_params): """Test with empty array.""" config = { "title": "Test", @@ -116,9 +111,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -129,7 +123,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] is None assert result.outputs["last_record"] is None - def test_run_with_filter_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with contains condition.""" config = { "title": "Test", @@ -148,9 +142,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -159,7 +152,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple"] - def test_run_with_filter_not_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_not_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with not contains condition.""" config = { "title": "Test", @@ -178,9 +171,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -189,7 +181,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["banana", "cherry"] - def test_run_with_number_filter_greater_than(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_greater_than(self, mock_graph_runtime_state, graph_init_params): """Test filter with greater than condition on numbers.""" config = { "title": "Test", @@ -208,9 +200,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -219,7 +210,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [7, 9, 11] - def test_run_with_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in ascending order.""" config = { "title": "Test", @@ -237,9 +228,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -248,7 +238,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_order_descending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_descending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in descending order.""" config = { "title": "Test", @@ -266,9 +256,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -277,7 +266,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["cherry", "banana", "apple"] - def test_run_with_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_limit(self, mock_graph_runtime_state, graph_init_params): """Test with limit enabled.""" config = { "title": "Test", @@ -295,9 +284,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -306,7 +294,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana"] - def test_run_with_filter_order_and_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_order_and_limit(self, mock_graph_runtime_state, graph_init_params): """Test with filter, order, and limit combined.""" config = { "title": "Test", @@ -331,9 +319,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -342,7 +329,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [9, 8, 7] - def test_run_with_variable_not_found(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_variable_not_found(self, mock_graph_runtime_state, graph_init_params): """Test when variable is not found.""" config = { "title": "Test", @@ -356,9 +343,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -367,7 +353,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Variable not found" in result.error - def test_run_with_first_and_last_record(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_first_and_last_record(self, mock_graph_runtime_state, graph_init_params): """Test first_record and last_record outputs.""" config = { "title": "Test", @@ -382,9 +368,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -394,7 +379,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] == "first" assert result.outputs["last_record"] == "last" - def test_run_with_filter_startswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_startswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with startswith condition.""" config = { "title": "Test", @@ -413,9 +398,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -424,7 +408,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "application"] - def test_run_with_filter_endswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_endswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with endswith condition.""" config = { "title": "Test", @@ -443,9 +427,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -454,7 +437,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple", "table"] - def test_run_with_number_filter_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with equals condition.""" config = { "title": "Test", @@ -473,9 +456,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -484,7 +466,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [5, 5] - def test_run_with_number_filter_not_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_not_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with not equals condition.""" config = { "title": "Test", @@ -503,9 +485,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -514,7 +495,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [1, 3, 7, 9] - def test_run_with_number_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test number ordering in ascending order.""" config = { "title": "Test", @@ -532,9 +513,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index 0677f1bb52..a3afd1ed5c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -9,8 +9,8 @@ from sqlalchemy import Engine from core.helper import ssrf_proxy from core.tools import signature from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file import FileTransferMethod, FileType, models -from core.workflow.nodes.llm.file_saver import ( +from dify_graph.file import FileTransferMethod, FileType, models +from dify_graph.nodes.llm.file_saver import ( FileSaverImpl, _extract_content_type_and_extension, _get_extension, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 94b5b72ee1..90308facc3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,13 +5,16 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.message_entities import ( +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from dify_graph.entities import GraphInitParams +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, @@ -19,13 +22,10 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.entities import GraphInitParams -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.nodes.llm import llm_utils -from core.workflow.nodes.llm.entities import ( +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.nodes.llm import llm_utils +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, @@ -33,14 +33,14 @@ from core.workflow.nodes.llm.entities import ( VisionConfig, VisionConfigOptions, ) -from core.workflow.nodes.llm.file_saver import LLMFileSaver -from core.workflow.nodes.llm.node import LLMNode, _handle_memory_completion_mode -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment -from models.enums import UserFrom +from dify_graph.nodes.llm.file_saver import LLMFileSaver +from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.provider import ProviderType +from tests.workflow_test_utils import build_test_graph_init_params class MockTokenBufferMemory: @@ -76,11 +76,11 @@ def llm_node_data() -> LLMNodeData: @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="1", - app_id="1", + return build_test_graph_init_params( workflow_id="1", graph_config={}, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, @@ -611,7 +611,7 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("core.workflow.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index ac0c1df9c5..e40d565ef5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -2,10 +2,10 @@ from collections.abc import Mapping, Sequence from pydantic import BaseModel, Field -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelFeature -from core.workflow.file import File -from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage +from dify_graph.file import File +from dify_graph.model_runtime.entities.message_entities import PromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelFeature +from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage class LLMNodeTestScenario(BaseModel): diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py index 2742b7dab0..fd48edc58c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py @@ -1,5 +1,5 @@ -from core.workflow.nodes.parameter_extractor.entities import ParameterConfig -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.parameter_extractor.entities import ParameterConfig +from dify_graph.variables.types import SegmentType class TestParameterConfig: diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index ae229bbe2e..7eca531b62 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -7,17 +7,17 @@ from typing import Any import pytest -from core.model_runtime.entities import LLMMode -from core.workflow.nodes.llm import ModelConfig, VisionConfig -from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData -from core.workflow.nodes.parameter_extractor.exc import ( +from dify_graph.model_runtime.entities import LLMMode +from dify_graph.nodes.llm import ModelConfig, VisionConfig +from dify_graph.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData +from dify_graph.nodes.parameter_extractor.exc import ( InvalidNumberOfParametersError, InvalidSelectValueError, InvalidValueTypeError, RequiredParameterMissingError, ) -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from dify_graph.variables.types import SegmentType from factories.variable_factory import build_segment_with_type diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py index 5eb302798f..e57ebbd83e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from core.workflow.enums import ErrorStrategy -from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData +from dify_graph.enums import ErrorStrategy +from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData class TestTemplateTransformNodeData: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 0fb76fb7e7..6831626f58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,14 +1,14 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams -from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus -from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from models.workflow import WorkflowType +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError +from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params class TestTemplateTransformNode: @@ -24,21 +24,20 @@ class TestTemplateTransformNode: @pytest.fixture def mock_graph(self): - """Create a mock Graph.""" + """Create a mock Graph (kept for backward compat in other tests).""" return MagicMock(spec=Graph) @pytest.fixture def graph_init_params(self): """Create a mock GraphInitParams.""" - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", - workflow_type=WorkflowType.WORKFLOW, + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", - user_from="test", - invoke_from="test", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, call_depth=0, ) @@ -55,14 +54,15 @@ class TestTemplateTransformNode: "template": "Hello {{ name }}, you are {{ age }} years old!", } - def test_node_initialization(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node.node_type == NodeType.TEMPLATE_TRANSFORM @@ -70,31 +70,33 @@ class TestTemplateTransformNode: assert len(node._node_data.variables) == 2 assert node._node_data.template == "Hello {{ name }}, you are {{ age }} years old!" - def test_get_title(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" - def test_get_description(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" - def test_get_error_strategy(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_error_strategy(self, mock_graph_runtime_state, graph_init_params): """Test _get_error_strategy method.""" node_data = { "title": "Test", @@ -103,12 +105,13 @@ class TestTemplateTransformNode: "error_strategy": "fail-branch", } + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -127,14 +130,8 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_simple_template( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): - """Test _run with simple template transformation.""" + def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params): + """Test _run with simple template transformation using injected renderer.""" # Setup mock variable pool mock_name_value = MagicMock() mock_name_value.to_object.return_value = "Alice" @@ -147,15 +144,16 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - # Setup mock executor - mock_execute.return_value = "Hello Alice, you are 30 years old!" + # Setup mock renderer + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -165,11 +163,7 @@ class TestTemplateTransformNode: assert result.inputs["name"] == "Alice" assert result.inputs["age"] == 30 - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_none_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with None variable values.""" node_data = { "title": "Test", @@ -178,14 +172,16 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.return_value = None - mock_execute.return_value = "Value: " + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Value: " node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -193,23 +189,19 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs["value"] is None - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_code_execution_error( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): - """Test _run when code execution fails.""" + def test_run_with_render_error(self, basic_node_data, mock_graph_runtime_state, graph_init_params): + """Test _run when template rendering fails.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.side_effect = TemplateRenderError("Template syntax error") + + mock_renderer = MagicMock() + mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -217,23 +209,19 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Template syntax error" in result.error - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_output_length_exceeds_limit( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): + def test_run_output_length_exceeds_limit(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _run when output exceeds maximum length.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.return_value = "This is a very long output that exceeds the limit" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, max_output_length=10, ) @@ -242,13 +230,7 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_complex_jinja2_template( - self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params - ): + def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params): """Test _run with complex Jinja2 template including loops and conditions.""" node_data = { "title": "Complex Template", @@ -272,14 +254,16 @@ class TestTemplateTransformNode: ("sys", "show_total"): mock_show_total, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = "apple, banana, orange (Total: 3)" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -307,11 +291,7 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { "title": "Static Template", @@ -319,14 +299,15 @@ class TestTemplateTransformNode: "template": "This is a static message.", } - mock_execute.return_value = "This is a static message." + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "This is a static message." node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -335,11 +316,7 @@ class TestTemplateTransformNode: assert result.outputs["output"] == "This is a static message." assert result.inputs == {} - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_numeric_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with numeric variable values.""" node_data = { "title": "Numeric Template", @@ -360,14 +337,16 @@ class TestTemplateTransformNode: ("sys", "quantity"): mock_quantity, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = "Total: $31.5" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Total: $31.5" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -375,11 +354,7 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["output"] == "Total: $31.5" - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_dict_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with dictionary variable values.""" node_data = { "title": "Dict Template", @@ -391,14 +366,16 @@ class TestTemplateTransformNode: mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"} mock_graph_runtime_state.variable_pool.get.return_value = mock_user - mock_execute.return_value = "Name: John Doe, Email: john@example.com" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -407,11 +384,7 @@ class TestTemplateTransformNode: assert "John Doe" in result.outputs["output"] assert "john@example.com" in result.outputs["output"] - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_list_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with list variable values.""" node_data = { "title": "List Template", @@ -423,14 +396,16 @@ class TestTemplateTransformNode: mock_tags.to_object.return_value = ["python", "ai", "workflow"] mock_graph_runtime_state.variable_pool.get.return_value = mock_tags - mock_execute.return_value = "Tags: #python #ai #workflow " + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 1854cca236..44abf430c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -2,12 +2,14 @@ from collections.abc import Mapping import pytest -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeType -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.node import Node -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities import GraphInitParams +from dify_graph.enums import NodeType +from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.nodes.base.node import Node +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _SampleNodeData(BaseNodeData): @@ -26,15 +28,10 @@ class _SampleNode(Node[_SampleNodeData]): def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), @@ -56,6 +53,36 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == "account" + assert dify_ctx.invoke_from == "debugger" + + +def test_node_accepts_invoke_from_enum(): + graph_config: dict[str, object] = {} + init_params = build_test_graph_init_params( + graph_config=graph_config, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + start_at=0.0, + ) + + node = _SampleNode( + id="node-1", + config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}}, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == UserFrom.ACCOUNT + assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER + assert node.get_run_context_value("missing") is None + with pytest.raises(ValueError): + node.require_run_context_value("missing") def test_missing_generic_argument_raises_type_error(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 35c59b92c4..5e20b1e12f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -5,31 +5,31 @@ import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData -from core.workflow.nodes.document_extractor.node import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities import GraphInitParams +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from dify_graph.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, ) -from core.workflow.variables import ArrayFileSegment -from core.workflow.variables.segments import ArrayStringSegment -from core.workflow.variables.variables import StringVariable -from models.enums import UserFrom +from dify_graph.variables import ArrayFileSegment +from dify_graph.variables.segments import ArrayStringSegment +from dify_graph.variables.variables import StringVariable +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -44,11 +44,13 @@ def document_extractor_node(graph_init_params): variable_selector=["node_id", "variable_name"], ) node_config = {"id": "test_node_id", "data": node_data.model_dump()} + http_client = Mock() node = DocumentExtractorNode( id="test_node_id", config=node_config, graph_init_params=graph_init_params, graph_runtime_state=Mock(), + http_client=http_client, ) return node @@ -142,19 +144,20 @@ def test_run_extract_text( mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment mock_download = Mock(return_value=file_content) - mock_ssrf_proxy_get = Mock() - mock_ssrf_proxy_get.return_value.content = file_content - mock_ssrf_proxy_get.return_value.raise_for_status = Mock() - monkeypatch.setattr("core.workflow.file.file_manager.download", mock_download) - monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get) + mock_response = Mock() + mock_response.content = file_content + mock_response.raise_for_status = Mock() + document_extractor_node._http_client.get = Mock(return_value=mock_response) + + monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download) if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) elif mime_type.startswith("application/vnd.openxmlformats"): mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) result = document_extractor_node._run() @@ -164,7 +167,7 @@ def test_run_extract_text( assert result.outputs["text"] == ArrayStringSegment(value=expected_text) if transfer_method == FileTransferMethod.REMOTE_URL: - mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") + document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt") elif transfer_method == FileTransferMethod.LOCAL_FILE: mock_download.assert_called_once_with(mock_file) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index bc87a64161..041bd66d03 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -4,30 +4,30 @@ from unittest.mock import MagicMock, Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.graph import Graph -from core.workflow.nodes.if_else.entities import IfElseNodeData -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition -from core.workflow.variables import ArrayFileSegment +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.graph import Graph +from dify_graph.nodes.if_else.entities import IfElseNodeData +from dify_graph.nodes.if_else.if_else_node import IfElseNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition +from dify_graph.variables import ArrayFileSegment from extensions.ext_database import db -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -129,11 +129,11 @@ def test_execute_if_else_result_false(): # Create a simple graph for IfElse node testing graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -230,14 +230,18 @@ def test_array_file_contains_file_name(): # Create properly configured mock for graph_init_params graph_init_params = Mock() - graph_init_params.tenant_id = "test_tenant" - graph_init_params.app_id = "test_app" graph_init_params.workflow_id = "test_workflow" graph_init_params.graph_config = {} - graph_init_params.user_id = "test_user" - graph_init_params.user_from = UserFrom.ACCOUNT - graph_init_params.invoke_from = InvokeFrom.SERVICE_API graph_init_params.call_depth = 0 + graph_init_params.run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + } node = IfElseNode( id=str(uuid.uuid4()), @@ -299,11 +303,11 @@ def test_execute_if_else_boolean_conditions(condition: Condition): """Test IfElseNode with boolean conditions using various operators""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -354,11 +358,11 @@ def test_execute_if_else_boolean_false_conditions(): """Test IfElseNode with boolean conditions that should evaluate to false""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -423,11 +427,11 @@ def test_execute_if_else_boolean_cases_structure(): """Test IfElseNode with boolean conditions using the new cases structure""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 73c17ee45a..6ca72b64b2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,10 +2,11 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.nodes.list_operator.entities import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.nodes.list_operator.entities import ( ExtractConfig, FilterBy, FilterCondition, @@ -14,10 +15,9 @@ from core.workflow.nodes.list_operator.entities import ( Order, OrderByConfig, ) -from core.workflow.nodes.list_operator.exc import InvalidKeyError -from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func -from core.workflow.variables import ArrayFileSegment -from models.enums import UserFrom +from dify_graph.nodes.list_operator.exc import InvalidKeyError +from dify_graph.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func +from dify_graph.variables import ArrayFileSegment @pytest.fixture @@ -42,14 +42,18 @@ def list_operator_node(): } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() - graph_init_params.tenant_id = "test_tenant" - graph_init_params.app_id = "test_app" graph_init_params.workflow_id = "test_workflow" graph_init_params.graph_config = {} - graph_init_params.user_id = "test_user" - graph_init_params.user_from = UserFrom.ACCOUNT - graph_init_params.invoke_from = InvokeFrom.SERVICE_API graph_init_params.call_depth = 0 + graph_init_params.run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + } node = ListOperatorNode( id="test_node_id", diff --git a/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py index 0774348ac6..27d3848fb4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py @@ -3,13 +3,13 @@ from collections.abc import Generator from typing import Any import pytest - from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.entities import ToolCallResult from core.workflow.entities.tool_entities import ToolResultStatus -from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase from core.workflow.nodes.llm.node import LLMNode +from core.workflow.entities import ToolCallResult +from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase + class _StubModelInstance: """Minimal stub to satisfy _stream_llm_events signature.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index 47ef289ef3..4dfec5ef60 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -1,5 +1,5 @@ -from core.model_runtime.entities import ImagePromptMessageContent -from core.workflow.nodes.question_classifier import QuestionClassifierNodeData +from dify_graph.model_runtime.entities import ImagePromptMessageContent +from dify_graph.nodes.question_classifier import QuestionClassifierNodeData def test_init_question_classifier_node_data(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 8c7dc24868..b8f0e25e91 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -4,12 +4,12 @@ import time import pytest from pydantic import ValidationError as PydanticValidationError -from core.workflow.entities import GraphInitParams -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from tests.workflow_test_utils import build_test_graph_init_params def make_start_node(user_inputs, variables): @@ -32,11 +32,11 @@ def make_start_node(user_inputs, variables): return StartNode( id="start", config=config, - graph_init_params=GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, + tenant_id="tenant", + app_id="app", user_id="u", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 678691439f..11554169e1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,18 +8,18 @@ from unittest.mock import MagicMock, patch import pytest -from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities import GraphInitParams -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.segments import ArrayFileSegment +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segments import ArrayFileSegment +from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only - from core.workflow.nodes.tool.tool_node import ToolNode + from dify_graph.nodes.tool.tool_node import ToolNode @pytest.fixture @@ -31,7 +31,7 @@ def tool_node(monkeypatch) -> ToolNode: ops_stub.TraceTask = object # pragma: no cover - stub attribute monkeypatch.setitem(sys.modules, module_name, ops_stub) - from core.workflow.nodes.tool.tool_node import ToolNode + from dify_graph.nodes.tool.tool_node import ToolNode graph_config: dict[str, Any] = { "nodes": [ @@ -54,11 +54,11 @@ def tool_node(monkeypatch) -> ToolNode: "edges": [], } - init_params = GraphInitParams( - tenant_id="tenant-id", - app_id="app-id", + init_params = build_test_graph_init_params( workflow_id="workflow-id", graph_config=graph_config, + tenant_id="tenant-id", + app_id="app-id", user_id="user-id", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 8a52f963ef..2cd3a38fa6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -2,18 +2,18 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode -from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import ArrayStringVariable, StringVariable -from models.enums import UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.graph import Graph +from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.nodes.variable_assigner.v1 import VariableAssignerNode +from dify_graph.nodes.variable_assigner.v1.node_data import WriteMode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import ArrayStringVariable, StringVariable DEFAULT_NODE_ID = "node_id" @@ -43,13 +43,17 @@ def test_overwrite_string_variable(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -141,13 +145,17 @@ def test_append_variable_to_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -236,13 +244,17 @@ def test_clear_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py index 9a874337ed..a7673c5a14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py @@ -1,6 +1,6 @@ -from core.workflow.nodes.variable_assigner.v2.enums import Operation -from core.workflow.nodes.variable_assigner.v2.helpers import is_input_value_valid -from core.workflow.variables import SegmentType +from dify_graph.nodes.variable_assigner.v2.enums import Operation +from dify_graph.nodes.variable_assigner.v2.helpers import is_input_value_valid +from dify_graph.variables import SegmentType def test_is_input_value_valid_overwrite_array_string(): diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 5ed68fe8d0..5b285c2681 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -2,16 +2,16 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode -from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import ArrayStringVariable -from models.enums import UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.graph import Graph +from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode +from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import ArrayStringVariable DEFAULT_NODE_ID = "node_id" @@ -85,13 +85,17 @@ def test_remove_first_from_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -169,13 +173,17 @@ def test_remove_last_from_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -250,13 +258,17 @@ def test_remove_first_from_empty_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -331,13 +343,17 @@ def test_remove_last_from_empty_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -404,13 +420,17 @@ def test_node_factory_creates_variable_assigner_node(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 4fa9a01b61..410c4993e4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from core.workflow.nodes.trigger_webhook.entities import ( +from dify_graph.nodes.trigger_webhook.entities import ( ContentType, Method, WebhookBodyParameter, @@ -297,7 +297,7 @@ def test_webhook_body_parameter_edge_cases(): def test_webhook_data_inheritance(): """Test WebhookData inherits from BaseNodeData correctly.""" - from core.workflow.nodes.base import BaseNodeData + from dify_graph.nodes.base import BaseNodeData # Test that WebhookData is a subclass of BaseNodeData assert issubclass(WebhookData, BaseNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index 374d5183c8..f2273e441e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,7 +1,7 @@ import pytest -from core.workflow.nodes.base.exc import BaseNodeError -from core.workflow.nodes.trigger_webhook.exc import ( +from dify_graph.nodes.base.exc import BaseNodeError +from dify_graph.nodes.trigger_webhook.exc import ( WebhookConfigError, WebhookNodeError, WebhookNotFoundError, @@ -149,7 +149,7 @@ def test_webhook_error_attributes(): assert WebhookConfigError.__name__ == "WebhookConfigError" # Test that all error classes have proper __module__ - expected_module = "core.workflow.nodes.trigger_webhook.exc" + expected_module = "dify_graph.nodes.trigger_webhook.exc" assert WebhookNodeError.__module__ == expected_module assert WebhookTimeoutError.__module__ == expected_module assert WebhookNotFoundError.__module__ == expected_module diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index d8f6b41f89..c750e74182 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,21 +8,19 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.trigger_webhook.entities import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.nodes.trigger_webhook.entities import ( ContentType, Method, WebhookBodyParameter, WebhookData, ) -from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from core.workflow.runtime.graph_runtime_state import GraphRuntimeState -from core.workflow.runtime.variable_pool import VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom -from models.workflow import WorkflowType +from dify_graph.nodes.trigger_webhook.node import TriggerWebhookNode +from dify_graph.runtime.graph_runtime_state import GraphRuntimeState +from dify_graph.runtime.variable_pool import VariablePool +from dify_graph.system_variable import SystemVariable def create_webhook_node( @@ -37,14 +35,17 @@ def create_webhook_node( } graph_init_params = GraphInitParams( - tenant_id=tenant_id, - app_id="test-app", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test-workflow", graph_config={}, - user_id="test-user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": tenant_id, + "app_id": "test-app", + "user_id": "test-user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -129,8 +130,8 @@ def test_webhook_node_file_conversion_to_file_variable(): # Mock the file factory and variable factory with ( patch("factories.file_factory.build_from_mapping") as mock_file_factory, - patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, - patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + patch("dify_graph.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("dify_graph.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): # Setup mocks mock_file_obj = Mock() @@ -321,8 +322,8 @@ def test_webhook_node_file_conversion_mixed_parameters(): with ( patch("factories.file_factory.build_from_mapping") as mock_file_factory, - patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, - patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + patch("dify_graph.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("dify_graph.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): # Setup mocks for file mock_file_obj = Mock() @@ -389,8 +390,8 @@ def test_webhook_node_different_file_types(): with ( patch("factories.file_factory.build_from_mapping") as mock_file_factory, - patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, - patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + patch("dify_graph.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("dify_graph.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): # Setup mocks for all files mock_file_objs = [Mock() for _ in range(3)] diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 24d3740b99..df13bbb92f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -2,24 +2,22 @@ from unittest.mock import patch import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.nodes.trigger_webhook.entities import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.nodes.trigger_webhook.entities import ( ContentType, Method, WebhookBodyParameter, WebhookData, WebhookParameter, ) -from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from core.workflow.runtime.graph_runtime_state import GraphRuntimeState -from core.workflow.runtime.variable_pool import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import FileVariable, StringVariable -from models.enums import UserFrom -from models.workflow import WorkflowType +from dify_graph.nodes.trigger_webhook.node import TriggerWebhookNode +from dify_graph.runtime.graph_runtime_state import GraphRuntimeState +from dify_graph.runtime.variable_pool import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import FileVariable, StringVariable def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -30,14 +28,17 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) } graph_init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, workflow_id="1", graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py index 078ec5f6ab..e8ce6f60f7 100644 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -1,6 +1,6 @@ """Tests for workflow pause related enums and constants.""" -from core.workflow.enums import ( +from dify_graph.enums import ( WorkflowExecutionStatus, ) diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 93e7c9f68d..8023a0b594 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -4,9 +4,9 @@ from typing import Any import pytest from pydantic import ValidationError -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File -from core.workflow.system_variable import SystemVariable +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.system_variable import SystemVariable # Test data constants for SystemVariable serialization tests VALID_BASE_DATA: dict[str, Any] = { diff --git a/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py b/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py index 743fecaed0..b7a8f2551d 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py @@ -2,8 +2,8 @@ from typing import cast import pytest -from core.workflow.file.models import File, FileTransferMethod, FileType -from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView +from dify_graph.file.models import File, FileTransferMethod, FileType +from dify_graph.system_variable import SystemVariable, SystemVariableReadOnlyView class TestSystemVariableReadOnlyView: diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 7f2b080498..0fa0d26114 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -3,12 +3,12 @@ from collections import defaultdict import pytest -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import FileSegment, StringSegment -from core.workflow.variables.segments import ( +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import FileSegment, StringSegment +from dify_graph.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -19,7 +19,7 @@ from core.workflow.variables.segments import ( NoneSegment, ObjectSegment, ) -from core.workflow.variables.variables import ( +from dify_graph.variables.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 4a71692f1e..0aa6ec3f45 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -4,18 +4,18 @@ import pytest from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage -from core.workflow.constants import ( +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) -from core.workflow.file.enums import FileType -from core.workflow.file.models import File, FileTransferMethod -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.variables import StringVariable -from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.file.enums import FileType +from dify_graph.file.models import File, FileTransferMethod +from dify_graph.nodes.code.code_node import CodeNode +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.variables import StringVariable @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index 12b9bf5f14..9969c953e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,11 +2,10 @@ from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.workflow_entry import WorkflowEntry -from models.enums import UserFrom +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.runtime import GraphRuntimeState, VariablePool class TestWorkflowEntryRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/utils/test_condition.py b/api/tests/unit_tests/core/workflow/utils/test_condition.py index efedf88726..324ad5f674 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_condition.py +++ b/api/tests/unit_tests/core/workflow/utils/test_condition.py @@ -1,6 +1,6 @@ -from core.workflow.runtime import VariablePool -from core.workflow.utils.condition.entities import Condition -from core.workflow.utils.condition.processor import ConditionProcessor +from dify_graph.runtime import VariablePool +from dify_graph.utils.condition.entities import Condition +from dify_graph.utils.condition.processor import ConditionProcessor def test_number_formatting(): diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py index 83867e22e4..40df9de7fa 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -1,7 +1,7 @@ import dataclasses -from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.entities import VariableSelector +from dify_graph.nodes.base import variable_template_parser +from dify_graph.nodes.base.entities import VariableSelector def test_extract_selectors_from_template(): diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 87d02cb187..ce6b9232ce 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -7,8 +7,8 @@ import pytest from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.variables import ( +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -17,8 +17,8 @@ from core.workflow.variables import ( SecretVariable, StringVariable, ) -from core.workflow.variables.exc import VariableError -from core.workflow.variables.segments import ( +from dify_graph.variables.exc import VariableError +from dify_graph.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -33,7 +33,7 @@ from core.workflow.variables.segments import ( Segment, StringSegment, ) -from core.workflow.variables.types import SegmentType +from dify_graph.variables.types import SegmentType from factories import variable_factory from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py index bd86c13a2c..3fff54f487 100644 --- a/api/tests/unit_tests/libs/_human_input/support.py +++ b/api/tests/unit_tests/libs/_human_input/support.py @@ -4,8 +4,8 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Any -from core.workflow.nodes.human_input.entities import FormInput -from core.workflow.nodes.human_input.enums import TimeoutUnit +from dify_graph.nodes.human_input.entities import FormInput +from dify_graph.nodes.human_input.enums import TimeoutUnit # Exceptions diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py index 15e7d41e85..82598c5c6d 100644 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -6,11 +6,11 @@ from datetime import datetime, timedelta import pytest -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormInput, UserAction, ) -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py index 962eeb9e11..5d14b5eb4e 100644 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -6,11 +6,11 @@ from datetime import datetime, timedelta import pytest -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormInput, UserAction, ) -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py new file mode 100644 index 0000000000..248aa0b145 --- /dev/null +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -0,0 +1,145 @@ +import time + +import pytest + +from libs.broadcast_channel.redis.streams_channel import ( + StreamsBroadcastChannel, + StreamsTopic, + _StreamsSubscription, +) + + +class FakeStreamsRedis: + """Minimal in-memory Redis Streams stub for unit tests. + + - Stores entries per key as [(id, {b"data": bytes}), ...] + - xadd appends entries and returns an auto-increment id like "1-0" + - xread returns entries strictly greater than last_id + - expire is recorded but has no effect on behavior + """ + + def __init__(self) -> None: + self._store: dict[str, list[tuple[str, dict]]] = {} + self._next_id: dict[str, int] = {} + self._expire_calls: dict[str, int] = {} + + # Publisher API + def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + """Append entry to stream; accept optional maxlen for API compatibility. + + The test double ignores maxlen trimming semantics; only records the entry. + """ + n = self._next_id.get(key, 0) + 1 + self._next_id[key] = n + entry_id = f"{n}-0" + self._store.setdefault(key, []).append((entry_id, fields)) + return entry_id + + def expire(self, key: str, seconds: int) -> None: + self._expire_calls[key] = self._expire_calls.get(key, 0) + 1 + + # Consumer API + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + # Expect a single key + assert len(streams) == 1 + key, last_id = next(iter(streams.items())) + entries = self._store.get(key, []) + + # Find position strictly greater than last_id + start_idx = 0 + if last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start_idx = i + 1 + break + if start_idx >= len(entries): + # Simulate blocking wait (bounded) if requested + if block and block > 0: + time.sleep(min(0.01, block / 1000.0)) + return [] + + end_idx = len(entries) if count is None else min(len(entries), start_idx + count) + batch = entries[start_idx:end_idx] + return [(key, batch)] + + +@pytest.fixture +def fake_redis() -> FakeStreamsRedis: + return FakeStreamsRedis() + + +@pytest.fixture +def streams_channel(fake_redis: FakeStreamsRedis) -> StreamsBroadcastChannel: + return StreamsBroadcastChannel(fake_redis, retention_seconds=60) + + +class TestStreamsBroadcastChannel: + def test_topic_creation(self, streams_channel: StreamsBroadcastChannel, fake_redis: FakeStreamsRedis): + topic = streams_channel.topic("alpha") + assert isinstance(topic, StreamsTopic) + assert topic._client is fake_redis + assert topic._topic == "alpha" + assert topic._key == "stream:alpha" + + def test_publish_calls_xadd_and_expire( + self, + streams_channel: StreamsBroadcastChannel, + fake_redis: FakeStreamsRedis, + ): + topic = streams_channel.topic("beta") + payload = b"hello" + topic.publish(payload) + # One entry stored under stream key (bytes key for payload field) + assert fake_redis._store["stream:beta"][0][1] == {b"data": payload} + # Expire called after publish + assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + + +class TestStreamsSubscription: + def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("gamma") + # Pre-publish events before subscribing (late subscriber) + topic.publish(b"e1") + topic.publish(b"e2") + + sub = topic.subscribe() + assert isinstance(sub, _StreamsSubscription) + + received: list[bytes] = [] + with sub: + # Give listener thread a moment to xread + time.sleep(0.05) + # Drain using receive() to avoid indefinite iteration in tests + for _ in range(5): + msg = sub.receive(timeout=0.1) + if msg is None: + break + received.append(msg) + + assert received == [b"e1", b"e2"] + + def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("delta") + sub = topic.subscribe() + with sub: + # No messages yet + assert sub.receive(timeout=0.05) is None + + def test_close_stops_listener(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("epsilon") + sub = topic.subscribe() + with sub: + # Listener running; now close and ensure no crash + sub.close() + # After close, receive should raise SubscriptionClosedError + from libs.broadcast_channel.exc import SubscriptionClosedError + + with pytest.raises(SubscriptionClosedError): + sub.receive() + + def test_no_expire_when_zero_retention(self, fake_redis: FakeStreamsRedis): + channel = StreamsBroadcastChannel(fake_redis, retention_seconds=0) + topic = channel.topic("zeta") + topic.publish(b"payload") + # No expire recorded when retention is disabled + assert fake_redis._expire_calls.get("stream:zeta") is None diff --git a/api/tests/unit_tests/libs/test_cron_compatibility.py b/api/tests/unit_tests/libs/test_cron_compatibility.py index 6f3a94f6dc..61103d7935 100644 --- a/api/tests/unit_tests/libs/test_cron_compatibility.py +++ b/api/tests/unit_tests/libs/test_cron_compatibility.py @@ -294,7 +294,7 @@ class TestFrontendBackendIntegration(unittest.TestCase): def test_schedule_service_integration(self): """Test integration with ScheduleService patterns.""" - from core.workflow.nodes.trigger_schedule.entities import VisualConfig + from dify_graph.nodes.trigger_schedule.entities import VisualConfig from services.trigger.schedule_service import ScheduleService # Test enhanced syntax through visual config conversion diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 8b96c62dc9..6c619dcf98 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1204,7 +1204,7 @@ class TestConversationStatusCount: def test_status_count_batch_loading_implementation(self): """Test that status_count uses batch loading instead of N+1 queries.""" # Arrange - from core.workflow.enums import WorkflowExecutionStatus + from dify_graph.enums import WorkflowExecutionStatus app_id = str(uuid4()) conversation_id = str(uuid4()) @@ -1411,7 +1411,7 @@ class TestConversationStatusCount: def test_status_count_paused(self): """Test status_count includes paused workflow runs.""" # Arrange - from core.workflow.enums import WorkflowExecutionStatus + from dify_graph.enums import WorkflowExecutionStatus app_id = str(uuid4()) conversation_id = str(uuid4()) diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index d44aa56488..7d7674da3c 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,6 +1,6 @@ from uuid import uuid4 -from core.workflow.variables import SegmentType +from dify_graph.variables import SegmentType from factories import variable_factory from models import ConversationVariable diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 544693da34..f3b72aa128 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -4,10 +4,10 @@ from unittest import mock from uuid import uuid4 from constants import HIDDEN_VALUE -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File -from core.workflow.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from core.workflow.variables.segments import IntegerSegment, Segment +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from dify_graph.variables.segments import IntegerSegment, Segment from factories.variable_factory import build_segment from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index 9907cf05c0..f66f0b657d 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -14,7 +14,7 @@ from uuid import uuid4 import pytest -from core.workflow.enums import ( +from dify_graph.enums import ( NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 4b5b3b318c..3707ed90be 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -6,9 +6,9 @@ from unittest.mock import Mock, patch import pytest -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType -from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus +from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType +from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowPause as WorkflowPauseModel from models.workflow import WorkflowPauseReason diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index f5428b46ff..8daf91c538 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -6,11 +6,11 @@ from datetime import UTC, datetime, timedelta from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain from core.entities.execution_extra_content import HumanInputFormSubmissionData -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormDefinition, UserAction, ) -from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.enums import HumanInputFormStatus from models.execution_extra_content import HumanInputContent as HumanInputContentModel from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index 5cba43714a..06703b8e38 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -12,17 +12,17 @@ import pytest from pytest_mock import MockerFixture from sqlalchemy.orm import Session, sessionmaker -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities import ( +from dify_graph.entities import ( WorkflowNodeExecution, ) -from core.workflow.enums import ( +from dify_graph.enums import ( NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from models.account import Account, Tenant from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py index 5539856083..95a7751273 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py @@ -11,8 +11,8 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution -from core.workflow.enums import NodeType +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution +from dify_graph.enums import NodeType from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 4923e29d73..6829691507 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -111,7 +111,7 @@ from unittest.mock import Mock, patch import pytest from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index 57364142ad..afc3b29fca 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -545,7 +545,7 @@ class TestExternalDatasetServiceProcessExternalApi: params={}, ) - from core.workflow.nodes.http_request.exc import InvalidHttpMethodError + from dify_graph.nodes.http_request.exc import InvalidHttpMethodError with pytest.raises(InvalidHttpMethodError): ExternalDatasetService.process_external_api(settings, files=None) diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 635c86a14b..dcd6785464 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1125,6 +1125,38 @@ class TestRegisterService: mock_create_workspace.assert_called_once_with(account=mock_account) mock_join_default_workspace.assert_not_called() + def test_create_account_and_tenant_still_calls_default_workspace_join_when_workspace_creation_fails( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should still be attempted when personal workspace creation fails.""" + from services.errors.workspace import WorkSpaceNotAllowedCreateError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_workspace.side_effect = WorkSpaceNotAllowedCreateError() + + with pytest.raises(WorkSpaceNotAllowedCreateError): + AccountService.create_account_and_tenant( + email="test@example.com", + name="Test User", + interface_language="en-US", + password=None, + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies): """Test successful account registration.""" # Setup mocks @@ -1235,6 +1267,84 @@ class TestRegisterService: mock_join_default_workspace.assert_not_called() + def test_register_still_calls_default_workspace_join_when_personal_workspace_creation_fails( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should run even when personal workspace creation raises.""" + from services.errors.workspace import WorkSpaceNotAllowedCreateError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_tenant.side_effect = WorkSpaceNotAllowedCreateError() + + with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + mock_db_dependencies["db"].session.commit.assert_not_called() + + def test_register_still_calls_default_workspace_join_when_workspace_limit_exceeded( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should run before propagating workspace-limit registration failure.""" + from services.errors.workspace import WorkspacesLimitExceededError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_tenant.side_effect = WorkspacesLimitExceededError() + + with pytest.raises(AccountRegisterError, match="Registration failed:"): + RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + mock_db_dependencies["db"].session.commit.assert_not_called() + def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies): """Test account registration with OAuth integration.""" # Setup mocks diff --git a/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py new file mode 100644 index 0000000000..e66d52f66b --- /dev/null +++ b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py @@ -0,0 +1,197 @@ +import json +import uuid +from collections import defaultdict, deque + +import pytest + +from core.app.apps.message_generator import MessageGenerator +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +# ----------------------------- +# Fakes for Redis Pub/Sub flow +# ----------------------------- +class _FakePubSub: + def __init__(self, store: dict[str, deque[bytes]]): + self._store = store + self._subs: set[str] = set() + self._closed = False + + def subscribe(self, topic: str) -> None: + self._subs.add(topic) + + def unsubscribe(self, topic: str) -> None: + self._subs.discard(topic) + + def close(self) -> None: + self._closed = True + + def get_message(self, ignore_subscribe_messages: bool = True, timeout: int | float | None = 1): + # simulate a non-blocking poll; return first available + if self._closed: + return None + for t in list(self._subs): + q = self._store.get(t) + if q and len(q) > 0: + payload = q.popleft() + return {"type": "message", "channel": t, "data": payload} + # no message + return None + + +class _FakeRedisClient: + def __init__(self, store: dict[str, deque[bytes]]): + self._store = store + + def pubsub(self): + return _FakePubSub(self._store) + + def publish(self, topic: str, payload: bytes) -> None: + self._store.setdefault(topic, deque()).append(payload) + + +# ------------------------------------ +# Fakes for Redis Streams (XADD/XREAD) +# ------------------------------------ +class _FakeStreams: + def __init__(self) -> None: + # key -> list[(id, {field: value})] + self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list) + self._seq: dict[str, int] = defaultdict(int) + + def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + # maxlen is accepted for API compatibility with redis-py; ignored in this test double + self._seq[key] += 1 + eid = f"{self._seq[key]}-0" + self._data[key].append((eid, fields)) + return eid + + def expire(self, key: str, seconds: int) -> None: + # no-op for tests + return None + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + assert len(streams) == 1 + key, last_id = next(iter(streams.items())) + entries = self._data.get(key, []) + start = 0 + if last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start = i + 1 + break + if start >= len(entries): + return [] + end = len(entries) if count is None else min(len(entries), start + count) + return [(key, entries[start:end])] + + +@pytest.fixture +def _patch_get_channel_streams(monkeypatch): + from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + fake = _FakeStreams() + chan = StreamsBroadcastChannel(fake, retention_seconds=60) + + def _get_channel(): + return chan + + # Patch both the source and the imported alias used by MessageGenerator + monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan) + monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan) + # Ensure AppGenerateService sees streams mode + import services.app_generate_service as ags + + monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams", raising=False) + + +@pytest.fixture +def _patch_get_channel_pubsub(monkeypatch): + from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel + + store: dict[str, deque[bytes]] = defaultdict(deque) + client = _FakeRedisClient(store) + chan = RedisBroadcastChannel(client) + + def _get_channel(): + return chan + + # Patch both the source and the imported alias used by MessageGenerator + monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan) + monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan) + # Ensure AppGenerateService sees pubsub mode + import services.app_generate_service as ags + + monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub", raising=False) + + +def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]): + # Publish events to the same topic used by MessageGenerator + topic = MessageGenerator.get_response_topic(app_mode, run_id) + for ev in events: + topic.publish(json.dumps(ev).encode()) + + +@pytest.mark.usefixtures("_patch_get_channel_streams") +def test_streams_full_flow_prepublish_and_replay(): + app_mode = AppMode.WORKFLOW + run_id = str(uuid.uuid4()) + + # Build start_task that publishes two events immediately + events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] + + def start_task(): + _publish_events(app_mode, run_id, events) + + on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) + + # Start retrieving BEFORE subscription is established; in streams mode, we also started immediately + gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) + + received = [] + for msg in gen: + if isinstance(msg, str): + # skip ping events + continue + received.append(msg) + if msg.get("event") == "workflow_finished": + break + + assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"] + + +@pytest.mark.usefixtures("_patch_get_channel_pubsub") +def test_pubsub_full_flow_start_on_subscribe_gated(monkeypatch): + # Speed up any potential timer if it accidentally triggers + monkeypatch.setattr("services.app_generate_service.SSE_TASK_START_FALLBACK_MS", 50) + + app_mode = AppMode.WORKFLOW + run_id = str(uuid.uuid4()) + + published_order: list[str] = [] + + def start_task(): + # When called (on subscribe), publish both events + events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] + _publish_events(app_mode, run_id, events) + published_order.extend([e["event"] for e in events]) + + on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) + + # Producer not started yet; only when subscribe happens + assert published_order == [] + + gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) + + received = [] + for msg in gen: + if isinstance(msg, str): + continue + received.append(msg) + if msg.get("event") == "workflow_finished": + break + + # Verify publish happened and consumer received in order + assert published_order == ["workflow_started", "workflow_finished"] + assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"] diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py index 69766188f3..abff48347e 100644 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -1,13 +1,10 @@ import datetime - -# Mock redis_client before importing dataset_service -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, patch import pytest from models.dataset import Dataset, Document from services.dataset_service import DocumentService -from services.errors.document import DocumentIndexingError from tests.unit_tests.conftest import redis_mock @@ -48,7 +45,6 @@ class DocumentBatchUpdateTestDataFactory: document.indexing_status = indexing_status document.completed_at = completed_at or datetime.datetime.now() - # Set default values for optional fields document.disabled_at = None document.disabled_by = None document.archived_at = None @@ -59,32 +55,9 @@ class DocumentBatchUpdateTestDataFactory: setattr(document, key, value) return document - @staticmethod - def create_multiple_documents( - document_ids: list[str], enabled: bool = True, archived: bool = False, indexing_status: str = "completed" - ) -> list[Mock]: - """Create multiple mock documents with specified attributes.""" - documents = [] - for doc_id in document_ids: - doc = DocumentBatchUpdateTestDataFactory.create_document_mock( - document_id=doc_id, - name=f"document_{doc_id}.pdf", - enabled=enabled, - archived=archived, - indexing_status=indexing_status, - ) - documents.append(doc) - return documents - class TestDatasetServiceBatchUpdateDocumentStatus: - """ - Comprehensive unit tests for DocumentService.batch_update_document_status method. - - This test suite covers all supported actions (enable, disable, archive, un_archive), - error conditions, edge cases, and validates proper interaction with Redis cache, - database operations, and async task triggers. - """ + """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" @pytest.fixture def mock_document_service_dependencies(self): @@ -104,697 +77,24 @@ class TestDatasetServiceBatchUpdateDocumentStatus: "current_time": current_time, } - @pytest.fixture - def mock_async_task_dependencies(self): - """Mock setup for async task dependencies.""" - with ( - patch("services.dataset_service.add_document_to_index_task") as mock_add_task, - patch("services.dataset_service.remove_document_from_index_task") as mock_remove_task, - ): - yield {"add_task": mock_add_task, "remove_task": mock_remove_task} - - def _assert_document_enabled(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was enabled correctly.""" - assert document.enabled == True - assert document.disabled_at is None - assert document.disabled_by is None - assert document.updated_at == current_time - - def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was disabled correctly.""" - assert document.enabled == False - assert document.disabled_at == current_time - assert document.disabled_by == user_id - assert document.updated_at == current_time - - def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was archived correctly.""" - assert document.archived == True - assert document.archived_at == current_time - assert document.archived_by == user_id - assert document.updated_at == current_time - - def _assert_document_unarchived(self, document: Mock): - """Helper method to verify document was unarchived correctly.""" - assert document.archived == False - assert document.archived_at is None - assert document.archived_by is None - - def _assert_redis_cache_operations(self, document_ids: list[str], action: str = "setex"): - """Helper method to verify Redis cache operations.""" - if action == "setex": - expected_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] - redis_mock.setex.assert_has_calls(expected_calls) - elif action == "get": - expected_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] - redis_mock.get.assert_has_calls(expected_calls) - - def _assert_async_task_calls(self, mock_task, document_ids: list[str], task_type: str): - """Helper method to verify async task calls.""" - expected_calls = [call(doc_id) for doc_id in document_ids] - if task_type in {"add", "remove"}: - mock_task.delay.assert_has_calls(expected_calls) - - # ==================== Enable Document Tests ==================== - - def test_batch_update_enable_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful enabling of disabled documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create disabled documents - disabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=False) - mock_document_service_dependencies["get_document"].side_effect = disabled_docs - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to enable documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2"], action="enable", user=user - ) - - # Verify document attributes were updated correctly - for doc in disabled_docs: - self._assert_document_enabled(doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache operations - self._assert_redis_cache_operations(["doc-1", "doc-2"], "get") - self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex") - - # Verify async tasks were triggered for indexing - self._assert_async_task_calls(mock_async_task_dependencies["add_task"], ["doc-1", "doc-2"], "add") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 2 - assert mock_db.commit.call_count == 1 - - def test_batch_update_enable_already_enabled_document_skipped(self, mock_document_service_dependencies): - """Test enabling documents that are already enabled.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already enabled document - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = enabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to enable already enabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # ==================== Disable Document Tests ==================== - - def test_batch_update_disable_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful disabling of enabled and completed documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create enabled documents - enabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=True) - mock_document_service_dependencies["get_document"].side_effect = enabled_docs - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to disable documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2"], action="disable", user=user - ) - - # Verify document attributes were updated correctly - for doc in enabled_docs: - self._assert_document_disabled(doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache operations for indexing prevention - self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex") - - # Verify async tasks were triggered to remove from index - self._assert_async_task_calls(mock_async_task_dependencies["remove_task"], ["doc-1", "doc-2"], "remove") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 2 - assert mock_db.commit.call_count == 1 - - def test_batch_update_disable_already_disabled_document_skipped( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test disabling documents that are already disabled.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already disabled document - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False) - mock_document_service_dependencies["get_document"].return_value = disabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to disable already disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="disable", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # Verify no async tasks were triggered (document was skipped) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - def test_batch_update_disable_non_completed_document_error(self, mock_document_service_dependencies): - """Test that DocumentIndexingError is raised when trying to disable non-completed documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create a document that's not completed - non_completed_doc = DocumentBatchUpdateTestDataFactory.create_document_mock( - enabled=True, - indexing_status="indexing", # Not completed - completed_at=None, # Not completed - ) - mock_document_service_dependencies["get_document"].return_value = non_completed_doc - - # Verify that DocumentIndexingError is raised - with pytest.raises(DocumentIndexingError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="disable", user=user - ) - - # Verify error message indicates document is not completed - assert "is not completed" in str(exc_info.value) - - # ==================== Archive Document Tests ==================== - - def test_batch_update_archive_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful archiving of unarchived documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create unarchived enabled document - unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False) - mock_document_service_dependencies["get_document"].return_value = unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to archive documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="archive", user=user - ) - - # Verify document attributes were updated correctly - self._assert_document_archived(unarchived_doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache was set (because document was enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify async task was triggered to remove from index (because enabled) - mock_async_task_dependencies["remove_task"].delay.assert_called_once_with("doc-1") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_batch_update_archive_already_archived_document_skipped(self, mock_document_service_dependencies): - """Test archiving documents that are already archived.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already archived document - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to archive already archived document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-3"], action="archive", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - def test_batch_update_archive_disabled_document_no_index_removal( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test archiving disabled documents (should not trigger index removal).""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Set up disabled, unarchived document - disabled_unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=False) - mock_document_service_dependencies["get_document"].return_value = disabled_unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Archive the disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="archive", user=user - ) - - # Verify document was archived - self._assert_document_archived( - disabled_unarchived_doc, user.id, mock_document_service_dependencies["current_time"] - ) - - # Verify no Redis cache was set (document is disabled) - redis_mock.setex.assert_not_called() - - # Verify no index removal task was triggered (document is disabled) - mock_async_task_dependencies["remove_task"].delay.assert_not_called() - - # Verify database operations still occurred - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # ==================== Unarchive Document Tests ==================== - - def test_batch_update_unarchive_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful unarchiving of archived documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock archived document - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to unarchive documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify document attributes were updated correctly - self._assert_document_unarchived(archived_doc) - assert archived_doc.updated_at == mock_document_service_dependencies["current_time"] - - # Verify Redis cache was set (because document is enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify async task was triggered to add back to index (because enabled) - mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_batch_update_unarchive_already_unarchived_document_skipped( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test unarchiving documents that are already unarchived.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already unarchived document - unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False) - mock_document_service_dependencies["get_document"].return_value = unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to unarchive already unarchived document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # Verify no async tasks were triggered (document was skipped) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - def test_batch_update_unarchive_disabled_document_no_index_addition( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test unarchiving disabled documents (should not trigger index addition).""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock archived but disabled document - archived_disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_disabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Unarchive the disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify document was unarchived - self._assert_document_unarchived(archived_disabled_doc) - assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"] - - # Verify no Redis cache was set (document is disabled) - redis_mock.setex.assert_not_called() - - # Verify no index addition task was triggered (document is disabled) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - # Verify database operations still occurred - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # ==================== Error Handling Tests ==================== - - def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_document_service_dependencies): - """Test that DocumentIndexingError is raised when documents are currently being indexed.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock enabled document - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = enabled_doc - - # Set up mock to indicate document is being indexed - redis_mock.reset_mock() - redis_mock.get.return_value = "indexing" - - # Verify that DocumentIndexingError is raised - with pytest.raises(DocumentIndexingError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify error message contains document name - assert "test_document.pdf" in str(exc_info.value) - assert "is being indexed" in str(exc_info.value) - - # Verify Redis cache was checked - redis_mock.get.assert_called_once_with("document_doc-1_indexing") - def test_batch_update_invalid_action_error(self, mock_document_service_dependencies): """Test that ValueError is raised when an invalid action is provided.""" dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() user = DocumentBatchUpdateTestDataFactory.create_user_mock() - # Create mock document doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) mock_document_service_dependencies["get_document"].return_value = doc - # Reset module-level Redis mock redis_mock.reset_mock() redis_mock.get.return_value = None - # Test with invalid action invalid_action = "invalid_action" with pytest.raises(ValueError) as exc_info: DocumentService.batch_update_document_status( dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user ) - # Verify error message contains the invalid action assert invalid_action in str(exc_info.value) assert "Invalid action" in str(exc_info.value) - # Verify no Redis operations occurred redis_mock.setex.assert_not_called() - - def test_batch_update_async_task_error_handling( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test handling of async task errors during batch operations.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock disabled document - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False) - mock_document_service_dependencies["get_document"].return_value = disabled_doc - - # Mock async task to raise an exception - mock_async_task_dependencies["add_task"].delay.side_effect = Exception("Celery task error") - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Verify that async task error is propagated - with pytest.raises(Exception) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify error message - assert "Celery task error" in str(exc_info.value) - - # Verify database operations completed successfully - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # Verify Redis cache was set successfully - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify document was updated - self._assert_document_enabled(disabled_doc, user.id, mock_document_service_dependencies["current_time"]) - - # ==================== Edge Case Tests ==================== - - def test_batch_update_empty_document_list(self, mock_document_service_dependencies): - """Test batch operations with an empty document ID list.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Call method with empty document list - result = DocumentService.batch_update_document_status( - dataset=dataset, document_ids=[], action="enable", user=user - ) - - # Verify no document lookups were performed - mock_document_service_dependencies["get_document"].assert_not_called() - - # Verify method returns None (early return) - assert result is None - - def test_batch_update_document_not_found_skipped(self, mock_document_service_dependencies): - """Test behavior when some documents don't exist in the database.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Mock document service to return None (document not found) - mock_document_service_dependencies["get_document"].return_value = None - - # Call method with non-existent document ID - # This should not raise an error, just skip the missing document - try: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["non-existent-doc"], action="enable", user=user - ) - except Exception as e: - pytest.fail(f"Method should not raise exception for missing documents: {e}") - - # Verify document lookup was attempted - mock_document_service_dependencies["get_document"].assert_called_once_with(dataset.id, "non-existent-doc") - - def test_batch_update_mixed_document_states_and_actions( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test batch operations on documents with mixed states and various scenarios.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create documents in various states - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-2", enabled=True) - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-3", enabled=True, archived=True) - - # Mix of different document states - documents = [disabled_doc, enabled_doc, archived_doc] - mock_document_service_dependencies["get_document"].side_effect = documents - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform enable operation on mixed state documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=user - ) - - # Verify only the disabled document was processed - # (enabled and archived documents should be skipped for enable action) - - # Only one add should occur (for the disabled document that was enabled) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - # Only one commit should occur - mock_db.commit.assert_called_once() - - # Only one Redis setex should occur (for the document that was enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Only one async task should be triggered (for the document that was enabled) - mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1") - - # ==================== Performance Tests ==================== - - def test_batch_update_large_document_list_performance( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test batch operations with a large number of documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create large list of document IDs - document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents - - # Create mock documents - mock_documents = DocumentBatchUpdateTestDataFactory.create_multiple_documents( - document_ids, - enabled=False, # All disabled, will be enabled - ) - mock_document_service_dependencies["get_document"].side_effect = mock_documents - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform batch enable operation - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=document_ids, action="enable", user=user - ) - - # Verify all documents were processed - assert mock_document_service_dependencies["get_document"].call_count == 100 - - # Verify all documents were updated - for mock_doc in mock_documents: - self._assert_document_enabled(mock_doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 100 - assert mock_db.commit.call_count == 1 - - # Verify Redis cache operations occurred for each document - assert redis_mock.setex.call_count == 100 - - # Verify async tasks were triggered for each document - assert mock_async_task_dependencies["add_task"].delay.call_count == 100 - - # Verify correct Redis cache keys were set - expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)] - redis_mock.setex.assert_has_calls(expected_redis_calls) - - # Verify correct async task calls - expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)] - mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) - - def test_batch_update_mixed_document_states_complex_scenario( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test complex batch operations with documents in various states.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create documents in various states - doc1 = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) # Will be enabled - doc2 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-2", enabled=True - ) # Already enabled, will be skipped - doc3 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-3", enabled=True - ) # Already enabled, will be skipped - doc4 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-4", enabled=True - ) # Not affected by enable action - doc5 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-5", enabled=True, archived=True - ) # Not affected by enable action - doc6 = None # Non-existent, will be skipped - - mock_document_service_dependencies["get_document"].side_effect = [doc1, doc2, doc3, doc4, doc5, doc6] - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform mixed batch operations - DocumentService.batch_update_document_status( - dataset=dataset, - document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"], - action="enable", # This will only affect doc1 - user=user, - ) - - # Verify document 1 was enabled - self._assert_document_enabled(doc1, user.id, mock_document_service_dependencies["current_time"]) - - # Verify other documents were skipped appropriately - assert doc2.enabled == True # No change - assert doc3.enabled == True # No change - assert doc4.enabled == True # No change - assert doc5.enabled == True # No change - - # Verify database commits occurred for processed documents - # Only doc1 should be added (others were skipped, doc6 doesn't exist) - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 1 - assert mock_db.commit.call_count == 1 - - # Verify Redis cache operations occurred for processed documents - # Only doc1 should have Redis operations - assert redis_mock.setex.call_count == 1 - - # Verify async tasks were triggered for processed documents - # Only doc1 should trigger tasks - assert mock_async_task_dependencies["add_task"].delay.call_count == 1 - - # Verify correct Redis cache keys were set - expected_redis_calls = [call("document_doc-1_indexing", 600, 1)] - redis_mock.setex.assert_has_calls(expected_redis_calls) - - # Verify correct async task calls - expected_task_calls = [call("doc-1")] - mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) diff --git a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py index 7c7a70f962..f8c5270656 100644 --- a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py @@ -1,726 +1,39 @@ -""" -Comprehensive unit tests for DatasetService creation methods. +"""Unit tests for non-SQL validation paths in DatasetService dataset creation.""" -This test suite covers: -- create_empty_dataset for internal datasets -- create_empty_dataset for external datasets -- create_empty_rag_pipeline_dataset -- Error conditions and edge cases -""" - -from unittest.mock import Mock, create_autospec, patch +from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from core.model_runtime.entities.model_entities import ModelType -from models.account import Account -from models.dataset import Dataset, Pipeline from services.dataset_service import DatasetService -from services.entities.knowledge_entities.knowledge_entities import RetrievalModel -from services.entities.knowledge_entities.rag_pipeline_entities import ( - IconInfo, - RagPipelineDatasetCreateEntity, -) -from services.errors.dataset import DatasetNameDuplicateError +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity -class DatasetCreateTestDataFactory: - """Factory class for creating test data and mock objects for dataset creation tests.""" - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock account.""" - account = create_autospec(Account, instance=True) - account.id = account_id - account.current_tenant_id = tenant_id - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock: - """Create a mock embedding model.""" - embedding_model = Mock() - embedding_model.model_name = model - embedding_model.provider = provider - return embedding_model - - @staticmethod - def create_retrieval_model_mock() -> Mock: - """Create a mock retrieval model.""" - retrieval_model = Mock(spec=RetrievalModel) - retrieval_model.model_dump.return_value = { - "search_method": "semantic_search", - "top_k": 2, - "score_threshold": 0.0, - } - retrieval_model.reranking_model = None - return retrieval_model - - @staticmethod - def create_external_knowledge_api_mock(api_id: str = "api-123", **kwargs) -> Mock: - """Create a mock external knowledge API.""" - api = Mock() - api.id = api_id - for key, value in kwargs.items(): - setattr(api, key, value) - return api - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - name: str = "Test Dataset", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock dataset.""" - dataset = create_autospec(Dataset, instance=True) - dataset.id = dataset_id - dataset.name = name - dataset.tenant_id = tenant_id - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_pipeline_mock( - pipeline_id: str = "pipeline-123", - name: str = "Test Pipeline", - **kwargs, - ) -> Mock: - """Create a mock pipeline.""" - pipeline = Mock(spec=Pipeline) - pipeline.id = pipeline_id - pipeline.name = name - for key, value in kwargs.items(): - setattr(pipeline, key, value) - return pipeline - - -class TestDatasetServiceCreateEmptyDataset: - """ - Comprehensive unit tests for DatasetService.create_empty_dataset method. - - This test suite covers: - - Internal dataset creation (vendor provider) - - External dataset creation - - High quality indexing technique with embedding models - - Economy indexing technique - - Retrieval model configuration - - Error conditions (duplicate names, missing external knowledge IDs) - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.ModelManager") as mock_model_manager, - patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, - patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, - patch("services.dataset_service.ExternalDatasetService") as mock_external_service, - ): - yield { - "db_session": mock_db, - "model_manager": mock_model_manager, - "check_embedding": mock_check_embedding, - "check_reranking": mock_check_reranking, - "external_service": mock_external_service, - } - - # ==================== Internal Dataset Creation Tests ==================== - - def test_create_internal_dataset_basic_success(self, mock_dataset_service_dependencies): - """Test successful creation of basic internal dataset.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Test Dataset" - description = "Test description" - - # Mock database query to return None (no duplicate name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock database session operations - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=description, - indexing_technique=None, - account=account, - ) - - # Assert - assert result is not None - assert result.name == name - assert result.description == description - assert result.tenant_id == tenant_id - assert result.created_by == account.id - assert result.updated_by == account.id - assert result.provider == "vendor" - assert result.permission == "only_me" - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_economy_indexing(self, mock_dataset_service_dependencies): - """Test successful creation of internal dataset with economy indexing.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Economy Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="economy", - account=account, - ) - - # Assert - assert result.indexing_technique == "economy" - assert result.embedding_model_provider is None - assert result.embedding_model is None - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_high_quality_indexing_default_embedding( - self, mock_dataset_service_dependencies - ): - """Test creation with high_quality indexing using default embedding model.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "High Quality Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock() - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_default_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - ) - - # Assert - assert result.indexing_technique == "high_quality" - assert result.embedding_model_provider == embedding_model.provider - assert result.embedding_model == embedding_model.model_name - mock_model_manager_instance.get_default_model_instance.assert_called_once_with( - tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_high_quality_indexing_custom_embedding( - self, mock_dataset_service_dependencies - ): - """Test creation with high_quality indexing using custom embedding model.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Custom Embedding Dataset" - embedding_provider = "openai" - embedding_model_name = "text-embedding-3-small" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock( - model=embedding_model_name, provider=embedding_provider - ) - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - embedding_model_provider=embedding_provider, - embedding_model_name=embedding_model_name, - ) - - # Assert - assert result.indexing_technique == "high_quality" - assert result.embedding_model_provider == embedding_provider - assert result.embedding_model == embedding_model_name - mock_dataset_service_dependencies["check_embedding"].assert_called_once_with( - tenant_id, embedding_provider, embedding_model_name - ) - mock_model_manager_instance.get_model_instance.assert_called_once_with( - tenant_id=tenant_id, - provider=embedding_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model_name, - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_retrieval_model(self, mock_dataset_service_dependencies): - """Test creation with retrieval model configuration.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Retrieval Model Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock retrieval model - retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock() - retrieval_model_dict = {"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0} - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - retrieval_model=retrieval_model, - ) - - # Assert - assert result.retrieval_model == retrieval_model_dict - retrieval_model.model_dump.assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_retrieval_model_reranking(self, mock_dataset_service_dependencies): - """Test creation with retrieval model that includes reranking.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Reranking Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock() - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_default_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - # Mock retrieval model with reranking - reranking_model = Mock() - reranking_model.reranking_provider_name = "cohere" - reranking_model.reranking_model_name = "rerank-english-v3.0" - - retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock() - retrieval_model.reranking_model = reranking_model - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - retrieval_model=retrieval_model, - ) - - # Assert - mock_dataset_service_dependencies["check_reranking"].assert_called_once_with( - tenant_id, "cohere", "rerank-english-v3.0" - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_custom_permission(self, mock_dataset_service_dependencies): - """Test creation with custom permission setting.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Custom Permission Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - permission="all_team_members", - ) - - # Assert - assert result.permission == "all_team_members" - mock_db.commit.assert_called_once() - - # ==================== External Dataset Creation Tests ==================== - - def test_create_external_dataset_success(self, mock_dataset_service_dependencies): - """Test successful creation of external dataset.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "external-api-123" - external_knowledge_id = "external-knowledge-456" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API - external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id) - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id=external_knowledge_id, - ) - - # Assert - assert result.provider == "external" - assert mock_db.add.call_count == 2 # Dataset + ExternalKnowledgeBindings - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.assert_called_once_with( - external_api_id - ) - mock_db.commit.assert_called_once() - - def test_create_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge API is not found.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "non-existent-api" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API not found - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = None - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - - # Act & Assert - with pytest.raises(ValueError, match="External API template not found"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id="knowledge-123", - ) - - def test_create_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge ID is missing.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "external-api-123" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API - external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id) - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - - # Act & Assert - with pytest.raises(ValueError, match="external_knowledge_id is required"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id=None, - ) - - # ==================== Error Handling Tests ==================== - - def test_create_dataset_duplicate_name_error(self, mock_dataset_service_dependencies): - """Test error when dataset name already exists.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Duplicate Dataset" - - # Mock database query to return existing dataset - existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = existing_dataset - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Act & Assert - with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - ) - - -class TestDatasetServiceCreateEmptyRagPipelineDataset: - """ - Comprehensive unit tests for DatasetService.create_empty_rag_pipeline_dataset method. - - This test suite covers: - - RAG pipeline dataset creation with provided name - - RAG pipeline dataset creation with auto-generated name - - Pipeline creation - - Error conditions (duplicate names, missing current user) - """ +class TestDatasetServiceCreateRagPipelineDatasetNonSQL: + """Unit coverage for non-SQL validation in create_empty_rag_pipeline_dataset.""" @pytest.fixture def mock_rag_pipeline_dependencies(self): - """Common mock setup for RAG pipeline dataset creation.""" + """Patch database session and current_user for validation-only unit coverage.""" with ( patch("services.dataset_service.db.session") as mock_db, patch("services.dataset_service.current_user") as mock_current_user, - patch("services.dataset_service.generate_incremental_name") as mock_generate_name, ): - # Configure mock_current_user to behave like a Flask-Login proxy - # Default: no user (falsy) - mock_current_user.id = None yield { "db_session": mock_db, "current_user_mock": mock_current_user, - "generate_name": mock_generate_name, } - def test_create_rag_pipeline_dataset_with_name_success(self, mock_rag_pipeline_dependencies): - """Test successful creation of RAG pipeline dataset with provided name.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "RAG Pipeline Dataset" - description = "RAG Pipeline Description" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query (no duplicate name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description=description, - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result is not None - assert result.name == name - assert result.description == description - assert result.tenant_id == tenant_id - assert result.created_by == user_id - assert result.provider == "vendor" - assert result.runtime_mode == "rag_pipeline" - assert result.permission == "only_me" - assert mock_db.add.call_count == 2 # Pipeline + Dataset - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_with_auto_generated_name(self, mock_rag_pipeline_dependencies): - """Test creation of RAG pipeline dataset with auto-generated name.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - auto_name = "Untitled 1" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query (empty name, need to generate) - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [] - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock name generation - mock_rag_pipeline_dependencies["generate_name"].return_value = auto_name - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity with empty name - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name="", - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.name == auto_name - mock_rag_pipeline_dependencies["generate_name"].assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_duplicate_name_error(self, mock_rag_pipeline_dependencies): - """Test error when RAG pipeline dataset name already exists.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Duplicate RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query to return existing dataset - existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = existing_dataset - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act & Assert - with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"): - DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies): - """Test error when current user is not available.""" + """Raise ValueError when current_user.id is unavailable before SQL persistence.""" # Arrange tenant_id = str(uuid4()) - - # Mock current user as None - set id to None so the check fails mock_rag_pipeline_dependencies["current_user_mock"].id = None - # Mock database query mock_query = Mock() mock_query.filter_by.return_value.first.return_value = None mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - # Create entity icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name="Test Dataset", @@ -729,91 +42,9 @@ class TestDatasetServiceCreateEmptyRagPipelineDataset: permission="only_me", ) - # Act & Assert + # Act / Assert with pytest.raises(ValueError, match="Current user or current user id not found"): DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity + tenant_id=tenant_id, + rag_pipeline_dataset_create_entity=entity, ) - - def test_create_rag_pipeline_dataset_with_custom_permission(self, mock_rag_pipeline_dependencies): - """Test creation with custom permission setting.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Custom Permission RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="all_team", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.permission == "all_team" - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_with_icon_info(self, mock_rag_pipeline_dependencies): - """Test creation with icon info configuration.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Icon Info RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity with icon info - icon_info = IconInfo( - icon="📚", - icon_background="#E8F5E9", - icon_type="emoji", - icon_url="https://example.com/icon.png", - ) - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.icon_info == icon_info.model_dump() - mock_db.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_document_service_rename_document.py b/api/tests/unit_tests/services/test_document_service_rename_document.py deleted file mode 100644 index 94850ecb09..0000000000 --- a/api/tests/unit_tests/services/test_document_service_rename_document.py +++ /dev/null @@ -1,176 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import Mock, create_autospec, patch - -import pytest - -from models import Account -from services.dataset_service import DocumentService - - -@pytest.fixture -def mock_env(): - """Patch dependencies used by DocumentService.rename_document. - - Mocks: - - DatasetService.get_dataset - - DocumentService.get_document - - current_user (with current_tenant_id) - - db.session - """ - with ( - patch("services.dataset_service.DatasetService.get_dataset") as get_dataset, - patch("services.dataset_service.DocumentService.get_document") as get_document, - patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user, - patch("extensions.ext_database.db.session") as db_session, - ): - current_user.current_tenant_id = "tenant-123" - yield { - "get_dataset": get_dataset, - "get_document": get_document, - "current_user": current_user, - "db_session": db_session, - } - - -def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False): - return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled) - - -def make_document( - document_id="document-123", - dataset_id="dataset-123", - tenant_id="tenant-123", - name="Old Name", - data_source_info=None, - doc_metadata=None, -): - doc = Mock() - doc.id = document_id - doc.dataset_id = dataset_id - doc.tenant_id = tenant_id - doc.name = name - doc.data_source_info = data_source_info or {} - # property-like usage in code relies on a dict - doc.data_source_info_dict = dict(doc.data_source_info) - doc.doc_metadata = dict(doc_metadata or {}) - return doc - - -def test_rename_document_success(mock_env): - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "New Document Name" - - dataset = make_dataset(dataset_id) - document = make_document(document_id=document_id, dataset_id=dataset_id) - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - result = DocumentService.rename_document(dataset_id, document_id, new_name) - - assert result is document - assert document.name == new_name - mock_env["db_session"].add.assert_called_once_with(document) - mock_env["db_session"].commit.assert_called_once() - - -def test_rename_document_with_built_in_fields(mock_env): - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "Renamed" - - dataset = make_dataset(dataset_id, built_in_field_enabled=True) - document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"}) - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - DocumentService.rename_document(dataset_id, document_id, new_name) - - assert document.name == new_name - # BuiltInField.document_name == "document_name" in service code - assert document.doc_metadata["document_name"] == new_name - assert document.doc_metadata["foo"] == "bar" - - -def test_rename_document_updates_upload_file_when_present(mock_env): - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "Renamed" - file_id = "file-123" - - dataset = make_dataset(dataset_id) - document = make_document( - document_id=document_id, - dataset_id=dataset_id, - data_source_info={"upload_file_id": file_id}, - ) - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - # Intercept UploadFile rename UPDATE chain - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_env["db_session"].query.return_value = mock_query - - DocumentService.rename_document(dataset_id, document_id, new_name) - - assert document.name == new_name - mock_env["db_session"].query.assert_called() # update executed - - -def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env): - """ - When data_source_info_dict exists but does not contain "upload_file_id", - UploadFile should not be updated. - """ - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "Another Name" - - dataset = make_dataset(dataset_id) - # Ensure data_source_info_dict is truthy but lacks the key - document = make_document( - document_id=document_id, - dataset_id=dataset_id, - data_source_info={"url": "https://example.com"}, - ) - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - DocumentService.rename_document(dataset_id, document_id, new_name) - - assert document.name == new_name - # Should NOT attempt to update UploadFile - mock_env["db_session"].query.assert_not_called() - - -def test_rename_document_dataset_not_found(mock_env): - mock_env["get_dataset"].return_value = None - - with pytest.raises(ValueError, match="Dataset not found"): - DocumentService.rename_document("missing", "doc", "x") - - -def test_rename_document_not_found(mock_env): - dataset = make_dataset("dataset-123") - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = None - - with pytest.raises(ValueError, match="Document not found"): - DocumentService.rename_document(dataset.id, "missing", "x") - - -def test_rename_document_permission_denied_when_tenant_mismatch(mock_env): - dataset = make_dataset("dataset-123") - # different tenant than current_user.current_tenant_id - document = make_document(dataset_id=dataset.id, tenant_id="tenant-other") - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - with pytest.raises(ValueError, match="No permission"): - DocumentService.rename_document(dataset.id, document.id, "x") diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py index e0d6ad1b39..e64d3c5406 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -2,13 +2,13 @@ from types import SimpleNamespace import pytest -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, ) -from core.workflow.runtime import VariablePool +from dify_graph.runtime import VariablePool from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( DeliveryTestContext, diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 5800d029ca..a4c6c50593 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -9,12 +9,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormDefinition, FormInput, UserAction, ) -from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus from models.human_input import RecipientType from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index e2360b116d..6a6b63f003 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -3,9 +3,9 @@ import types import pytest from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ConfigurateMethod +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod from models.provider import ProviderType from services.model_provider_service import ModelProviderService diff --git a/api/tests/unit_tests/services/test_schedule_service.py b/api/tests/unit_tests/services/test_schedule_service.py index e28965ea2c..5e3dd157e6 100644 --- a/api/tests/unit_tests/services/test_schedule_service.py +++ b/api/tests/unit_tests/services/test_schedule_service.py @@ -5,8 +5,8 @@ from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy.orm import Session -from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig -from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError +from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig +from dify_graph.nodes.trigger_schedule.exc import ScheduleConfigError from events.event_handlers.sync_workflow_schedule_when_app_published import ( sync_schedule_from_workflow, ) @@ -136,7 +136,7 @@ class TestScheduleService(unittest.TestCase): def test_update_schedule_not_found(self): """Test updating a non-existent schedule raises exception.""" - from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError + from dify_graph.nodes.trigger_schedule.exc import ScheduleNotFoundError mock_session = MagicMock(spec=Session) mock_session.get.return_value = None @@ -172,7 +172,7 @@ class TestScheduleService(unittest.TestCase): def test_delete_schedule_not_found(self): """Test deleting a non-existent schedule raises exception.""" - from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError + from dify_graph.nodes.trigger_schedule.exc import ScheduleNotFoundError mock_session = MagicMock(spec=Session) mock_session.get.return_value = None diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 8199d586da..c703ab64d0 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -17,9 +17,9 @@ from uuid import uuid4 import pytest -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File -from core.workflow.variables.segments import ( +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, ArraySegment, diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 1f92ff590c..27664c7e29 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -16,7 +16,7 @@ import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 8a97fd8a24..6b36592c41 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -14,8 +14,8 @@ from unittest.mock import MagicMock, Mock, patch import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from dify_graph.enums import NodeType +from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from libs.datetime_utils import naive_utc_now from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index 83642fc209..1e0fdd788b 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -6,8 +6,8 @@ from unittest.mock import Mock, patch import pytest from sqlalchemy import Engine -from core.workflow.variables.segments import ObjectSegment, StringSegment -from core.workflow.variables.types import SegmentType +from dify_graph.variables.segments import ObjectSegment, StringSegment +from dify_graph.variables.types import SegmentType from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader @@ -174,7 +174,7 @@ class TestDraftVarLoaderSimple: mock_storage.load.return_value = test_json_content.encode() with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from core.workflow.variables.segments import FloatSegment + from dify_graph.variables.segments import FloatSegment mock_segment = FloatSegment(value=test_number) mock_build_segment.return_value = mock_segment @@ -224,7 +224,7 @@ class TestDraftVarLoaderSimple: mock_storage.load.return_value = test_json_content.encode() with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from core.workflow.variables.segments import ArrayAnySegment + from dify_graph.variables.segments import ArrayAnySegment mock_segment = ArrayAnySegment(value=test_array) mock_build_segment.return_value = mock_segment diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 8ccbfbb16e..a847c2b4d1 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -15,9 +15,9 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.helper import encrypter -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import AppMode from services.workflow.workflow_converter import WorkflowConverter diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 792257848f..4042e05565 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -7,10 +7,10 @@ import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import NodeType -from core.workflow.variables.segments import StringSegment -from core.workflow.variables.types import SegmentType +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.enums import NodeType +from dify_graph.variables.segments import StringSegment +from dify_graph.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index 844dab8976..6c1adba2b8 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -12,9 +12,9 @@ import pytest from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from core.workflow.runtime import GraphRuntimeState, VariablePool +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from dify_graph.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index 5ac5ac8ad2..5d6fa4c137 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -5,8 +5,8 @@ from unittest.mock import MagicMock import pytest from sqlalchemy.orm import sessionmaker -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import ( +from dify_graph.enums import NodeType +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 015dac257e..83c1f8d9da 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -4,9 +4,9 @@ from unittest.mock import MagicMock import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import FormInputType +from dify_graph.enums import NodeType +from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import FormInputType from models.model import App from models.workflow import Workflow from services import workflow_service as workflow_service_module diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py index 68fb8b748f..f6dbc4275b 100644 --- a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -1,158 +1,38 @@ -""" -Unit tests for duplicate document indexing tasks. - -This module tests the duplicate document indexing task functionality including: -- Task enqueuing to different queues (normal, priority, tenant-isolated) -- Batch processing of multiple duplicate documents -- Progress tracking through task lifecycle -- Error handling and retry mechanisms -- Cleanup of old document data before re-indexing -""" +"""Unit tests for queue/wrapper behaviors in duplicate document indexing tasks (non-database logic).""" import uuid -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest -from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from enums.cloud_plan import CloudPlan -from models.dataset import Dataset, Document, DocumentSegment from tasks.duplicate_document_indexing_task import ( - _duplicate_document_indexing_task, _duplicate_document_indexing_task_with_tenant_queue, duplicate_document_indexing_task, normal_duplicate_document_indexing_task, priority_duplicate_document_indexing_task, ) -# ============================================================================ -# Fixtures -# ============================================================================ - @pytest.fixture def tenant_id(): - """Generate a unique tenant ID for testing.""" return str(uuid.uuid4()) @pytest.fixture def dataset_id(): - """Generate a unique dataset ID for testing.""" return str(uuid.uuid4()) @pytest.fixture def document_ids(): - """Generate a list of document IDs for testing.""" return [str(uuid.uuid4()) for _ in range(3)] -@pytest.fixture -def mock_dataset(dataset_id, tenant_id): - """Create a mock Dataset object.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.indexing_technique = "high_quality" - dataset.embedding_model_provider = "openai" - dataset.embedding_model = "text-embedding-ada-002" - return dataset - - -@pytest.fixture -def mock_documents(document_ids, dataset_id): - """Create mock Document objects.""" - documents = [] - for doc_id in document_ids: - doc = Mock(spec=Document) - doc.id = doc_id - doc.dataset_id = dataset_id - doc.indexing_status = "waiting" - doc.error = None - doc.stopped_at = None - doc.processing_started_at = None - doc.doc_form = "text_model" - documents.append(doc) - return documents - - -@pytest.fixture -def mock_document_segments(document_ids): - """Create mock DocumentSegment objects.""" - segments = [] - for doc_id in document_ids: - for i in range(3): - segment = Mock(spec=DocumentSegment) - segment.id = str(uuid.uuid4()) - segment.document_id = doc_id - segment.index_node_id = f"node-{doc_id}-{i}" - segments.append(segment) - return segments - - -@pytest.fixture -def mock_db_session(): - """Mock database session via session_factory.create_session().""" - with patch("tasks.duplicate_document_indexing_task.session_factory", autospec=True) as mock_sf: - session = MagicMock() - # Allow tests to observe session.close() via context manager teardown - session.close = MagicMock() - cm = MagicMock() - cm.__enter__.return_value = session - - def _exit_side_effect(*args, **kwargs): - session.close() - - cm.__exit__.side_effect = _exit_side_effect - mock_sf.create_session.return_value = cm - - query = MagicMock() - session.query.return_value = query - query.where.return_value = query - session.scalars.return_value = MagicMock() - yield session - - -@pytest.fixture -def mock_indexing_runner(): - """Mock IndexingRunner.""" - with patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_runner_class: - mock_runner = MagicMock(spec=IndexingRunner) - mock_runner_class.return_value = mock_runner - yield mock_runner - - -@pytest.fixture -def mock_feature_service(): - """Mock FeatureService.""" - with patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_service: - mock_features = Mock() - mock_features.billing = Mock() - mock_features.billing.enabled = False - mock_features.vector_space = Mock() - mock_features.vector_space.size = 0 - mock_features.vector_space.limit = 1000 - mock_service.get_features.return_value = mock_features - yield mock_service - - -@pytest.fixture -def mock_index_processor_factory(): - """Mock IndexProcessorFactory.""" - with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True) as mock_factory: - mock_processor = MagicMock() - mock_processor.clean = Mock() - mock_factory.return_value.init_index_processor.return_value = mock_processor - yield mock_factory - - @pytest.fixture def mock_tenant_isolated_queue(): - """Mock TenantIsolatedTaskQueue.""" with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) as mock_queue_class: - mock_queue = MagicMock(spec=TenantIsolatedTaskQueue) + mock_queue = Mock(spec=TenantIsolatedTaskQueue) mock_queue.pull_tasks.return_value = [] mock_queue.delete_task_key = Mock() mock_queue.set_task_waiting_time = Mock() @@ -160,11 +40,6 @@ def mock_tenant_isolated_queue(): yield mock_queue -# ============================================================================ -# Tests for deprecated duplicate_document_indexing_task -# ============================================================================ - - class TestDuplicateDocumentIndexingTask: """Tests for the deprecated duplicate_document_indexing_task function.""" @@ -190,258 +65,6 @@ class TestDuplicateDocumentIndexingTask: mock_core_func.assert_called_once_with(dataset_id, document_ids) -# ============================================================================ -# Tests for _duplicate_document_indexing_task core function -# ============================================================================ - - -class TestDuplicateDocumentIndexingTaskCore: - """Tests for the _duplicate_document_indexing_task core function.""" - - def test_successful_duplicate_document_indexing( - self, - mock_db_session, - mock_indexing_runner, - mock_feature_service, - mock_index_processor_factory, - mock_dataset, - mock_documents, - mock_document_segments, - dataset_id, - document_ids, - ): - """Test successful duplicate document indexing flow.""" - # Arrange - # Dataset via query.first() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - # scalars() call sequence: - # 1) documents list - # 2..N) segments per document - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - # First call returns documents; subsequent calls return segments - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = mock_document_segments - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Verify IndexingRunner was called - mock_indexing_runner.run.assert_called_once() - - # Verify all documents were set to parsing status - for doc in mock_documents: - assert doc.indexing_status == "parsing" - assert doc.processing_started_at is not None - - # Verify session operations - assert mock_db_session.commit.called - assert mock_db_session.close.called - - def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids): - """Test duplicate document indexing when dataset is not found.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = None - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Should close the session at least once - assert mock_db_session.close.called - - def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( - self, - mock_db_session, - mock_feature_service, - mock_dataset, - dataset_id, - document_ids, - ): - """Test duplicate document indexing with billing enabled and sandbox plan.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_features = mock_feature_service.get_features.return_value - mock_features.billing.enabled = True - mock_features.billing.subscription.plan = CloudPlan.SANDBOX - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # For sandbox plan with multiple documents, should fail - mock_db_session.commit.assert_called() - - def test_duplicate_document_indexing_with_billing_limit_exceeded( - self, - mock_db_session, - mock_feature_service, - mock_dataset, - mock_documents, - dataset_id, - document_ids, - ): - """Test duplicate document indexing when billing limit is exceeded.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - # First scalars() -> documents; subsequent -> empty segments - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = [] - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - mock_features = mock_feature_service.get_features.return_value - mock_features.billing.enabled = True - mock_features.billing.subscription.plan = CloudPlan.TEAM - mock_features.vector_space.size = 990 - mock_features.vector_space.limit = 1000 - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Should commit the session - assert mock_db_session.commit.called - # Should close the session - assert mock_db_session.close.called - - def test_duplicate_document_indexing_runner_error( - self, - mock_db_session, - mock_indexing_runner, - mock_feature_service, - mock_index_processor_factory, - mock_dataset, - mock_documents, - dataset_id, - document_ids, - ): - """Test duplicate document indexing when IndexingRunner raises an error.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = [] - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - mock_indexing_runner.run.side_effect = Exception("Indexing error") - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Should close the session even after error - mock_db_session.close.assert_called_once() - - def test_duplicate_document_indexing_document_is_paused( - self, - mock_db_session, - mock_indexing_runner, - mock_feature_service, - mock_index_processor_factory, - mock_dataset, - mock_documents, - dataset_id, - document_ids, - ): - """Test duplicate document indexing when document is paused.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = [] - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Should handle DocumentIsPausedError gracefully - mock_db_session.close.assert_called_once() - - def test_duplicate_document_indexing_cleans_old_segments( - self, - mock_db_session, - mock_indexing_runner, - mock_feature_service, - mock_index_processor_factory, - mock_dataset, - mock_documents, - mock_document_segments, - dataset_id, - document_ids, - ): - """Test that duplicate document indexing cleans old segments.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = mock_document_segments - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Verify clean was called for each document - assert mock_processor.clean.call_count == len(mock_documents) - - # Verify segments were deleted in batch (DELETE FROM document_segments) - execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] - assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) - - -# ============================================================================ -# Tests for tenant queue wrapper function -# ============================================================================ - - class TestDuplicateDocumentIndexingTaskWithTenantQueue: """Tests for _duplicate_document_indexing_task_with_tenant_queue function.""" @@ -536,11 +159,6 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue: mock_tenant_isolated_queue.pull_tasks.assert_called_once() -# ============================================================================ -# Tests for normal_duplicate_document_indexing_task -# ============================================================================ - - class TestNormalDuplicateDocumentIndexingTask: """Tests for normal_duplicate_document_indexing_task function.""" @@ -581,11 +199,6 @@ class TestNormalDuplicateDocumentIndexingTask: ) -# ============================================================================ -# Tests for priority_duplicate_document_indexing_task -# ============================================================================ - - class TestPriorityDuplicateDocumentIndexingTask: """Tests for priority_duplicate_document_indexing_task function.""" diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index ee0699ba2d..bd0182a402 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -6,7 +6,7 @@ from typing import Any import pytest -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from tasks import human_input_timeout_tasks as task_module @@ -47,7 +47,7 @@ class _FakeSessionFactory: class _FakeFormRepo: - def __init__(self, _session_factory, form_map: dict[str, Any] | None = None): + def __init__(self, form_map: dict[str, Any] | None = None): self.calls: list[dict[str, Any]] = [] self._form_map = form_map or {} @@ -149,9 +149,9 @@ def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pyt monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture)) form_map = {form.id: form for form in forms} - repo = _FakeFormRepo(None, form_map=form_map) + repo = _FakeFormRepo(form_map=form_map) - def _repo_factory(_session_factory): + def _repo_factory(): return repo service = _FakeService(None) diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py index 161151305d..d3cf632b47 100644 --- a/api/tests/unit_tests/tasks/test_workflow_execute_task.py +++ b/api/tests/unit_tests/tasks/test_workflow_execute_task.py @@ -2,12 +2,40 @@ from __future__ import annotations import json import uuid +from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from models.model import AppMode -from tasks.app_generate.workflow_execute_task import _publish_streaming_response +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from models.enums import CreatorUserRole +from models.model import App, AppMode, Conversation +from models.workflow import Workflow, WorkflowRun +from tasks.app_generate.workflow_execute_task import _publish_streaming_response, _resume_app_execution + + +class _FakeSessionContext: + def __init__(self, session: MagicMock): + self._session = session + + def __enter__(self) -> MagicMock: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +def _build_advanced_chat_generate_entity(conversation_id: str | None) -> AdvancedChatAppGenerateEntity: + return AdvancedChatAppGenerateEntity( + task_id="task-id", + inputs={}, + files=[], + user_id="user-id", + stream=True, + invoke_from=InvokeFrom.WEB_APP, + query="query", + conversation_id=conversation_id, + ) @pytest.fixture @@ -37,3 +65,138 @@ def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): _publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT) mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode()) + + +def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(mocker): + workflow_run_id = "run-id" + conversation_id = "conversation-id" + message = MagicMock() + + mocker.patch("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object())) + + pause_entity = MagicMock() + pause_entity.get_state.return_value = b"state" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_pause.return_value = pause_entity + mocker.patch( + "tasks.app_generate.workflow_execute_task.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + generate_entity = _build_advanced_chat_generate_entity(conversation_id) + resumption_context = MagicMock() + resumption_context.serialized_graph_runtime_state = "{}" + resumption_context.get_generate_entity.return_value = generate_entity + mocker.patch( + "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", return_value=resumption_context + ) + mocker.patch("tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", return_value=MagicMock()) + + workflow_run = SimpleNamespace( + workflow_id="wf-id", + app_id="app-id", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by="account-id", + tenant_id="tenant-id", + ) + workflow = SimpleNamespace(created_by="workflow-owner") + app_model = SimpleNamespace(id="app-id") + conversation = SimpleNamespace(id=conversation_id) + + session = MagicMock() + + def _session_get(model, key): + if model is WorkflowRun: + return workflow_run + if model is Workflow: + return workflow + if model is App: + return app_model + if model is Conversation: + return conversation + return None + + session.get.side_effect = _session_get + session.scalar.return_value = message + + mocker.patch("tasks.app_generate.workflow_execute_task.Session", return_value=_FakeSessionContext(session)) + mocker.patch("tasks.app_generate.workflow_execute_task._resolve_user_for_run", return_value=MagicMock()) + resume_advanced_chat = mocker.patch("tasks.app_generate.workflow_execute_task._resume_advanced_chat") + mocker.patch("tasks.app_generate.workflow_execute_task._resume_workflow") + + _resume_app_execution({"workflow_run_id": workflow_run_id}) + + stmt = session.scalar.call_args.args[0] + stmt_text = str(stmt) + assert "messages.conversation_id = :conversation_id_1" in stmt_text + assert "messages.workflow_run_id = :workflow_run_id_1" in stmt_text + assert "ORDER BY messages.created_at DESC" in stmt_text + assert " LIMIT " in stmt_text + + compiled_params = stmt.compile().params + assert conversation_id in compiled_params.values() + assert workflow_run_id in compiled_params.values() + + workflow_run_repo.resume_workflow_pause.assert_called_once_with(workflow_run_id, pause_entity) + resume_advanced_chat.assert_called_once() + assert resume_advanced_chat.call_args.kwargs["conversation"] is conversation + assert resume_advanced_chat.call_args.kwargs["message"] is message + + +def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversation_id(mocker): + workflow_run_id = "run-id" + + mocker.patch("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object())) + + pause_entity = MagicMock() + pause_entity.get_state.return_value = b"state" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_pause.return_value = pause_entity + mocker.patch( + "tasks.app_generate.workflow_execute_task.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + generate_entity = _build_advanced_chat_generate_entity(conversation_id=None) + resumption_context = MagicMock() + resumption_context.serialized_graph_runtime_state = "{}" + resumption_context.get_generate_entity.return_value = generate_entity + mocker.patch( + "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", return_value=resumption_context + ) + mocker.patch("tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", return_value=MagicMock()) + + workflow_run = SimpleNamespace( + workflow_id="wf-id", + app_id="app-id", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by="account-id", + tenant_id="tenant-id", + ) + workflow = SimpleNamespace(created_by="workflow-owner") + app_model = SimpleNamespace(id="app-id") + + session = MagicMock() + + def _session_get(model, key): + if model is WorkflowRun: + return workflow_run + if model is Workflow: + return workflow + if model is App: + return app_model + return None + + session.get.side_effect = _session_get + + mocker.patch("tasks.app_generate.workflow_execute_task.Session", return_value=_FakeSessionContext(session)) + mocker.patch("tasks.app_generate.workflow_execute_task._resolve_user_for_run", return_value=MagicMock()) + resume_advanced_chat = mocker.patch("tasks.app_generate.workflow_execute_task._resume_advanced_chat") + + _resume_app_execution({"workflow_run_id": workflow_run_id}) + + session.scalar.assert_not_called() + workflow_run_repo.resume_workflow_pause.assert_not_called() + resume_advanced_chat.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py index fd5f0713a4..54be8379d5 100644 --- a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py +++ b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py @@ -11,11 +11,11 @@ # import pytest -# from core.workflow.entities.workflow_node_execution import ( +# from dify_graph.entities.workflow_node_execution import ( # WorkflowNodeExecution, # WorkflowNodeExecutionStatus, # ) -# from core.workflow.enums import NodeType +# from dify_graph.enums import NodeType # from libs.datetime_utils import naive_utc_now # from models import WorkflowNodeExecutionModel # from models.enums import ExecutionOffLoadType diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index 5930b63f58..fa9c6af287 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -13,11 +13,11 @@ from core.mcp.types import ( TextContent, TextResourceContents, ) -from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool +from dify_graph.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 1d85240c4c..78fa7820e8 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -11,17 +11,17 @@ from core.llm_generator.output_parser.structured_output import ( invoke_llm_with_pydantic_model, invoke_llm_with_structured_output, ) -from core.model_runtime.entities.llm_entities import ( +from dify_graph.model_runtime.entities.llm_entities import ( LLMResult, LLMResultWithStructuredOutput, LLMUsage, ) -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py new file mode 100644 index 0000000000..1f0bf8ef37 --- /dev/null +++ b/api/tests/workflow_test_utils.py @@ -0,0 +1,53 @@ +from collections.abc import Mapping +from typing import Any + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from dify_graph.entities.graph_init_params import GraphInitParams + + +def build_test_run_context( + *, + tenant_id: str = "tenant", + app_id: str = "app", + user_id: str = "user", + user_from: UserFrom | str = UserFrom.ACCOUNT, + invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER, + extra_context: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + normalized_user_from = user_from if isinstance(user_from, UserFrom) else UserFrom(user_from) + normalized_invoke_from = invoke_from if isinstance(invoke_from, InvokeFrom) else InvokeFrom(invoke_from) + return build_dify_run_context( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=normalized_user_from, + invoke_from=normalized_invoke_from, + extra_context=extra_context, + ) + + +def build_test_graph_init_params( + *, + workflow_id: str = "workflow", + graph_config: Mapping[str, Any] | None = None, + call_depth: int = 0, + tenant_id: str = "tenant", + app_id: str = "app", + user_id: str = "user", + user_from: UserFrom | str = UserFrom.ACCOUNT, + invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER, + extra_context: Mapping[str, Any] | None = None, +) -> GraphInitParams: + return GraphInitParams( + workflow_id=workflow_id, + graph_config=graph_config or {}, + run_context=build_test_run_context( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + extra_context=extra_context, + ), + call_depth=call_depth, + ) diff --git a/docker/.env.example b/docker/.env.example index ead6c38f54..0f3112ad0e 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -354,6 +354,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional: limit total Redis connections used by API/Worker (unset for default) +# Align with API's REDIS_MAX_CONNECTIONS in configs +REDIS_MAX_CONNECTIONS= # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 51ca23db29..c21a877754 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -91,6 +91,7 @@ x-shared-env: &shared-api-worker-env REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} + REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} diff --git a/docker/middleware.env.example b/docker/middleware.env.example index bb2eb84823..8c38c91f7a 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -91,6 +91,9 @@ MYSQL_INNODB_FLUSH_LOG_AT_TRX_COMMIT=2 # ----------------------------- REDIS_HOST_VOLUME=./volumes/redis/data REDIS_PASSWORD=difyai123456 +# Optional: limit total Redis connections used by API/Worker (unset for default) +# Align with API's REDIS_MAX_CONNECTIONS in configs +REDIS_MAX_CONNECTIONS= # ------------------------------ # Environment Variables for sandbox Service diff --git a/web/.nvmrc b/web/.nvmrc index a45fd52cc5..2bd5a0a98a 100644 --- a/web/.nvmrc +++ b/web/.nvmrc @@ -1 +1 @@ -24 +22 diff --git a/web/Dockerfile b/web/Dockerfile index 8bb39d6cb7..392d319ea8 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -1,5 +1,5 @@ # base image -FROM node:24-alpine AS base +FROM node:22-alpine AS base LABEL maintainer="takatost@gmail.com" # if you located in China, you can use aliyun mirror to speed up diff --git a/web/__tests__/apps/app-card-operations-flow.test.tsx b/web/__tests__/apps/app-card-operations-flow.test.tsx index 8e099a8c1e..763d071423 100644 --- a/web/__tests__/apps/app-card-operations-flow.test.tsx +++ b/web/__tests__/apps/app-card-operations-flow.test.tsx @@ -14,7 +14,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import AppCard from '@/app/components/apps/app-card' import { AccessMode } from '@/models/access-control' -import { deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' +import { exportAppConfig, updateAppInfo } from '@/service/apps' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -26,6 +26,8 @@ let mockSystemFeatures = { const mockRouterPush = vi.fn() const mockNotify = vi.fn() const mockOnPlanInfoChanged = vi.fn() +const mockDeleteAppMutation = vi.fn().mockResolvedValue(undefined) +let mockDeleteMutationPending = false vi.mock('next/navigation', () => ({ useRouter: () => ({ @@ -117,6 +119,13 @@ vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([]), })) +vi.mock('@/service/use-apps', () => ({ + useDeleteAppMutation: () => ({ + mutateAsync: mockDeleteAppMutation, + isPending: mockDeleteMutationPending, + }), +})) + vi.mock('@/service/apps', () => ({ deleteApp: vi.fn().mockResolvedValue({}), updateAppInfo: vi.fn().mockResolvedValue({}), @@ -271,6 +280,7 @@ const renderAppCard = (app?: Partial) => { describe('App Card Operations Flow', () => { beforeEach(() => { vi.clearAllMocks() + mockDeleteMutationPending = false mockIsCurrentWorkspaceEditor = true mockSystemFeatures = { branding: { enabled: false }, @@ -342,7 +352,7 @@ describe('App Card Operations Flow', () => { fireEvent.click(confirmBtn) await waitFor(() => { - expect(deleteApp).toHaveBeenCalledWith('app-to-delete') + expect(mockDeleteAppMutation).toHaveBeenCalledWith('app-to-delete') }) } } diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 88acfc8140..ddb5113b6a 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -8,11 +8,11 @@ */ import type { AppListResponse } from '@/models/app' import type { App } from '@/types/app' -import { fireEvent, render, screen } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { fireEvent, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import List from '@/app/components/apps/list' import { AccessMode } from '@/models/access-control' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -104,6 +104,10 @@ vi.mock('@/service/use-apps', () => ({ error: mockError, refetch: mockRefetch, }), + useDeleteAppMutation: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), })) vi.mock('@/hooks/use-pay', () => ({ @@ -162,10 +166,9 @@ const createPage = (apps: App[], hasMore = false, page = 1): AppListResponse => }) const renderList = (searchParams?: Record) => { - return render( - - - , + return renderWithNuqs( + , + { searchParams }, ) } @@ -210,11 +213,7 @@ describe('App List Browsing Flow', () => { it('should transition from loading to content when data loads', () => { mockIsLoading = true - const { rerender } = render( - - - , - ) + const { rerender } = renderWithNuqs() const skeletonCards = document.querySelectorAll('.animate-pulse') expect(skeletonCards.length).toBeGreaterThan(0) @@ -225,11 +224,7 @@ describe('App List Browsing Flow', () => { createMockApp({ id: 'app-1', name: 'Loaded App' }), ])] - rerender( - - - , - ) + rerender() expect(screen.getByText('Loaded App')).toBeInTheDocument() }) @@ -425,17 +420,9 @@ describe('App List Browsing Flow', () => { it('should call refetch when controlRefreshList increments', () => { mockPages = [createPage([createMockApp()])] - const { rerender } = render( - - - , - ) + const { rerender } = renderWithNuqs() - rerender( - - - , - ) + rerender() expect(mockRefetch).toHaveBeenCalled() }) diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index 9a859ef908..d81d1473d2 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -9,11 +9,11 @@ */ import type { AppListResponse } from '@/models/app' import type { App } from '@/types/app' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { fireEvent, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import List from '@/app/components/apps/list' import { AccessMode } from '@/models/access-control' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -91,6 +91,10 @@ vi.mock('@/service/use-apps', () => ({ error: null, refetch: mockRefetch, }), + useDeleteAppMutation: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), })) vi.mock('@/hooks/use-pay', () => ({ @@ -215,11 +219,7 @@ const createPage = (apps: App[]): AppListResponse => ({ }) const renderList = () => { - return render( - - - , - ) + return renderWithNuqs() } describe('Create App Flow', () => { diff --git a/web/__tests__/datasets/document-management.test.tsx b/web/__tests__/datasets/document-management.test.tsx index 3b901ccee2..8aedd4fc63 100644 --- a/web/__tests__/datasets/document-management.test.tsx +++ b/web/__tests__/datasets/document-management.test.tsx @@ -7,9 +7,10 @@ */ import type { SimpleDocumentDetail } from '@/models/datasets' -import { act, renderHook } from '@testing-library/react' +import { act, renderHook, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { DataSourceType } from '@/models/datasets' +import { renderHookWithNuqs } from '@/test/nuqs-testing' const mockPush = vi.fn() vi.mock('next/navigation', () => ({ @@ -28,12 +29,16 @@ const { useDocumentSort } = await import( const { useDocumentSelection } = await import( '@/app/components/datasets/documents/components/document-list/hooks/use-document-selection', ) -const { default: useDocumentListQueryState } = await import( +const { useDocumentListQueryState } = await import( '@/app/components/datasets/documents/hooks/use-document-list-query-state', ) type LocalDoc = SimpleDocumentDetail & { percent?: number } +const renderQueryStateHook = (searchParams = '') => { + return renderHookWithNuqs(() => useDocumentListQueryState(), { searchParams }) +} + const createDoc = (overrides?: Partial): LocalDoc => ({ id: `doc-${Math.random().toString(36).slice(2, 8)}`, name: 'test-doc.txt', @@ -85,7 +90,7 @@ describe('Document Management Flow', () => { describe('URL-based Query State', () => { it('should parse default query from empty URL params', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + const { result } = renderQueryStateHook() expect(result.current.query).toEqual({ page: 1, @@ -96,107 +101,85 @@ describe('Document Management Flow', () => { }) }) - it('should update query and push to router', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + it('should update keyword query with replace history', async () => { + const { result, onUrlUpdate } = renderQueryStateHook() act(() => { result.current.updateQuery({ keyword: 'test', page: 2 }) }) - expect(mockPush).toHaveBeenCalled() - // The push call should contain the updated query params - const pushUrl = mockPush.mock.calls[0][0] as string - expect(pushUrl).toContain('keyword=test') - expect(pushUrl).toContain('page=2') + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('replace') + expect(update.searchParams.get('keyword')).toBe('test') + expect(update.searchParams.get('page')).toBe('2') }) - it('should reset query to defaults', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + it('should reset query to defaults', async () => { + const { result, onUrlUpdate } = renderQueryStateHook() act(() => { result.current.resetQuery() }) - expect(mockPush).toHaveBeenCalled() - // Default query omits default values from URL - const pushUrl = mockPush.mock.calls[0][0] as string - expect(pushUrl).toBe('/datasets/ds-1/documents') + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('replace') + expect(update.searchParams.toString()).toBe('') }) }) describe('Document Sort Integration', () => { - it('should return documents unsorted when no sort field set', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'Banana.txt', word_count: 300 }), - createDoc({ id: 'doc-2', name: 'Apple.txt', word_count: 100 }), - createDoc({ id: 'doc-3', name: 'Cherry.txt', word_count: 200 }), - ] - + it('should derive sort field and order from remote sort value', () => { const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', remoteSortValue: '-created_at', + onRemoteSortChange: vi.fn(), })) - expect(result.current.sortField).toBeNull() - expect(result.current.sortedDocuments).toHaveLength(3) + expect(result.current.sortField).toBe('created_at') + expect(result.current.sortOrder).toBe('desc') }) - it('should sort by name descending', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'Banana.txt' }), - createDoc({ id: 'doc-2', name: 'Apple.txt' }), - createDoc({ id: 'doc-3', name: 'Cherry.txt' }), - ] - + it('should call remote sort change with descending sort for a new field', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', remoteSortValue: '-created_at', + onRemoteSortChange, })) act(() => { - result.current.handleSort('name') + result.current.handleSort('hit_count') }) - expect(result.current.sortField).toBe('name') - expect(result.current.sortOrder).toBe('desc') - const names = result.current.sortedDocuments.map(d => d.name) - expect(names).toEqual(['Cherry.txt', 'Banana.txt', 'Apple.txt']) + expect(onRemoteSortChange).toHaveBeenCalledWith('-hit_count') }) - it('should toggle sort order on same field click', () => { - const docs = [createDoc({ id: 'doc-1', name: 'A.txt' }), createDoc({ id: 'doc-2', name: 'B.txt' })] - + it('should toggle descending to ascending when clicking active field', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '-created_at', + remoteSortValue: '-hit_count', + onRemoteSortChange, })) - act(() => result.current.handleSort('name')) - expect(result.current.sortOrder).toBe('desc') + act(() => { + result.current.handleSort('hit_count') + }) - act(() => result.current.handleSort('name')) - expect(result.current.sortOrder).toBe('asc') + expect(onRemoteSortChange).toHaveBeenCalledWith('hit_count') }) - it('should filter by status before sorting', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'A.txt', display_status: 'available' }), - createDoc({ id: 'doc-2', name: 'B.txt', display_status: 'error' }), - createDoc({ id: 'doc-3', name: 'C.txt', display_status: 'available' }), - ] - + it('should ignore null sort field updates', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: 'available', remoteSortValue: '-created_at', + onRemoteSortChange, })) - // Only 'available' documents should remain - expect(result.current.sortedDocuments).toHaveLength(2) - expect(result.current.sortedDocuments.every(d => d.display_status === 'available')).toBe(true) + act(() => { + result.current.handleSort(null) + }) + + expect(onRemoteSortChange).not.toHaveBeenCalled() }) }) @@ -309,14 +292,13 @@ describe('Document Management Flow', () => { describe('Cross-Module: Query State → Sort → Selection Pipeline', () => { it('should maintain consistent default state across all hooks', () => { const docs = [createDoc({ id: 'doc-1' })] - const { result: queryResult } = renderHook(() => useDocumentListQueryState()) + const { result: queryResult } = renderQueryStateHook() const { result: sortResult } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: queryResult.current.query.status, remoteSortValue: queryResult.current.query.sort, + onRemoteSortChange: vi.fn(), })) const { result: selResult } = renderHook(() => useDocumentSelection({ - documents: sortResult.current.sortedDocuments, + documents: docs, selectedIds: [], onSelectedIdChange: vi.fn(), })) @@ -325,8 +307,9 @@ describe('Document Management Flow', () => { expect(queryResult.current.query.sort).toBe('-created_at') expect(queryResult.current.query.status).toBe('all') - // Sort inherits 'all' status → no filtering applied - expect(sortResult.current.sortedDocuments).toHaveLength(1) + // Sort state is derived from URL default sort. + expect(sortResult.current.sortField).toBe('created_at') + expect(sortResult.current.sortOrder).toBe('desc') // Selection starts empty expect(selResult.current.isAllSelected).toBe(false) diff --git a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts index 578552840d..dc5ab3fc86 100644 --- a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts +++ b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts @@ -19,7 +19,7 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/__tests__/tools/tool-browsing-and-filtering.test.tsx b/web/__tests__/tools/tool-browsing-and-filtering.test.tsx index 4e7fa4952b..dbefb1fdc3 100644 --- a/web/__tests__/tools/tool-browsing-and-filtering.test.tsx +++ b/web/__tests__/tools/tool-browsing-and-filtering.test.tsx @@ -28,9 +28,13 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('nuqs', () => ({ - useQueryState: () => ['builtin', vi.fn()], -})) +vi.mock('nuqs', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useQueryState: () => ['builtin', vi.fn()], + } +}) vi.mock('@/context/global-public-context', () => ({ useGlobalPublicStore: () => ({ enable_marketplace: false }), @@ -212,6 +216,12 @@ vi.mock('@/app/components/tools/marketplace', () => ({ default: () => null, })) +vi.mock('@/app/components/tools/marketplace/hooks', () => ({ + useMarketplace: () => ({ + handleScroll: vi.fn(), + }), +})) + vi.mock('@/app/components/tools/mcp', () => ({ default: () =>
MCP List
, })) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 470f4477fa..fd0bf2c8bd 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import type { NavIcon } from '@/app/components/app-sidebar/navLink' +import type { NavIcon } from '@/app/components/app-sidebar/nav-link' import type { App } from '@/types/app' import { RiDashboard2Fill, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index abdb8cd196..cd542cac9b 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -13,7 +13,7 @@ import AppCard from '@/app/components/app/overview/app-card' import TriggerCard from '@/app/components/app/overview/trigger-card' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index abd5dd96fd..db2786f6cf 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -8,10 +8,10 @@ import GotoAnything from '@/app/components/goto-anything' import Header from '@/app/components/header' import HeaderWrapper from '@/app/components/header/header-wrapper' import ReadmePanel from '@/app/components/plugins/readme-panel' -import { AppContextProvider } from '@/context/app-context' -import { EventEmitterContextProvider } from '@/context/event-emitter' -import { ModalContextProvider } from '@/context/modal-context' -import { ProviderContextProvider } from '@/context/provider-context' +import { AppContextProvider } from '@/context/app-context-provider' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' +import { ModalContextProvider } from '@/context/modal-context-provider' +import { ProviderContextProvider } from '@/context/provider-context-provider' import PartnerStack from '../components/billing/partner-stack' import Splash from '../components/splash' import RoleRouteGuard from './role-route-guard' diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 15c1865eb0..76db83c1ba 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -16,7 +16,7 @@ import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { updateUserProfile } from '@/service/common' diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index 461f37e978..c146174ea9 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -9,7 +9,7 @@ import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { checkEmailExisted, resetEmail, diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index 3a99d778ab..835663c721 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import PremiumBadge from '@/app/components/base/premium-badge' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Collapse from '@/app/components/header/account-setting/collapse' import { IS_CE_EDITION, validPassword } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/account/(commonLayout)/layout.tsx b/web/app/account/(commonLayout)/layout.tsx index e4125015d9..8fdbd8a238 100644 --- a/web/app/account/(commonLayout)/layout.tsx +++ b/web/app/account/(commonLayout)/layout.tsx @@ -4,10 +4,10 @@ import { AppInitializer } from '@/app/components/app-initializer' import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' -import { AppContextProvider } from '@/context/app-context' -import { EventEmitterContextProvider } from '@/context/event-emitter' -import { ModalContextProvider } from '@/context/modal-context' -import { ProviderContextProvider } from '@/context/provider-context' +import { AppContextProvider } from '@/context/app-context-provider' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' +import { ModalContextProvider } from '@/context/modal-context-provider' +import { ProviderContextProvider } from '@/context/provider-context-provider' import Header from './header' const Layout = ({ children }: { children: ReactNode }) => { diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index b7e7aa09ba..7f6b270b45 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -2,7 +2,7 @@ import Loading from '@/app/components/base/loading' import Header from '@/app/signin/_header' -import { AppContextProvider } from '@/context/app-context' +import { AppContextProvider } from '@/context/app-context-provider' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import { useIsLogin } from '@/service/use-common' diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index e4cd10175a..bf7aa39580 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -26,11 +26,10 @@ export const AppInitializer = ({ // Tokens are now stored in cookies, no need to check localStorage const pathname = usePathname() const [init, setInit] = useState(false) - const [oauthNewUser, setOauthNewUser] = useQueryState( + const [oauthNewUser] = useQueryState( 'oauth_new_user', parseAsBoolean.withOptions({ history: 'replace' }), ) - const isSetupFinished = useCallback(async () => { try { const setUpStatus = await fetchSetupStatusWithCache() @@ -69,11 +68,12 @@ export const AppInitializer = ({ ...utmInfo, }) - // Clean up: remove utm_info cookie and URL params Cookies.remove('utm_info') - setOauthNewUser(null) } + if (oauthNewUser !== null) + router.replace(pathname) + if (action === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION) localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes') @@ -96,7 +96,7 @@ export const AppInitializer = ({ router.replace('/signin') } })() - }, [isSetupFinished, router, pathname, searchParams, oauthNewUser, setOauthNewUser]) + }, [isSetupFinished, router, pathname, searchParams, oauthNewUser]) return init ? children : null } diff --git a/web/app/components/app-sidebar/__tests__/app-sidebar-dropdown.spec.tsx b/web/app/components/app-sidebar/__tests__/app-sidebar-dropdown.spec.tsx new file mode 100644 index 0000000000..5018709da1 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/app-sidebar-dropdown.spec.tsx @@ -0,0 +1,177 @@ +import type { App, AppSSO } from '@/types/app' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppSidebarDropdown from '../app-sidebar-dropdown' + +let mockAppDetail: (App & Partial) | undefined + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: Record) => unknown) => selector({ + appDetail: mockAppDetail, + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: true, + }), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( +
{children}
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +vi.mock('../../base/app-icon', () => ({ + default: ({ size, icon }: { size: string, icon: string }) => ( +
+ ), +})) + +vi.mock('../../base/divider', () => ({ + default: () =>
, +})) + +vi.mock('../app-info', () => ({ + default: ({ expand, onlyShowDetail, openState }: { + expand: boolean + onlyShowDetail?: boolean + openState?: boolean + }) => ( +
+ ), +})) + +vi.mock('../nav-link', () => ({ + default: ({ name, href, mode }: { name: string, href: string, mode?: string }) => ( + {name} + ), +})) + +const MockIcon = (props: React.SVGProps) => + +const createAppDetail = (overrides: Partial = {}): App & Partial => ({ + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: '', + use_icon_as_answer_icon: false, + ...overrides, +} as App & Partial) + +const navigation = [ + { name: 'Overview', href: '/overview', icon: MockIcon, selectedIcon: MockIcon }, + { name: 'Logs', href: '/logs', icon: MockIcon, selectedIcon: MockIcon }, +] + +describe('AppSidebarDropdown', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppDetail = createAppDetail() + }) + + it('should return null when appDetail is not available', () => { + mockAppDetail = undefined + const { container } = render() + expect(container.innerHTML).toBe('') + }) + + it('should render trigger with app icon', () => { + render() + const icons = screen.getAllByTestId('app-icon') + const smallIcon = icons.find(i => i.getAttribute('data-size') === 'small') + expect(smallIcon).toBeInTheDocument() + }) + + it('should render navigation links', () => { + render() + expect(screen.getByTestId('nav-link-Overview')).toBeInTheDocument() + expect(screen.getByTestId('nav-link-Logs')).toBeInTheDocument() + }) + + it('should display app name', () => { + render() + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should display app mode label', () => { + render() + expect(screen.getByText('app.types.chatbot')).toBeInTheDocument() + }) + + it('should display mode labels for different modes', () => { + mockAppDetail = createAppDetail({ mode: AppModeEnum.ADVANCED_CHAT }) + render() + expect(screen.getByText('app.types.advanced')).toBeInTheDocument() + }) + + it('should render AppInfo component for detail expand', () => { + render() + expect(screen.getByTestId('app-info')).toBeInTheDocument() + expect(screen.getByTestId('app-info')).toHaveAttribute('data-only-detail', 'true') + }) + + it('should toggle portal open state when trigger is clicked', async () => { + const user = userEvent.setup() + render() + + const trigger = screen.getByTestId('portal-trigger') + await user.click(trigger) + + const portal = screen.getByTestId('portal-elem') + expect(portal).toHaveAttribute('data-open', 'true') + }) + + it('should render divider between app info and navigation', () => { + render() + expect(screen.getByTestId('divider')).toBeInTheDocument() + }) + + it('should render large app icon in dropdown content', () => { + render() + const icons = screen.getAllByTestId('app-icon') + const largeIcon = icons.find(icon => icon.getAttribute('data-size') === 'large') + expect(largeIcon).toBeInTheDocument() + }) + + it('should set detailExpand when clicking app info area', async () => { + const user = userEvent.setup() + render() + + const appName = screen.getByText('Test App') + const appInfoArea = appName.closest('[class*="cursor-pointer"]') + if (appInfoArea) + await user.click(appInfoArea) + }) + + it('should display workflow mode label', () => { + mockAppDetail = createAppDetail({ mode: AppModeEnum.WORKFLOW }) + render() + expect(screen.getByText('app.types.workflow')).toBeInTheDocument() + }) + + it('should display agent mode label', () => { + mockAppDetail = createAppDetail({ mode: AppModeEnum.AGENT_CHAT }) + render() + expect(screen.getByText('app.types.agent')).toBeInTheDocument() + }) + + it('should display completion mode label', () => { + mockAppDetail = createAppDetail({ mode: AppModeEnum.COMPLETION }) + render() + expect(screen.getByText('app.types.completion')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app-sidebar/__tests__/basic.spec.tsx b/web/app/components/app-sidebar/__tests__/basic.spec.tsx new file mode 100644 index 0000000000..67e708eb02 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/basic.spec.tsx @@ -0,0 +1,110 @@ +import { render, screen } from '@testing-library/react' +import * as React from 'react' +import AppBasic from '../basic' + +vi.mock('@/app/components/base/icons/src/vender/workflow', () => ({ + ApiAggregate: (props: React.SVGProps) => , + WindowCursor: (props: React.SVGProps) => , +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ popupContent }: { popupContent: React.ReactNode }) => ( +
{popupContent}
+ ), +})) + +vi.mock('../../base/app-icon', () => ({ + default: ({ icon, background, innerIcon, className }: { + icon?: string + background?: string + innerIcon?: React.ReactNode + className?: string + }) => ( +
+ {innerIcon} +
+ ), +})) + +describe('AppBasic', () => { + describe('Icon rendering', () => { + it('should render app icon when iconType is app with valid icon and background', () => { + render() + expect(screen.getByTestId('app-icon')).toBeInTheDocument() + }) + + it('should not render app icon when icon is empty', () => { + render() + expect(screen.queryByTestId('app-icon')).not.toBeInTheDocument() + }) + + it('should render api icon when iconType is api', () => { + render() + expect(screen.getByTestId('api-icon')).toBeInTheDocument() + }) + + it('should render webapp icon when iconType is webapp', () => { + render() + expect(screen.getByTestId('webapp-icon')).toBeInTheDocument() + }) + + it('should render dataset icon when iconType is dataset', () => { + render() + const icons = screen.getAllByTestId('app-icon') + expect(icons.length).toBeGreaterThan(0) + }) + + it('should render notion icon when iconType is notion', () => { + render() + const icons = screen.getAllByTestId('app-icon') + expect(icons.length).toBeGreaterThan(0) + }) + }) + + describe('Expand mode', () => { + it('should show name and type in expand mode', () => { + render() + expect(screen.getByText('My App')).toBeInTheDocument() + expect(screen.getByText('Chatbot')).toBeInTheDocument() + }) + + it('should hide name and type in collapse mode', () => { + render() + expect(screen.queryByText('My App')).not.toBeInTheDocument() + }) + + it('should show hover tip when provided', () => { + render() + expect(screen.getByTestId('tooltip')).toBeInTheDocument() + expect(screen.getByText('Some tip')).toBeInTheDocument() + }) + + it('should not show hover tip when not provided', () => { + render() + expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument() + }) + }) + + describe('Type display', () => { + it('should hide type when hideType is true', () => { + render() + expect(screen.queryByText('Chatbot')).not.toBeInTheDocument() + }) + + it('should show external tag when isExternal is true', () => { + render() + expect(screen.getByText('dataset.externalTag')).toBeInTheDocument() + }) + + it('should show type inline when isExtraInLine is true and hideType is false', () => { + render() + expect(screen.getByText('Chatbot')).toBeInTheDocument() + }) + + it('should apply custom text styles', () => { + render() + const nameContainer = screen.getByText('My App').parentElement + expect(nameContainer).toHaveClass('text-red-500') + }) + }) +}) diff --git a/web/app/components/app-sidebar/__tests__/dataset-sidebar-dropdown.spec.tsx b/web/app/components/app-sidebar/__tests__/dataset-sidebar-dropdown.spec.tsx new file mode 100644 index 0000000000..1f3a5f9ad8 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/dataset-sidebar-dropdown.spec.tsx @@ -0,0 +1,193 @@ +import type { DataSet } from '@/models/datasets' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import DatasetSidebarDropdown from '../dataset-sidebar-dropdown' + +let mockDataset: DataSet + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: DataSet }) => unknown) => + selector({ dataset: mockDataset }), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + useDatasetRelatedApps: () => ({ data: [] }), +})) + +vi.mock('@/hooks/use-knowledge', () => ({ + useKnowledge: () => ({ + formatIndexingTechniqueAndMethod: () => 'method-text', + }), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( +
{children}
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +vi.mock('../../base/app-icon', () => ({ + default: ({ size, icon }: { size: string, icon: string }) => ( +
+ ), +})) + +vi.mock('../../base/divider', () => ({ + default: () =>
, +})) + +vi.mock('../../base/effect', () => ({ + default: ({ className }: { className?: string }) =>
, +})) + +vi.mock('../../datasets/extra-info', () => ({ + default: ({ expand, documentCount }: { + relatedApps?: unknown[] + expand: boolean + documentCount: number + }) => ( +
+ ), +})) + +vi.mock('../dataset-info/dropdown', () => ({ + default: ({ expand }: { expand: boolean }) => ( +
+ ), +})) + +vi.mock('../nav-link', () => ({ + default: ({ name, href, mode, disabled }: { name: string, href: string, mode?: string, disabled?: boolean }) => ( + {name} + ), +})) + +const MockIcon = (props: React.SVGProps) => + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'dataset-1', + name: 'Test Dataset', + description: 'A test dataset', + provider: 'internal', + icon_info: { + icon: '📙', + icon_type: 'emoji', + icon_background: '#FFF4ED', + icon_url: '', + }, + doc_form: 'text_model' as DataSet['doc_form'], + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + document_count: 10, + runtime_mode: 'general', + retrieval_model_dict: { + search_method: 'semantic_search' as DataSet['retrieval_model_dict']['search_method'], + reranking_enable: false, + reranking_model: { reranking_provider_name: '', reranking_model_name: '' }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + ...overrides, +} as DataSet) + +const navigation = [ + { name: 'Documents', href: '/documents', icon: MockIcon, selectedIcon: MockIcon }, + { name: 'Settings', href: '/settings', icon: MockIcon, selectedIcon: MockIcon, disabled: true }, +] + +describe('DatasetSidebarDropdown', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createDataset() + }) + + it('should render trigger with dataset icon', () => { + render() + const icons = screen.getAllByTestId('app-icon') + const smallIcon = icons.find(i => i.getAttribute('data-size') === 'small') + expect(smallIcon).toBeInTheDocument() + expect(smallIcon).toHaveAttribute('data-icon', '📙') + }) + + it('should display dataset name in dropdown content', () => { + render() + expect(screen.getByText('Test Dataset')).toBeInTheDocument() + }) + + it('should display dataset description', () => { + render() + expect(screen.getByText('A test dataset')).toBeInTheDocument() + }) + + it('should not display description when empty', () => { + mockDataset = createDataset({ description: '' }) + render() + expect(screen.queryByText('A test dataset')).not.toBeInTheDocument() + }) + + it('should render navigation links', () => { + render() + expect(screen.getByTestId('nav-link-Documents')).toBeInTheDocument() + expect(screen.getByTestId('nav-link-Settings')).toBeInTheDocument() + }) + + it('should render ExtraInfo', () => { + render() + const extraInfo = screen.getByTestId('extra-info') + expect(extraInfo).toHaveAttribute('data-expand', 'true') + expect(extraInfo).toHaveAttribute('data-doc-count', '10') + }) + + it('should render Effect component', () => { + render() + expect(screen.getByTestId('effect')).toBeInTheDocument() + }) + + it('should render Dropdown component with expand=true', () => { + render() + expect(screen.getByTestId('dataset-dropdown')).toHaveAttribute('data-expand', 'true') + }) + + it('should show external tag for external provider', () => { + mockDataset = createDataset({ provider: 'external' }) + render() + expect(screen.getByText('dataset.externalTag')).toBeInTheDocument() + }) + + it('should use fallback icon info when icon_info is missing', () => { + mockDataset = createDataset({ icon_info: undefined as unknown as DataSet['icon_info'] }) + render() + const icons = screen.getAllByTestId('app-icon') + const fallbackIcon = icons.find(i => i.getAttribute('data-icon') === '📙') + expect(fallbackIcon).toBeInTheDocument() + }) + + it('should toggle dropdown open state on trigger click', async () => { + const user = userEvent.setup() + render() + + const trigger = screen.getByTestId('portal-trigger') + await user.click(trigger) + + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'true') + }) + + it('should render divider', () => { + render() + expect(screen.getByTestId('divider')).toBeInTheDocument() + }) + + it('should render medium app icon in content area', () => { + render() + const icons = screen.getAllByTestId('app-icon') + const mediumIcon = icons.find(i => i.getAttribute('data-size') === 'medium') + expect(mediumIcon).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app-sidebar/__tests__/index.spec.tsx b/web/app/components/app-sidebar/__tests__/index.spec.tsx new file mode 100644 index 0000000000..89db80e0f1 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/index.spec.tsx @@ -0,0 +1,298 @@ +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import AppDetailNav from '..' + +let mockAppSidebarExpand = 'expand' +const mockSetAppSidebarExpand = vi.fn() +let mockPathname = '/app/123/overview' + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: Record) => unknown) => selector({ + appDetail: { id: 'app-1', name: 'Test', mode: 'chat', icon: '🤖', icon_type: 'emoji', icon_background: '#fff' }, + appSidebarExpand: mockAppSidebarExpand, + setAppSidebarExpand: mockSetAppSidebarExpand, + }), +})) + +vi.mock('zustand/react/shallow', () => ({ + useShallow: (fn: unknown) => fn, +})) + +vi.mock('next/navigation', () => ({ + usePathname: () => mockPathname, +})) + +let mockIsHovering = true +let mockKeyPressCallback: ((e: { preventDefault: () => void }) => void) | null = null + +vi.mock('ahooks', () => ({ + useHover: () => mockIsHovering, + useKeyPress: (_key: string, cb: (e: { preventDefault: () => void }) => void) => { + mockKeyPressCallback = cb + }, +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: () => 'desktop', + MediaType: { mobile: 'mobile', desktop: 'desktop' }, +})) + +let mockSubscriptionCallback: ((v: unknown) => void) | null = null + +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + useSubscription: (cb: (v: unknown) => void) => { mockSubscriptionCallback = cb }, + }, + }), +})) + +vi.mock('../../base/divider', () => ({ + default: ({ className }: { className?: string }) =>
, +})) + +vi.mock('@/app/components/workflow/utils', () => ({ + getKeyboardKeyCodeBySystem: () => 'ctrl', +})) + +vi.mock('../app-info', () => ({ + default: ({ expand }: { expand: boolean }) => ( +
+ ), +})) + +vi.mock('../app-sidebar-dropdown', () => ({ + default: ({ navigation }: { navigation: unknown[] }) => ( +
+ ), +})) + +vi.mock('../dataset-info', () => ({ + default: ({ expand }: { expand: boolean }) => ( +
+ ), +})) + +vi.mock('../dataset-sidebar-dropdown', () => ({ + default: ({ navigation }: { navigation: unknown[] }) => ( +
+ ), +})) + +vi.mock('../nav-link', () => ({ + default: ({ name, href, mode }: { name: string, href: string, mode?: string }) => ( + {name} + ), +})) + +vi.mock('../toggle-button', () => ({ + default: ({ expand, handleToggle, className }: { expand: boolean, handleToggle: () => void, className?: string }) => ( + + ), +})) + +const MockIcon = (props: React.SVGProps) => + +const navigation = [ + { name: 'Overview', href: '/overview', icon: MockIcon, selectedIcon: MockIcon }, + { name: 'Logs', href: '/logs', icon: MockIcon, selectedIcon: MockIcon }, +] + +describe('AppDetailNav', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppSidebarExpand = 'expand' + mockPathname = '/app/123/overview' + mockIsHovering = true + }) + + describe('Normal sidebar mode', () => { + it('should render AppInfo when iconType is app', () => { + render() + expect(screen.getByTestId('app-info')).toBeInTheDocument() + expect(screen.getByTestId('app-info')).toHaveAttribute('data-expand', 'true') + }) + + it('should render DatasetInfo when iconType is dataset', () => { + render() + expect(screen.getByTestId('dataset-info')).toBeInTheDocument() + }) + + it('should render navigation links', () => { + render() + expect(screen.getByTestId('nav-link-Overview')).toBeInTheDocument() + expect(screen.getByTestId('nav-link-Logs')).toBeInTheDocument() + }) + + it('should render divider', () => { + render() + expect(screen.getByTestId('divider')).toBeInTheDocument() + }) + + it('should apply expanded width class', () => { + const { container } = render() + const sidebar = container.firstElementChild as HTMLElement + expect(sidebar).toHaveClass('w-[216px]') + }) + + it('should apply collapsed width class', () => { + mockAppSidebarExpand = 'collapse' + const { container } = render() + const sidebar = container.firstElementChild as HTMLElement + expect(sidebar).toHaveClass('w-14') + }) + + it('should render extraInfo when iconType is dataset and extraInfo provided', () => { + render( +
} + />, + ) + expect(screen.getByTestId('extra-info')).toBeInTheDocument() + }) + + it('should not render extraInfo when iconType is app', () => { + render( +
} + />, + ) + expect(screen.queryByTestId('extra-info')).not.toBeInTheDocument() + }) + }) + + describe('Workflow canvas mode', () => { + it('should render AppSidebarDropdown when in workflow canvas with hidden header', () => { + mockPathname = '/app/123/workflow' + localStorage.setItem('workflow-canvas-maximize', 'true') + + render() + + expect(screen.getByTestId('app-sidebar-dropdown')).toBeInTheDocument() + expect(screen.queryByTestId('app-info')).not.toBeInTheDocument() + }) + + it('should render normal sidebar when workflow canvas is not maximized', () => { + mockPathname = '/app/123/workflow' + localStorage.setItem('workflow-canvas-maximize', 'false') + + render() + + expect(screen.queryByTestId('app-sidebar-dropdown')).not.toBeInTheDocument() + expect(screen.getByTestId('app-info')).toBeInTheDocument() + }) + }) + + describe('Pipeline canvas mode', () => { + it('should render DatasetSidebarDropdown when in pipeline canvas with hidden header', () => { + mockPathname = '/dataset/123/pipeline' + localStorage.setItem('workflow-canvas-maximize', 'true') + + render() + + expect(screen.getByTestId('dataset-sidebar-dropdown')).toBeInTheDocument() + expect(screen.queryByTestId('app-info')).not.toBeInTheDocument() + }) + }) + + describe('Navigation mode', () => { + it('should pass expand mode to nav links when expanded', () => { + render() + expect(screen.getByTestId('nav-link-Overview')).toHaveAttribute('data-mode', 'expand') + }) + + it('should pass collapse mode to nav links when collapsed', () => { + mockAppSidebarExpand = 'collapse' + render() + expect(screen.getByTestId('nav-link-Overview')).toHaveAttribute('data-mode', 'collapse') + }) + }) + + describe('Toggle behavior', () => { + it('should call setAppSidebarExpand on toggle', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('toggle-button')) + + expect(mockSetAppSidebarExpand).toHaveBeenCalledWith('collapse') + }) + + it('should toggle from collapse to expand', async () => { + const user = userEvent.setup() + mockAppSidebarExpand = 'collapse' + render() + + await user.click(screen.getByTestId('toggle-button')) + + expect(mockSetAppSidebarExpand).toHaveBeenCalledWith('expand') + }) + }) + + describe('Sidebar persistence', () => { + it('should persist expand state to localStorage', () => { + render() + expect(localStorage.setItem).toHaveBeenCalledWith('app-detail-collapse-or-expand', 'expand') + }) + }) + + describe('Disabled navigation items', () => { + it('should render disabled navigation items', () => { + const navWithDisabled = [ + ...navigation, + { name: 'Disabled', href: '/disabled', icon: MockIcon, selectedIcon: MockIcon, disabled: true }, + ] + render() + expect(screen.getByTestId('nav-link-Disabled')).toBeInTheDocument() + }) + }) + + describe('Event emitter subscription', () => { + it('should handle workflow-canvas-maximize event', () => { + mockPathname = '/app/123/workflow' + render() + + const cb = mockSubscriptionCallback + expect(cb).not.toBeNull() + act(() => { + cb!({ type: 'workflow-canvas-maximize', payload: true }) + }) + }) + + it('should ignore non-maximize events', () => { + render() + + const cb = mockSubscriptionCallback + act(() => { + cb!({ type: 'other-event' }) + }) + }) + }) + + describe('Keyboard shortcut', () => { + it('should toggle sidebar on ctrl+b', () => { + render() + + const cb = mockKeyPressCallback + expect(cb).not.toBeNull() + act(() => { + cb!({ preventDefault: vi.fn() }) + }) + expect(mockSetAppSidebarExpand).toHaveBeenCalledWith('collapse') + }) + }) + + describe('Hover-based toggle button visibility', () => { + it('should hide toggle button when not hovering', () => { + mockIsHovering = false + render() + expect(screen.queryByTestId('toggle-button')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx b/web/app/components/app-sidebar/__tests__/sidebar-animation-issues.spec.tsx similarity index 80% rename from web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx rename to web/app/components/app-sidebar/__tests__/sidebar-animation-issues.spec.tsx index 5d85b99d9a..fef65fcad3 100644 --- a/web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/sidebar-animation-issues.spec.tsx @@ -143,12 +143,6 @@ describe('Sidebar Animation Issues Reproduction', () => { expect(toggleSection).toHaveClass('px-4') // Same consistent padding expect(toggleSection).not.toHaveClass('px-5') expect(toggleSection).not.toHaveClass('px-6') - - // THE FIX: px-4 in both states prevents position movement - console.log('✅ Issue #1 FIXED: Toggle button now has consistent padding') - console.log(' - Before: px-4 (collapsed) vs px-6 (expanded) - 8px difference') - console.log(' - After: px-4 (both states) - 0px difference') - console.log(' - Result: No button position movement during transition') }) it('should verify sidebar width animation is working correctly', () => { @@ -164,8 +158,6 @@ describe('Sidebar Animation Issues Reproduction', () => { // Expanded state rerender() expect(container).toHaveClass('w-[216px]') - - console.log('✅ Sidebar width transition is properly configured') }) }) @@ -188,13 +180,6 @@ describe('Sidebar Animation Issues Reproduction', () => { expect(link).toHaveClass('px-3') // 12px padding (+2px) expect(icon).toHaveClass('mr-2') // 8px margin (+8px) expect(screen.getByTestId('nav-text-Orchestrate')).toBeInTheDocument() - - // THE BUG: Multiple simultaneous changes create squeeze effect - console.log('🐛 Issue #2 Reproduced: Text squeeze effect from multiple layout changes') - console.log(' - Link padding: px-2.5 → px-3 (+2px)') - console.log(' - Icon margin: mr-0 → mr-2 (+8px)') - console.log(' - Text appears: none → visible (abrupt)') - console.log(' - Result: Text appears with squeeze effect due to layout shifts') }) it('should document the abrupt text rendering issue', () => { @@ -207,10 +192,6 @@ describe('Sidebar Animation Issues Reproduction', () => { // Text suddenly appears - no transition expect(screen.getByTestId('nav-text-API Access')).toBeInTheDocument() - - console.log('🐛 Issue #2 Detail: Conditional rendering {mode === "expand" && name}') - console.log(' - Problem: Text appears/disappears abruptly without transition') - console.log(' - Should use: opacity or width transition for smooth appearance') }) }) @@ -234,13 +215,6 @@ describe('Sidebar Animation Issues Reproduction', () => { expect(iconContainer).toHaveClass('gap-1') expect(iconContainer).not.toHaveClass('justify-between') expect(appIcon).toHaveAttribute('data-size', 'small') - - // THE BUG: Layout mode switch causes icon to "bounce" - console.log('🐛 Issue #3 Reproduced: Icon bounce from layout mode switching') - console.log(' - Layout change: justify-between → flex-col gap-1') - console.log(' - Icon size: large (40px) → small (24px)') - console.log(' - Transition: transition-all causes excessive animation') - console.log(' - Result: Icon appears to bounce to right then back during collapse') }) it('should identify the problematic transition-all property', () => { @@ -251,10 +225,6 @@ describe('Sidebar Animation Issues Reproduction', () => { // The problematic broad transition expect(computedStyle.transition).toContain('all') - - console.log('🐛 Issue #3 Detail: transition-all affects ALL CSS properties') - console.log(' - Problem: Animates layout properties that should not transition') - console.log(' - Solution: Use specific transition properties instead of "all"') }) }) @@ -276,7 +246,6 @@ describe('Sidebar Animation Issues Reproduction', () => { // Initial state verification expect(expanded).toBe(false) - console.log('🔄 Starting interactive test - all issues will be reproduced') // Simulate toggle click fireEvent.click(toggleButton) @@ -287,11 +256,6 @@ describe('Sidebar Animation Issues Reproduction', () => {
, ) - - console.log('✨ All three issues successfully reproduced in interactive test:') - console.log(' 1. Toggle button position movement (padding inconsistency)') - console.log(' 2. Navigation text squeeze effect (multiple layout changes)') - console.log(' 3. App icon bounce animation (layout mode switching)') }) }) }) diff --git a/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx similarity index 65% rename from web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx rename to web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx index f7e91b3dea..fb19833dd2 100644 --- a/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx @@ -13,7 +13,7 @@ vi.mock('next/navigation', () => ({ // Mock classnames utility vi.mock('@/utils/classnames', () => ({ - default: (...classes: any[]) => classes.filter(Boolean).join(' '), + default: (...classes: unknown[]) => classes.filter(Boolean).join(' '), })) // Simplified NavLink component to test the fix @@ -101,12 +101,6 @@ describe('Text Squeeze Fix Verification', () => { expect(textElement).toHaveClass('whitespace-nowrap') expect(textElement).toHaveClass('transition-all') - console.log('✅ NavLink Collapsed State:') - console.log(' - Text is in DOM but visually hidden') - console.log(' - Uses opacity-0 and w-0 for hiding') - console.log(' - Has whitespace-nowrap to prevent wrapping') - console.log(' - Has transition-all for smooth animation') - // Switch to expanded state rerender() @@ -115,13 +109,6 @@ describe('Text Squeeze Fix Verification', () => { expect(expandedText).toHaveClass('opacity-100') expect(expandedText).toHaveClass('w-auto') expect(expandedText).not.toHaveClass('pointer-events-none') - - console.log('✅ NavLink Expanded State:') - console.log(' - Text is visible with opacity-100') - console.log(' - Uses w-auto for natural width') - console.log(' - No layout jumps during transition') - - console.log('🎯 NavLink Fix Result: Text squeeze effect ELIMINATED') }) it('should verify smooth transition properties', () => { @@ -131,11 +118,6 @@ describe('Text Squeeze Fix Verification', () => { expect(textElement).toHaveClass('transition-all') expect(textElement).toHaveClass('duration-200') expect(textElement).toHaveClass('ease-in-out') - - console.log('✅ Transition Properties Verified:') - console.log(' - transition-all: Smooth property changes') - console.log(' - duration-200: 200ms transition time') - console.log(' - ease-in-out: Smooth easing function') }) }) @@ -159,11 +141,6 @@ describe('Text Squeeze Fix Verification', () => { expect(appName).toHaveClass('whitespace-nowrap') expect(appType).toHaveClass('whitespace-nowrap') - console.log('✅ AppInfo Collapsed State:') - console.log(' - Text container is in DOM but visually hidden') - console.log(' - App name and type elements always present') - console.log(' - Uses whitespace-nowrap to prevent wrapping') - // Switch to expanded state rerender() @@ -172,13 +149,6 @@ describe('Text Squeeze Fix Verification', () => { expect(expandedContainer).toHaveClass('opacity-100') expect(expandedContainer).toHaveClass('w-auto') expect(expandedContainer).not.toHaveClass('pointer-events-none') - - console.log('✅ AppInfo Expanded State:') - console.log(' - Text container is visible with opacity-100') - console.log(' - Uses w-auto for natural width') - console.log(' - No layout jumps during transition') - - console.log('🎯 AppInfo Fix Result: Text squeeze effect ELIMINATED') }) it('should verify transition properties on text container', () => { @@ -188,45 +158,11 @@ describe('Text Squeeze Fix Verification', () => { expect(textContainer).toHaveClass('transition-all') expect(textContainer).toHaveClass('duration-200') expect(textContainer).toHaveClass('ease-in-out') - - console.log('✅ AppInfo Transition Properties Verified:') - console.log(' - Container has smooth CSS transitions') - console.log(' - Same 200ms duration as NavLink for consistency') }) }) describe('Fix Strategy Comparison', () => { it('should document the fix strategy differences', () => { - console.log('\n📋 TEXT SQUEEZE FIX STRATEGY COMPARISON') - console.log('='.repeat(60)) - - console.log('\n❌ BEFORE (Problematic):') - console.log(' NavLink: {mode === "expand" && name}') - console.log(' AppInfo: {expand && (
...
)}') - console.log(' Problem: Conditional rendering causes abrupt appearance') - console.log(' Result: Text "squeezes" from center during layout changes') - - console.log('\n✅ AFTER (Fixed):') - console.log(' NavLink: {name}') - console.log(' AppInfo:
...
') - console.log(' Solution: CSS controls visibility, element always in DOM') - console.log(' Result: Smooth opacity and width transitions') - - console.log('\n🎯 KEY FIX PRINCIPLES:') - console.log(' 1. ✅ Always keep text elements in DOM') - console.log(' 2. ✅ Use opacity for show/hide transitions') - console.log(' 3. ✅ Use width (w-0/w-auto) for layout control') - console.log(' 4. ✅ Add whitespace-nowrap to prevent wrapping') - console.log(' 5. ✅ Use pointer-events-none when hidden') - console.log(' 6. ✅ Add overflow-hidden for clean hiding') - - console.log('\n🚀 BENEFITS:') - console.log(' - No more abrupt text appearance') - console.log(' - Smooth 200ms transitions') - console.log(' - No layout jumps or shifts') - console.log(' - Consistent animation timing') - console.log(' - Better user experience') - // Always pass documentation test expect(true).toBe(true) }) diff --git a/web/app/components/app-sidebar/__tests__/toggle-button.spec.tsx b/web/app/components/app-sidebar/__tests__/toggle-button.spec.tsx new file mode 100644 index 0000000000..1a117ac5e3 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/toggle-button.spec.tsx @@ -0,0 +1,46 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import ToggleButton from '../toggle-button' + +vi.mock('@/app/components/workflow/shortcuts-name', () => ({ + default: ({ keys }: { keys: string[] }) => ( + {keys.join('+')} + ), +})) + +describe('ToggleButton', () => { + it('should render collapse arrow when expanded', () => { + render() + const button = screen.getByRole('button') + expect(button).toBeInTheDocument() + }) + + it('should render expand arrow when collapsed', () => { + render() + const button = screen.getByRole('button') + expect(button).toBeInTheDocument() + }) + + it('should call handleToggle when clicked', async () => { + const user = userEvent.setup() + const handleToggle = vi.fn() + render() + + await user.click(screen.getByRole('button')) + + expect(handleToggle).toHaveBeenCalledTimes(1) + }) + + it('should apply custom className', () => { + render() + const button = screen.getByRole('button') + expect(button).toHaveClass('custom-class') + }) + + it('should have rounded-full style', () => { + render() + const button = screen.getByRole('button') + expect(button).toHaveClass('rounded-full') + }) +}) diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 3603ded71c..90199fe1bb 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -1,4 +1,4 @@ -import type { Operation } from './app-operations' +import type { Operation } from './app-info/app-operations' import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' @@ -21,7 +21,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie import { useStore as useAppStore } from '@/app/components/app/store' import Button from '@/app/components/base/button' import ContentDialog from '@/app/components/base/content-dialog' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' @@ -35,7 +35,7 @@ import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' import { downloadBlob } from '@/utils/download' import AppIcon from '../base/app-icon' -import AppOperations from './app-operations' +import AppOperations from './app-info/app-operations' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false, diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx new file mode 100644 index 0000000000..3082eb3789 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx @@ -0,0 +1,298 @@ +import type { App, AppSSO } from '@/types/app' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppInfoDetailPanel from '../app-info-detail-panel' + +vi.mock('../../../base/app-icon', () => ({ + default: ({ size, icon }: { size: string, icon: string }) => ( +
+ ), +})) + +vi.mock('@/app/components/base/content-dialog', () => ({ + default: ({ show, onClose, children, className }: { + show: boolean + onClose: () => void + children: React.ReactNode + className?: string + }) => ( + show + ? ( +
+ + {children} +
+ ) + : null + ), +})) + +vi.mock('@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view', () => ({ + default: ({ appId }: { appId: string }) => ( +
+ ), +})) + +vi.mock('@/app/components/base/button', () => ({ + default: ({ children, onClick, className, size, variant }: { + children: React.ReactNode + onClick?: () => void + className?: string + size?: string + variant?: string + }) => ( + + ), +})) + +vi.mock('../app-operations', () => ({ + default: ({ primaryOperations, secondaryOperations }: { + primaryOperations?: Array<{ id: string, title: string, onClick: () => void }> + secondaryOperations?: Array<{ id: string, title: string, onClick: () => void, type?: string }> + }) => ( +
+ {primaryOperations?.map(op => ( + + ))} + {secondaryOperations?.map(op => ( + op.type === 'divider' + ? + : + ))} +
+ ), +})) + +const createAppDetail = (overrides: Partial = {}): App & Partial => ({ + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: 'A test description', + use_icon_as_answer_icon: false, + ...overrides, +} as App & Partial) + +describe('AppInfoDetailPanel', () => { + const defaultProps = { + appDetail: createAppDetail(), + show: true, + onClose: vi.fn(), + openModal: vi.fn(), + exportCheck: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should not render when show is false', () => { + render() + expect(screen.queryByTestId('content-dialog')).not.toBeInTheDocument() + }) + + it('should render dialog when show is true', () => { + render() + expect(screen.getByTestId('content-dialog')).toBeInTheDocument() + }) + + it('should display app name', () => { + render() + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should display app mode label', () => { + render() + expect(screen.getByText('app.types.chatbot')).toBeInTheDocument() + }) + + it('should display description when available', () => { + render() + expect(screen.getByText('A test description')).toBeInTheDocument() + }) + + it('should not display description when empty', () => { + render() + expect(screen.queryByText('A test description')).not.toBeInTheDocument() + }) + + it('should not display description when undefined', () => { + render() + expect(screen.queryByText('A test description')).not.toBeInTheDocument() + }) + + it('should render CardView with correct appId', () => { + render() + const cardView = screen.getByTestId('card-view') + expect(cardView).toHaveAttribute('data-app-id', 'app-1') + }) + + it('should render app icon with large size', () => { + render() + const icon = screen.getByTestId('app-icon') + expect(icon).toHaveAttribute('data-size', 'large') + }) + }) + + describe('Operations', () => { + it('should render edit, duplicate, and export operations', () => { + render() + expect(screen.getByTestId('op-edit')).toBeInTheDocument() + expect(screen.getByTestId('op-duplicate')).toBeInTheDocument() + expect(screen.getByTestId('op-export')).toBeInTheDocument() + }) + + it('should call openModal with edit when edit is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('op-edit')) + + expect(defaultProps.openModal).toHaveBeenCalledWith('edit') + }) + + it('should call openModal with duplicate when duplicate is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('op-duplicate')) + + expect(defaultProps.openModal).toHaveBeenCalledWith('duplicate') + }) + + it('should call exportCheck when export is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('op-export')) + + expect(defaultProps.exportCheck).toHaveBeenCalledTimes(1) + }) + + it('should render delete operation', () => { + render() + expect(screen.getByTestId('op-delete')).toBeInTheDocument() + }) + + it('should call openModal with delete when delete is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('op-delete')) + + expect(defaultProps.openModal).toHaveBeenCalledWith('delete') + }) + }) + + describe('Import DSL option', () => { + it('should show import DSL for advanced_chat mode', () => { + render( + , + ) + expect(screen.getByTestId('op-import')).toBeInTheDocument() + }) + + it('should show import DSL for workflow mode', () => { + render( + , + ) + expect(screen.getByTestId('op-import')).toBeInTheDocument() + }) + + it('should not show import DSL for chat mode', () => { + render() + expect(screen.queryByTestId('op-import')).not.toBeInTheDocument() + }) + + it('should call openModal with importDSL when import is clicked', async () => { + const user = userEvent.setup() + render( + , + ) + await user.click(screen.getByTestId('op-import')) + expect(defaultProps.openModal).toHaveBeenCalledWith('importDSL') + }) + + it('should render divider in secondary operations', async () => { + const user = userEvent.setup() + render() + const divider = screen.getByTestId('op-divider-1') + expect(divider).toBeInTheDocument() + await user.click(divider) + }) + }) + + describe('Switch operation', () => { + it('should show switch button for chat mode', () => { + render() + expect(screen.getByText('app.switch')).toBeInTheDocument() + }) + + it('should show switch button for completion mode', () => { + render( + , + ) + expect(screen.getByText('app.switch')).toBeInTheDocument() + }) + + it('should not show switch button for workflow mode', () => { + render( + , + ) + expect(screen.queryByText('app.switch')).not.toBeInTheDocument() + }) + + it('should not show switch button for advanced_chat mode', () => { + render( + , + ) + expect(screen.queryByText('app.switch')).not.toBeInTheDocument() + }) + + it('should call openModal with switch when switch button is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('app.switch')) + + expect(defaultProps.openModal).toHaveBeenCalledWith('switch') + }) + }) + + describe('Dialog interactions', () => { + it('should call onClose when dialog close button is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('dialog-close')) + + expect(defaultProps.onClose).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx new file mode 100644 index 0000000000..f8612e8057 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx @@ -0,0 +1,264 @@ +import type { App, AppSSO } from '@/types/app' +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppInfoModals from '../app-info-modals' + +vi.mock('next/dynamic', () => ({ + default: (loader: () => Promise<{ default: React.ComponentType }>) => { + const LazyComp = React.lazy(loader) + return function DynamicWrapper(props: Record) { + return React.createElement( + React.Suspense, + { fallback: null }, + React.createElement(LazyComp, props), + ) + } + }, +})) + +vi.mock('@/app/components/app/switch-app-modal', () => ({ + default: ({ show, onClose }: { show: boolean, onClose: () => void }) => ( + show ?
: null + ), +})) + +vi.mock('@/app/components/explore/create-app-modal', () => ({ + default: ({ show, onHide, isEditModal }: { show: boolean, onHide: () => void, isEditModal?: boolean }) => ( + show ?
: null + ), +})) + +vi.mock('@/app/components/app/duplicate-modal', () => ({ + default: ({ show, onHide }: { show: boolean, onHide: () => void }) => ( + show ?
: null + ), +})) + +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ isShow, title, onConfirm, onCancel }: { + isShow: boolean + title: string + onConfirm: () => void + onCancel: () => void + }) => ( + isShow + ? ( +
+ + +
+ ) + : null + ), +})) + +vi.mock('@/app/components/workflow/update-dsl-modal', () => ({ + default: ({ onCancel, onBackup }: { onCancel: () => void, onBackup: () => void }) => ( +
+ + +
+ ), +})) + +vi.mock('@/app/components/workflow/dsl-export-confirm-modal', () => ({ + default: ({ onConfirm, onClose }: { onConfirm: (include?: boolean) => void, onClose: () => void }) => ( +
+ + +
+ ), +})) + +const createAppDetail = (overrides: Partial = {}): App & Partial => ({ + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: '', + use_icon_as_answer_icon: false, + max_active_requests: null, + ...overrides, +} as App & Partial) + +const defaultProps = { + appDetail: createAppDetail(), + closeModal: vi.fn(), + secretEnvList: [] as never[], + setSecretEnvList: vi.fn(), + onEdit: vi.fn(), + onCopy: vi.fn(), + onExport: vi.fn(), + exportCheck: vi.fn(), + handleConfirmExport: vi.fn(), + onConfirmDelete: vi.fn(), +} + +describe('AppInfoModals', () => { + beforeAll(async () => { + await new Promise(resolve => setTimeout(resolve, 0)) + }) + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render nothing when activeModal is null', async () => { + await act(async () => { + render() + }) + expect(screen.queryByTestId('switch-modal')).not.toBeInTheDocument() + expect(screen.queryByTestId('confirm-modal')).not.toBeInTheDocument() + }) + + it('should render SwitchAppModal when activeModal is switch', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + }) + + it('should render CreateAppModal in edit mode when activeModal is edit', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + expect(screen.getByTestId('edit-modal')).toBeInTheDocument() + }) + }) + + it('should render DuplicateAppModal when activeModal is duplicate', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + }) + + it('should render Confirm for delete when activeModal is delete', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + const confirm = screen.getByTestId('confirm-modal') + expect(confirm).toBeInTheDocument() + expect(confirm).toHaveAttribute('data-title', 'app.deleteAppConfirmTitle') + }) + }) + + it('should render UpdateDSLModal when activeModal is importDSL', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + expect(screen.getByTestId('import-dsl-modal')).toBeInTheDocument() + }) + }) + + it('should render export warning Confirm when activeModal is exportWarning', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + const confirm = screen.getByTestId('confirm-modal') + expect(confirm).toBeInTheDocument() + expect(confirm).toHaveAttribute('data-title', 'workflow.sidebar.exportWarning') + }) + }) + + it('should render DSLExportConfirmModal when secretEnvList is not empty', async () => { + await act(async () => { + render( + , + ) + }) + await waitFor(() => { + expect(screen.getByTestId('dsl-export-confirm-modal')).toBeInTheDocument() + }) + }) + + it('should not render DSLExportConfirmModal when secretEnvList is empty', async () => { + await act(async () => { + render() + }) + expect(screen.queryByTestId('dsl-export-confirm-modal')).not.toBeInTheDocument() + }) + + it('should call closeModal when cancel on delete modal', async () => { + const user = userEvent.setup() + await act(async () => { + render() + }) + + await waitFor(() => expect(screen.getByText('Cancel')).toBeInTheDocument()) + await user.click(screen.getByText('Cancel')) + + expect(defaultProps.closeModal).toHaveBeenCalledTimes(1) + }) + + it('should call onConfirmDelete when confirm on delete modal', async () => { + const user = userEvent.setup() + await act(async () => { + render() + }) + + await waitFor(() => expect(screen.getByText('Confirm')).toBeInTheDocument()) + await user.click(screen.getByText('Confirm')) + + expect(defaultProps.onConfirmDelete).toHaveBeenCalledTimes(1) + }) + + it('should call handleConfirmExport when confirm on export warning', async () => { + const user = userEvent.setup() + await act(async () => { + render() + }) + + await waitFor(() => expect(screen.getByText('Confirm')).toBeInTheDocument()) + await user.click(screen.getByText('Confirm')) + + expect(defaultProps.handleConfirmExport).toHaveBeenCalledTimes(1) + }) + + it('should call exportCheck when backup on importDSL modal', async () => { + const user = userEvent.setup() + await act(async () => { + render() + }) + + await waitFor(() => expect(screen.getByText('Backup')).toBeInTheDocument()) + await user.click(screen.getByText('Backup')) + + expect(defaultProps.exportCheck).toHaveBeenCalledTimes(1) + }) + + it('should call setSecretEnvList with empty array when closing DSLExportConfirmModal', async () => { + const user = userEvent.setup() + await act(async () => { + render( + , + ) + }) + + await waitFor(() => expect(screen.getByText('Close Export')).toBeInTheDocument()) + await user.click(screen.getByText('Close Export')) + + expect(defaultProps.setSecretEnvList).toHaveBeenCalledWith([]) + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-trigger.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-trigger.spec.tsx new file mode 100644 index 0000000000..65d660876c --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-trigger.spec.tsx @@ -0,0 +1,99 @@ +import type { App, AppSSO } from '@/types/app' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppInfoTrigger from '../app-info-trigger' + +vi.mock('../../../base/app-icon', () => ({ + default: ({ size, icon, background }: { + size: string + icon: string + background: string + iconType?: string + imageUrl?: string + }) => ( +
+ ), +})) + +const createAppDetail = (overrides: Partial = {}): App & Partial => ({ + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: 'A test app', + use_icon_as_answer_icon: false, + ...overrides, +} as App & Partial) + +describe('AppInfoTrigger', () => { + it('should render app icon with correct size when expanded', () => { + render() + const icon = screen.getByTestId('app-icon') + expect(icon).toHaveAttribute('data-size', 'large') + }) + + it('should render app icon with small size when collapsed', () => { + render() + const icon = screen.getByTestId('app-icon') + expect(icon).toHaveAttribute('data-size', 'small') + }) + + it('should show app name when expanded', () => { + render() + expect(screen.getByText('My Chatbot')).toBeInTheDocument() + }) + + it('should not show app name when collapsed', () => { + render() + expect(screen.queryByText('My Chatbot')).not.toBeInTheDocument() + }) + + it('should show app mode label when expanded', () => { + render() + expect(screen.getByText('app.types.advanced')).toBeInTheDocument() + }) + + it('should not show mode label when collapsed', () => { + render() + expect(screen.queryByText('app.types.chatbot')).not.toBeInTheDocument() + }) + + it('should call onClick when button is clicked', async () => { + const user = userEvent.setup() + const onClick = vi.fn() + render() + + await user.click(screen.getByRole('button')) + + expect(onClick).toHaveBeenCalledTimes(1) + }) + + it('should show settings icon in expanded and collapsed states', () => { + const { container, rerender } = render( + , + ) + expect(container.querySelector('svg')).toBeInTheDocument() + + rerender() + expect(container.querySelector('svg')).toBeInTheDocument() + }) + + it('should apply ml-1 class to icon wrapper when collapsed', () => { + render( + , + ) + const iconWrapper = screen.getByTestId('app-icon').parentElement + expect(iconWrapper).toHaveClass('ml-1') + }) + + it('should not apply ml-1 class when expanded', () => { + render() + const iconWrapper = screen.getByTestId('app-icon').parentElement + expect(iconWrapper).not.toHaveClass('ml-1') + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-mode-labels.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/app-mode-labels.spec.ts new file mode 100644 index 0000000000..ac4318278c --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-mode-labels.spec.ts @@ -0,0 +1,34 @@ +import type { TFunction } from 'i18next' +import { AppModeEnum } from '@/types/app' +import { getAppModeLabel } from '../app-mode-labels' + +describe('getAppModeLabel', () => { + const t: TFunction = ((key: string, options?: Record) => { + const ns = (options?.ns as string | undefined) ?? '' + return ns ? `${ns}.${key}` : key + }) as TFunction + + it('should return advanced chat label', () => { + expect(getAppModeLabel(AppModeEnum.ADVANCED_CHAT, t)).toBe('app.types.advanced') + }) + + it('should return agent chat label', () => { + expect(getAppModeLabel(AppModeEnum.AGENT_CHAT, t)).toBe('app.types.agent') + }) + + it('should return chatbot label', () => { + expect(getAppModeLabel(AppModeEnum.CHAT, t)).toBe('app.types.chatbot') + }) + + it('should return completion label', () => { + expect(getAppModeLabel(AppModeEnum.COMPLETION, t)).toBe('app.types.completion') + }) + + it('should return workflow label for unknown mode', () => { + expect(getAppModeLabel('unknown-mode', t)).toBe('app.types.workflow') + }) + + it('should return workflow label for workflow mode', () => { + expect(getAppModeLabel(AppModeEnum.WORKFLOW, t)).toBe('app.types.workflow') + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-operations.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-operations.spec.tsx new file mode 100644 index 0000000000..1df23c2d20 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-operations.spec.tsx @@ -0,0 +1,253 @@ +import type { Operation } from '../app-operations' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import AppOperations from '../app-operations' + +vi.mock('../../../base/button', () => ({ + default: ({ children, onClick, className, size, variant, id, tabIndex, ...rest }: { + 'children': React.ReactNode + 'onClick'?: () => void + 'className'?: string + 'size'?: string + 'variant'?: string + 'id'?: string + 'tabIndex'?: number + 'data-targetid'?: string + }) => ( + + ), +})) + +vi.mock('../../../base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( +
{children}
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children, className }: { children: React.ReactNode, className?: string }) => ( +
{children}
+ ), +})) + +const createOperation = (id: string, title: string, type?: 'divider'): Operation => ({ + id, + title, + icon: , + onClick: vi.fn(), + type, +}) + +function setupDomMeasurements(navWidth: number, moreWidth: number, childWidths: number[]) { + const originalClientWidth = Object.getOwnPropertyDescriptor(HTMLElement.prototype, 'clientWidth') + + Object.defineProperty(HTMLElement.prototype, 'clientWidth', { + configurable: true, + get(this: HTMLElement) { + if (this.getAttribute('aria-hidden') === 'true') + return navWidth + if (this.id === 'more-measure') + return moreWidth + if (this.dataset.targetid) { + const idx = Array.from(this.parentElement?.children ?? []).indexOf(this) + return childWidths[idx] ?? 50 + } + return 0 + }, + }) + + return () => { + if (originalClientWidth) + Object.defineProperty(HTMLElement.prototype, 'clientWidth', originalClientWidth) + } +} + +describe('AppOperations', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering with operations prop', () => { + it('should render measurement container', () => { + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + const { container } = render() + expect(container.querySelector('[aria-hidden="true"]')).toBeInTheDocument() + }) + + it('should render operation buttons in measurement container', () => { + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + + it('should use operations as primary when provided', () => { + const ops = [createOperation('edit', 'Edit')] + const secondary = [createOperation('delete', 'Delete')] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + }) + + describe('Rendering with primaryOperations and secondaryOperations', () => { + it('should render primary operations in measurement container', () => { + const primary = [createOperation('edit', 'Edit')] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + + it('should use secondary operations when provided', () => { + const primary = [createOperation('edit', 'Edit')] + const secondary = [createOperation('delete', 'Delete')] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + + it('should use empty operations array when neither operations nor primaryOperations provided', () => { + const { container } = render() + expect(container).toBeInTheDocument() + }) + }) + + describe('Overflow behavior', () => { + it('should show all operations when container is wide enough', () => { + const cleanup = setupDomMeasurements(500, 60, [80, 80]) + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + + render() + + cleanup() + }) + + it('should move operations to more menu when container is narrow', () => { + const cleanup = setupDomMeasurements(100, 60, [80, 80]) + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + + render() + + cleanup() + }) + + it('should show last item without more button if it fits alone', () => { + const cleanup = setupDomMeasurements(90, 60, [80]) + const ops = [createOperation('edit', 'Edit')] + + render() + + cleanup() + }) + }) + + describe('More button', () => { + it('should render more button text in measurement container', () => { + const ops = [createOperation('edit', 'Edit')] + render() + const moreButtons = screen.getAllByText('common.operation.more') + expect(moreButtons.length).toBeGreaterThanOrEqual(1) + }) + + it('should handle trigger more click', async () => { + const cleanup = setupDomMeasurements(100, 60, [80, 80]) + const user = userEvent.setup() + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + const secondary = [createOperation('delete', 'Delete')] + + render() + + const trigger = screen.queryByTestId('portal-trigger') + if (trigger) + await user.click(trigger) + + cleanup() + }) + }) + + describe('Visible operations click', () => { + it('should call onClick when a visible operation is clicked', async () => { + const cleanup = setupDomMeasurements(500, 60, [80, 80]) + const user = userEvent.setup() + const editOp = createOperation('edit', 'Edit') + const copyOp = createOperation('copy', 'Copy') + + render() + + const visibleButtons = screen.getAllByText('Edit') + const clickableButton = visibleButtons.find(btn => btn.closest('button')?.tabIndex !== -1) + if (clickableButton) + await user.click(clickableButton) + + cleanup() + }) + }) + + describe('Divider operations', () => { + it('should filter out divider operations from inline display', () => { + const ops = [ + createOperation('edit', 'Edit'), + createOperation('div-1', '', 'divider'), + createOperation('delete', 'Delete'), + ] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + }) + + describe('Gap styling', () => { + it('should apply gap to measurement and visible containers', () => { + const ops = [createOperation('edit', 'Edit')] + const { container } = render() + const hiddenContainer = container.querySelector('[aria-hidden="true"]') + expect(hiddenContainer).toHaveStyle({ gap: '8px' }) + }) + + it('should apply gap to visible container', () => { + const ops = [createOperation('edit', 'Edit')] + const { container } = render() + const containers = container.querySelectorAll('div[style]') + const visibleContainer = Array.from(containers).find( + el => el.getAttribute('aria-hidden') !== 'true', + ) + if (visibleContainer) + expect(visibleContainer).toHaveStyle({ gap: '4px' }) + }) + }) + + describe('More menu content', () => { + it('should render divider items in more menu', () => { + const cleanup = setupDomMeasurements(100, 60, [80, 80]) + const primary = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + const secondary = [ + createOperation('divider-1', '', 'divider'), + createOperation('delete', 'Delete'), + ] + + render() + + cleanup() + }) + }) + + describe('Empty inline operations', () => { + it('should handle when all operations are dividers', () => { + const ops = [createOperation('div-1', '', 'divider'), createOperation('div-2', '', 'divider')] + const { container } = render() + expect(container).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/index.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/index.spec.tsx new file mode 100644 index 0000000000..fc0bb56f75 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/index.spec.tsx @@ -0,0 +1,147 @@ +import type { App, AppSSO } from '@/types/app' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppInfo from '..' + +let mockIsCurrentWorkspaceEditor = true +const mockSetPanelOpen = vi.fn() + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: mockIsCurrentWorkspaceEditor, + }), +})) + +vi.mock('../app-info-trigger', () => ({ + default: React.memo(({ appDetail, expand, onClick }: { + appDetail: App & Partial + expand: boolean + onClick: () => void + }) => ( + + )), +})) + +vi.mock('../app-info-detail-panel', () => ({ + default: React.memo(({ show, onClose }: { show: boolean, onClose: () => void }) => ( + show ?
: null + )), +})) + +vi.mock('../app-info-modals', () => ({ + default: React.memo(({ activeModal }: { activeModal: string | null }) => ( + activeModal ?
: null + )), +})) + +const mockAppDetail: App & Partial = { + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: '', + use_icon_as_answer_icon: false, +} as App & Partial + +const mockUseAppInfoActions = { + appDetail: mockAppDetail, + panelOpen: false, + setPanelOpen: mockSetPanelOpen, + closePanel: vi.fn(), + activeModal: null as string | null, + openModal: vi.fn(), + closeModal: vi.fn(), + secretEnvList: [], + setSecretEnvList: vi.fn(), + onEdit: vi.fn(), + onCopy: vi.fn(), + onExport: vi.fn(), + exportCheck: vi.fn(), + handleConfirmExport: vi.fn(), + onConfirmDelete: vi.fn(), +} + +vi.mock('../use-app-info-actions', () => ({ + useAppInfoActions: () => mockUseAppInfoActions, +})) + +describe('AppInfo', () => { + beforeEach(() => { + vi.clearAllMocks() + mockIsCurrentWorkspaceEditor = true + mockUseAppInfoActions.appDetail = mockAppDetail + mockUseAppInfoActions.panelOpen = false + mockUseAppInfoActions.activeModal = null + }) + + it('should return null when appDetail is not available', () => { + mockUseAppInfoActions.appDetail = undefined as unknown as App & Partial + const { container } = render() + expect(container.innerHTML).toBe('') + }) + + it('should render trigger when not onlyShowDetail', () => { + render() + expect(screen.getByTestId('trigger')).toBeInTheDocument() + }) + + it('should not render trigger when onlyShowDetail is true', () => { + render() + expect(screen.queryByTestId('trigger')).not.toBeInTheDocument() + }) + + it('should pass expand prop to trigger', () => { + render() + expect(screen.getByTestId('trigger')).toHaveAttribute('data-expand', 'true') + + const { unmount } = render() + const triggers = screen.getAllByTestId('trigger') + expect(triggers[triggers.length - 1]).toHaveAttribute('data-expand', 'false') + unmount() + }) + + it('should toggle panel when trigger is clicked and user is editor', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('trigger')) + + expect(mockSetPanelOpen).toHaveBeenCalled() + const updater = mockSetPanelOpen.mock.calls[0][0] as (v: boolean) => boolean + expect(updater(false)).toBe(true) + expect(updater(true)).toBe(false) + }) + + it('should not toggle panel when trigger is clicked and user is not editor', async () => { + const user = userEvent.setup() + mockIsCurrentWorkspaceEditor = false + render() + + await user.click(screen.getByTestId('trigger')) + + expect(mockSetPanelOpen).not.toHaveBeenCalled() + }) + + it('should show detail panel based on panelOpen when not onlyShowDetail', () => { + mockUseAppInfoActions.panelOpen = true + render() + expect(screen.getByTestId('detail-panel')).toBeInTheDocument() + }) + + it('should show detail panel based on openState when onlyShowDetail', () => { + render() + expect(screen.getByTestId('detail-panel')).toBeInTheDocument() + }) + + it('should hide detail panel when openState is false and onlyShowDetail', () => { + render() + expect(screen.queryByTestId('detail-panel')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts new file mode 100644 index 0000000000..6104e2b641 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -0,0 +1,492 @@ +import { act, renderHook } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import { useAppInfoActions } from '../use-app-info-actions' + +const mockNotify = vi.fn() +const mockReplace = vi.fn() +const mockOnPlanInfoChanged = vi.fn() +const mockInvalidateAppList = vi.fn() +const mockSetAppDetail = vi.fn() +const mockUpdateAppInfo = vi.fn() +const mockCopyApp = vi.fn() +const mockExportAppConfig = vi.fn() +const mockDeleteApp = vi.fn() +const mockFetchWorkflowDraft = vi.fn() +const mockDownloadBlob = vi.fn() + +let mockAppDetail: Record | undefined = { + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', +} + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ replace: mockReplace }), +})) + +vi.mock('use-context-selector', () => ({ + useContext: () => ({ notify: mockNotify }), +})) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ onPlanInfoChanged: mockOnPlanInfoChanged }), +})) + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: Record) => unknown) => selector({ + appDetail: mockAppDetail, + setAppDetail: mockSetAppDetail, + }), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + ToastContext: {}, +})) + +vi.mock('@/service/use-apps', () => ({ + useInvalidateAppList: () => mockInvalidateAppList, +})) + +vi.mock('@/service/apps', () => ({ + updateAppInfo: (...args: unknown[]) => mockUpdateAppInfo(...args), + copyApp: (...args: unknown[]) => mockCopyApp(...args), + exportAppConfig: (...args: unknown[]) => mockExportAppConfig(...args), + deleteApp: (...args: unknown[]) => mockDeleteApp(...args), +})) + +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), +})) + +vi.mock('@/utils/download', () => ({ + downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args), +})) + +vi.mock('@/utils/app-redirection', () => ({ + getRedirection: vi.fn(), +})) + +vi.mock('@/config', () => ({ + NEED_REFRESH_APP_LIST_KEY: 'test-refresh-key', +})) + +describe('useAppInfoActions', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppDetail = { + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + } + }) + + describe('Initial state', () => { + it('should return initial state correctly', () => { + const { result } = renderHook(() => useAppInfoActions({})) + expect(result.current.appDetail).toEqual(mockAppDetail) + expect(result.current.panelOpen).toBe(false) + expect(result.current.activeModal).toBeNull() + expect(result.current.secretEnvList).toEqual([]) + }) + }) + + describe('Panel management', () => { + it('should toggle panelOpen', () => { + const { result } = renderHook(() => useAppInfoActions({})) + + act(() => { + result.current.setPanelOpen(true) + }) + + expect(result.current.panelOpen).toBe(true) + }) + + it('should close panel and call onDetailExpand', () => { + const onDetailExpand = vi.fn() + const { result } = renderHook(() => useAppInfoActions({ onDetailExpand })) + + act(() => { + result.current.setPanelOpen(true) + }) + + act(() => { + result.current.closePanel() + }) + + expect(result.current.panelOpen).toBe(false) + expect(onDetailExpand).toHaveBeenCalledWith(false) + }) + }) + + describe('Modal management', () => { + it('should open modal and close panel', () => { + const { result } = renderHook(() => useAppInfoActions({})) + + act(() => { + result.current.setPanelOpen(true) + }) + + act(() => { + result.current.openModal('edit') + }) + + expect(result.current.activeModal).toBe('edit') + expect(result.current.panelOpen).toBe(false) + }) + + it('should close modal', () => { + const { result } = renderHook(() => useAppInfoActions({})) + + act(() => { + result.current.openModal('delete') + }) + + act(() => { + result.current.closeModal() + }) + + expect(result.current.activeModal).toBeNull() + }) + }) + + describe('onEdit', () => { + it('should update app info and close modal on success', async () => { + const updatedApp = { ...mockAppDetail, name: 'Updated' } + mockUpdateAppInfo.mockResolvedValue(updatedApp) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onEdit({ + name: 'Updated', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + description: '', + use_icon_as_answer_icon: false, + }) + }) + + expect(mockUpdateAppInfo).toHaveBeenCalled() + expect(mockSetAppDetail).toHaveBeenCalledWith(updatedApp) + expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.editDone' }) + }) + + it('should notify error on edit failure', async () => { + mockUpdateAppInfo.mockRejectedValue(new Error('fail')) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onEdit({ + name: 'Updated', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + description: '', + use_icon_as_answer_icon: false, + }) + }) + + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.editFailed' }) + }) + + it('should not call updateAppInfo when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onEdit({ + name: 'Updated', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + description: '', + use_icon_as_answer_icon: false, + }) + }) + + expect(mockUpdateAppInfo).not.toHaveBeenCalled() + }) + }) + + describe('onCopy', () => { + it('should copy app and redirect on success', async () => { + const newApp = { id: 'app-2', name: 'Copy', mode: 'chat' } + mockCopyApp.mockResolvedValue(newApp) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onCopy({ + name: 'Copy', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + }) + }) + + expect(mockCopyApp).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' }) + expect(mockOnPlanInfoChanged).toHaveBeenCalled() + }) + + it('should notify error on copy failure', async () => { + mockCopyApp.mockRejectedValue(new Error('fail')) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onCopy({ + name: 'Copy', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + }) + }) + + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + }) + }) + + describe('onCopy - early return', () => { + it('should not call copyApp when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onCopy({ + name: 'Copy', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + }) + }) + + expect(mockCopyApp).not.toHaveBeenCalled() + }) + }) + + describe('onExport', () => { + it('should export app config and trigger download', async () => { + mockExportAppConfig.mockResolvedValue({ data: 'yaml-content' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onExport(false) + }) + + expect(mockExportAppConfig).toHaveBeenCalledWith({ appID: 'app-1', include: false }) + expect(mockDownloadBlob).toHaveBeenCalled() + }) + + it('should notify error on export failure', async () => { + mockExportAppConfig.mockRejectedValue(new Error('fail')) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onExport() + }) + + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + + describe('onExport - early return', () => { + it('should not export when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onExport() + }) + + expect(mockExportAppConfig).not.toHaveBeenCalled() + }) + }) + + describe('exportCheck', () => { + it('should call onExport directly for non-workflow modes', async () => { + mockExportAppConfig.mockResolvedValue({ data: 'yaml' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockExportAppConfig).toHaveBeenCalled() + }) + + it('should open export warning modal for workflow mode', async () => { + mockAppDetail = { ...mockAppDetail, mode: AppModeEnum.WORKFLOW } + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(result.current.activeModal).toBe('exportWarning') + }) + + it('should open export warning modal for advanced_chat mode', async () => { + mockAppDetail = { ...mockAppDetail, mode: AppModeEnum.ADVANCED_CHAT } + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(result.current.activeModal).toBe('exportWarning') + }) + }) + + describe('exportCheck - early return', () => { + it('should not do anything when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockExportAppConfig).not.toHaveBeenCalled() + }) + }) + + describe('handleConfirmExport', () => { + it('should export directly when no secret env variables', async () => { + mockAppDetail = { ...mockAppDetail, mode: AppModeEnum.WORKFLOW } + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [{ value_type: 'string' }], + }) + mockExportAppConfig.mockResolvedValue({ data: 'yaml' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(mockExportAppConfig).toHaveBeenCalled() + }) + + it('should set secret env list when secret variables exist', async () => { + mockAppDetail = { ...mockAppDetail, mode: AppModeEnum.WORKFLOW } + const secretVars = [{ value_type: 'secret', key: 'API_KEY' }] + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: secretVars, + }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(result.current.secretEnvList).toEqual(secretVars) + }) + + it('should notify error on workflow draft fetch failure', async () => { + mockFetchWorkflowDraft.mockRejectedValue(new Error('fail')) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + + describe('handleConfirmExport - early return', () => { + it('should not do anything when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(mockFetchWorkflowDraft).not.toHaveBeenCalled() + }) + }) + + describe('handleConfirmExport - with environment variables', () => { + it('should handle empty environment_variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: undefined, + }) + mockExportAppConfig.mockResolvedValue({ data: 'yaml' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(mockExportAppConfig).toHaveBeenCalled() + }) + }) + + describe('onConfirmDelete', () => { + it('should delete app and redirect on success', async () => { + mockDeleteApp.mockResolvedValue({}) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onConfirmDelete() + }) + + expect(mockDeleteApp).toHaveBeenCalledWith('app-1') + expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.appDeleted' }) + expect(mockInvalidateAppList).toHaveBeenCalled() + expect(mockReplace).toHaveBeenCalledWith('/apps') + expect(mockSetAppDetail).toHaveBeenCalledWith() + }) + + it('should not delete when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onConfirmDelete() + }) + + expect(mockDeleteApp).not.toHaveBeenCalled() + }) + + it('should notify error on delete failure', async () => { + mockDeleteApp.mockRejectedValue({ message: 'cannot delete' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onConfirmDelete() + }) + + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: expect.stringContaining('app.appDeleteFailed'), + }) + }) + }) +}) diff --git a/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx b/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx new file mode 100644 index 0000000000..70dcb8df70 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx @@ -0,0 +1,151 @@ +import type { Operation } from './app-operations' +import type { AppInfoModalType } from './use-app-info-actions' +import type { App, AppSSO } from '@/types/app' +import { + RiDeleteBinLine, + RiEditLine, + RiExchange2Line, + RiFileCopy2Line, + RiFileDownloadLine, + RiFileUploadLine, +} from '@remixicon/react' +import * as React from 'react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view' +import Button from '@/app/components/base/button' +import ContentDialog from '@/app/components/base/content-dialog' +import { AppModeEnum } from '@/types/app' +import AppIcon from '../../base/app-icon' +import { getAppModeLabel } from './app-mode-labels' +import AppOperations from './app-operations' + +type AppInfoDetailPanelProps = { + appDetail: App & Partial + show: boolean + onClose: () => void + openModal: (modal: Exclude) => void + exportCheck: () => void +} + +const AppInfoDetailPanel = ({ + appDetail, + show, + onClose, + openModal, + exportCheck, +}: AppInfoDetailPanelProps) => { + const { t } = useTranslation() + + const primaryOperations = useMemo(() => [ + { + id: 'edit', + title: t('editApp', { ns: 'app' }), + icon: , + onClick: () => openModal('edit'), + }, + { + id: 'duplicate', + title: t('duplicate', { ns: 'app' }), + icon: , + onClick: () => openModal('duplicate'), + }, + { + id: 'export', + title: t('export', { ns: 'app' }), + icon: , + onClick: exportCheck, + }, + ], [t, openModal, exportCheck]) + + const secondaryOperations = useMemo(() => [ + ...(appDetail.mode === AppModeEnum.ADVANCED_CHAT || appDetail.mode === AppModeEnum.WORKFLOW) + ? [{ + id: 'import', + title: t('common.importDSL', { ns: 'workflow' }), + icon: , + onClick: () => openModal('importDSL'), + }] + : [], + { + id: 'divider-1', + title: '', + icon: <>, + onClick: () => {}, + type: 'divider' as const, + }, + { + id: 'delete', + title: t('operation.delete', { ns: 'common' }), + icon: , + onClick: () => openModal('delete'), + }, + ], [appDetail.mode, t, openModal]) + + const switchOperation = useMemo(() => { + if (appDetail.mode !== AppModeEnum.COMPLETION && appDetail.mode !== AppModeEnum.CHAT) + return null + return { + id: 'switch', + title: t('switch', { ns: 'app' }), + icon: , + onClick: () => openModal('switch'), + } + }, [appDetail.mode, t, openModal]) + + return ( + +
+
+ +
+
{appDetail.name}
+
+ {getAppModeLabel(appDetail.mode, t)} +
+
+
+ {appDetail.description && ( +
+ {appDetail.description} +
+ )} + +
+ + {switchOperation && ( +
+ +
+ )} +
+ ) +} + +export default React.memo(AppInfoDetailPanel) diff --git a/web/app/components/app-sidebar/app-info/app-info-modals.tsx b/web/app/components/app-sidebar/app-info/app-info-modals.tsx new file mode 100644 index 0000000000..4ca7f6adbc --- /dev/null +++ b/web/app/components/app-sidebar/app-info/app-info-modals.tsx @@ -0,0 +1,122 @@ +import type { AppInfoModalType } from './use-app-info-actions' +import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' +import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import type { App, AppSSO } from '@/types/app' +import dynamic from 'next/dynamic' +import * as React from 'react' +import { useTranslation } from 'react-i18next' + +const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false }) +const CreateAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), { ssr: false }) +const DuplicateAppModal = dynamic(() => import('@/app/components/app/duplicate-modal'), { ssr: false }) +const Confirm = dynamic(() => import('@/app/components/base/confirm'), { ssr: false }) +const UpdateDSLModal = dynamic(() => import('@/app/components/workflow/update-dsl-modal'), { ssr: false }) +const DSLExportConfirmModal = dynamic(() => import('@/app/components/workflow/dsl-export-confirm-modal'), { ssr: false }) + +type AppInfoModalsProps = { + appDetail: App & Partial + activeModal: AppInfoModalType + closeModal: () => void + secretEnvList: EnvironmentVariable[] + setSecretEnvList: (list: EnvironmentVariable[]) => void + onEdit: CreateAppModalProps['onConfirm'] + onCopy: DuplicateAppModalProps['onConfirm'] + onExport: (include?: boolean) => Promise + exportCheck: () => void + handleConfirmExport: () => void + onConfirmDelete: () => void +} + +const AppInfoModals = ({ + appDetail, + activeModal, + closeModal, + secretEnvList, + setSecretEnvList, + onEdit, + onCopy, + onExport, + exportCheck, + handleConfirmExport, + onConfirmDelete, +}: AppInfoModalsProps) => { + const { t } = useTranslation() + + return ( + <> + {activeModal === 'switch' && ( + + )} + {activeModal === 'edit' && ( + + )} + {activeModal === 'duplicate' && ( + + )} + {activeModal === 'delete' && ( + + )} + {activeModal === 'importDSL' && ( + + )} + {activeModal === 'exportWarning' && ( + + )} + {secretEnvList.length > 0 && ( + setSecretEnvList([])} + /> + )} + + ) +} + +export default React.memo(AppInfoModals) diff --git a/web/app/components/app-sidebar/app-info/app-info-trigger.tsx b/web/app/components/app-sidebar/app-info/app-info-trigger.tsx new file mode 100644 index 0000000000..07a41124e3 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/app-info-trigger.tsx @@ -0,0 +1,67 @@ +import type { App, AppSSO } from '@/types/app' +import { RiEqualizer2Line } from '@remixicon/react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import { cn } from '@/utils/classnames' +import AppIcon from '../../base/app-icon' +import { getAppModeLabel } from './app-mode-labels' + +type AppInfoTriggerProps = { + appDetail: App & Partial + expand: boolean + onClick: () => void +} + +const AppInfoTrigger = ({ appDetail, expand, onClick }: AppInfoTriggerProps) => { + const { t } = useTranslation() + const modeLabel = getAppModeLabel(appDetail.mode, t) + + return ( + + ) +} + +export default React.memo(AppInfoTrigger) diff --git a/web/app/components/app-sidebar/app-info/app-mode-labels.ts b/web/app/components/app-sidebar/app-info/app-mode-labels.ts new file mode 100644 index 0000000000..1d72feb089 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/app-mode-labels.ts @@ -0,0 +1,17 @@ +import type { TFunction } from 'i18next' +import { AppModeEnum } from '@/types/app' + +export function getAppModeLabel(mode: string, t: TFunction): string { + switch (mode) { + case AppModeEnum.ADVANCED_CHAT: + return t('types.advanced', { ns: 'app' }) + case AppModeEnum.AGENT_CHAT: + return t('types.agent', { ns: 'app' }) + case AppModeEnum.CHAT: + return t('types.chatbot', { ns: 'app' }) + case AppModeEnum.COMPLETION: + return t('types.completion', { ns: 'app' }) + default: + return t('types.workflow', { ns: 'app' }) + } +} diff --git a/web/app/components/app-sidebar/app-operations.tsx b/web/app/components/app-sidebar/app-info/app-operations.tsx similarity index 99% rename from web/app/components/app-sidebar/app-operations.tsx rename to web/app/components/app-sidebar/app-info/app-operations.tsx index 1cf6acaf2e..a182db7cc8 100644 --- a/web/app/components/app-sidebar/app-operations.tsx +++ b/web/app/components/app-sidebar/app-info/app-operations.tsx @@ -3,7 +3,7 @@ import { RiMoreLine } from '@remixicon/react' import { cloneElement, useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' -import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../base/portal-to-follow-elem' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' export type Operation = { id: string diff --git a/web/app/components/app-sidebar/app-info/index.tsx b/web/app/components/app-sidebar/app-info/index.tsx new file mode 100644 index 0000000000..2530add2dc --- /dev/null +++ b/web/app/components/app-sidebar/app-info/index.tsx @@ -0,0 +1,75 @@ +import * as React from 'react' +import { useAppContext } from '@/context/app-context' +import AppInfoDetailPanel from './app-info-detail-panel' +import AppInfoModals from './app-info-modals' +import AppInfoTrigger from './app-info-trigger' +import { useAppInfoActions } from './use-app-info-actions' + +export type IAppInfoProps = { + expand: boolean + onlyShowDetail?: boolean + openState?: boolean + onDetailExpand?: (expand: boolean) => void +} + +const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailExpand }: IAppInfoProps) => { + const { isCurrentWorkspaceEditor } = useAppContext() + + const { + appDetail, + panelOpen, + setPanelOpen, + closePanel, + activeModal, + openModal, + closeModal, + secretEnvList, + setSecretEnvList, + onEdit, + onCopy, + onExport, + exportCheck, + handleConfirmExport, + onConfirmDelete, + } = useAppInfoActions({ onDetailExpand }) + + if (!appDetail) + return null + + return ( +
+ {!onlyShowDetail && ( + { + if (isCurrentWorkspaceEditor) + setPanelOpen(v => !v) + }} + /> + )} + + +
+ ) +} + +export default React.memo(AppInfo) diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts new file mode 100644 index 0000000000..800f21de44 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -0,0 +1,189 @@ +import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' +import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import { useRouter } from 'next/navigation' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { useStore as useAppStore } from '@/app/components/app/store' +import { ToastContext } from '@/app/components/base/toast/context' +import { NEED_REFRESH_APP_LIST_KEY } from '@/config' +import { useProviderContext } from '@/context/provider-context' +import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' +import { useInvalidateAppList } from '@/service/use-apps' +import { fetchWorkflowDraft } from '@/service/workflow' +import { AppModeEnum } from '@/types/app' +import { getRedirection } from '@/utils/app-redirection' +import { downloadBlob } from '@/utils/download' + +export type AppInfoModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'importDSL' | 'exportWarning' | null + +type UseAppInfoActionsParams = { + onDetailExpand?: (expand: boolean) => void +} + +export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const { replace } = useRouter() + const { onPlanInfoChanged } = useProviderContext() + const appDetail = useAppStore(state => state.appDetail) + const setAppDetail = useAppStore(state => state.setAppDetail) + const invalidateAppList = useInvalidateAppList() + + const [panelOpen, setPanelOpen] = useState(false) + const [activeModal, setActiveModal] = useState(null) + const [secretEnvList, setSecretEnvList] = useState([]) + + const closePanel = useCallback(() => { + setPanelOpen(false) + onDetailExpand?.(false) + }, [onDetailExpand]) + + const openModal = useCallback((modal: Exclude) => { + closePanel() + setActiveModal(modal) + }, [closePanel]) + + const closeModal = useCallback(() => { + setActiveModal(null) + }, []) + + const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ + name, + icon_type, + icon, + icon_background, + description, + use_icon_as_answer_icon, + max_active_requests, + }) => { + if (!appDetail) + return + try { + const app = await updateAppInfo({ + appID: appDetail.id, + name, + icon_type, + icon, + icon_background, + description, + use_icon_as_answer_icon, + max_active_requests, + }) + closeModal() + notify({ type: 'success', message: t('editDone', { ns: 'app' }) }) + setAppDetail(app) + } + catch { + notify({ type: 'error', message: t('editFailed', { ns: 'app' }) }) + } + }, [appDetail, closeModal, notify, setAppDetail, t]) + + const onCopy: DuplicateAppModalProps['onConfirm'] = useCallback(async ({ + name, + icon_type, + icon, + icon_background, + }) => { + if (!appDetail) + return + try { + const newApp = await copyApp({ + appID: appDetail.id, + name, + icon_type, + icon, + icon_background, + mode: appDetail.mode, + }) + closeModal() + notify({ type: 'success', message: t('newApp.appCreated', { ns: 'app' }) }) + localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + onPlanInfoChanged() + getRedirection(true, newApp, replace) + } + catch { + notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + } + }, [appDetail, closeModal, notify, onPlanInfoChanged, replace, t]) + + const onExport = useCallback(async (include = false) => { + if (!appDetail) + return + try { + const { data } = await exportAppConfig({ appID: appDetail.id, include }) + const file = new Blob([data], { type: 'application/yaml' }) + downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) + } + catch { + notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + } + }, [appDetail, notify, t]) + + const exportCheck = useCallback(async () => { + if (!appDetail) + return + if (appDetail.mode !== AppModeEnum.WORKFLOW && appDetail.mode !== AppModeEnum.ADVANCED_CHAT) { + onExport() + return + } + setActiveModal('exportWarning') + }, [appDetail, onExport]) + + const handleConfirmExport = useCallback(async () => { + if (!appDetail) + return + closeModal() + try { + const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) + const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') + if (list.length === 0) { + onExport() + return + } + setSecretEnvList(list) + } + catch { + notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + } + }, [appDetail, closeModal, notify, onExport, t]) + + const onConfirmDelete = useCallback(async () => { + if (!appDetail) + return + try { + await deleteApp(appDetail.id) + notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) }) + invalidateAppList() + onPlanInfoChanged() + setAppDetail() + replace('/apps') + } + catch (e: unknown) { + notify({ + type: 'error', + message: `${t('appDeleteFailed', { ns: 'app' })}${e instanceof Error && e.message ? `: ${e.message}` : ''}`, + }) + } + closeModal() + }, [appDetail, closeModal, invalidateAppList, notify, onPlanInfoChanged, replace, setAppDetail, t]) + + return { + appDetail, + panelOpen, + setPanelOpen, + closePanel, + activeModal, + openModal, + closeModal, + secretEnvList, + setSecretEnvList, + onEdit, + onCopy, + onExport, + exportCheck, + handleConfirmExport, + onConfirmDelete, + } +} diff --git a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx index 521342238e..87632ba647 100644 --- a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx @@ -1,4 +1,4 @@ -import type { NavIcon } from './navLink' +import type { NavIcon } from './nav-link' import { RiEqualizer2Line, RiMenuLine, @@ -13,12 +13,12 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import { useAppContext } from '@/context/app-context' -import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' import AppIcon from '../base/app-icon' import Divider from '../base/divider' import AppInfo from './app-info' -import NavLink from './navLink' +import { getAppModeLabel } from './app-info/app-mode-labels' +import NavLink from './nav-link' type Props = { navigation: Array<{ @@ -99,7 +99,7 @@ const AppSidebarDropdown = ({ navigation }: Props) => {
{appDetail.name}
-
{appDetail.mode === AppModeEnum.ADVANCED_CHAT ? t('types.advanced', { ns: 'app' }) : appDetail.mode === AppModeEnum.AGENT_CHAT ? t('types.agent', { ns: 'app' }) : appDetail.mode === AppModeEnum.CHAT ? t('types.chatbot', { ns: 'app' }) : appDetail.mode === AppModeEnum.COMPLETION ? t('types.completion', { ns: 'app' }) : t('types.workflow', { ns: 'app' })}
+
{getAppModeLabel(appDetail.mode, t)}
diff --git a/web/app/components/app-sidebar/completion.png b/web/app/components/app-sidebar/completion.png deleted file mode 100644 index 7a3cbd5107..0000000000 Binary files a/web/app/components/app-sidebar/completion.png and /dev/null differ diff --git a/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx new file mode 100644 index 0000000000..512f9490c2 --- /dev/null +++ b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx @@ -0,0 +1,228 @@ +import type { DataSet } from '@/models/datasets' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { + ChunkingMode, + DatasetPermission, + DataSourceType, +} from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import Dropdown from '../dropdown' + +let mockDataset: DataSet +let mockIsDatasetOperator = false +const mockReplace = vi.fn() +const mockInvalidDatasetList = vi.fn() +const mockInvalidDatasetDetail = vi.fn() +const mockExportPipeline = vi.fn() +const mockCheckIsUsedInApp = vi.fn() +const mockDeleteDataset = vi.fn() + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'dataset-1', + name: 'Dataset Name', + indexing_status: 'completed', + icon_info: { + icon: '📙', + icon_background: '#FFF4ED', + icon_type: 'emoji', + icon_url: '', + }, + description: 'Dataset description', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 1690000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 1, + total_document_count: 1, + word_count: 1000, + provider: 'internal', + embedding_model: 'text-embedding-3', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { reranking_provider_name: '', reranking_model_name: '' }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { reranking_provider_name: '', reranking_model_name: '' }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 0, + score_threshold: 0, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + runtime_mode: 'rag_pipeline', + enable_api: false, + is_multimodal: false, + ...overrides, +}) + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ replace: mockReplace }), +})) + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }), +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) => + selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + datasetDetailQueryKeyPrefix: ['dataset', 'detail'], + useInvalidDatasetList: () => mockInvalidDatasetList, +})) + +vi.mock('@/service/use-base', () => ({ + useInvalid: () => mockInvalidDatasetDetail, +})) + +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ mutateAsync: mockExportPipeline }), +})) + +vi.mock('@/service/datasets', () => ({ + checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args), + deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args), +})) + +vi.mock('@/app/components/datasets/rename-modal', () => ({ + default: ({ + show, + onClose, + onSuccess, + }: { + show: boolean + onClose: () => void + onSuccess?: () => void + }) => { + if (!show) + return null + return ( +
+ + +
+ ) + }, +})) + +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ + isShow, + onConfirm, + onCancel, + title, + content, + }: { + isShow: boolean + onConfirm: () => void + onCancel: () => void + title: string + content: string + }) => { + if (!isShow) + return null + return ( +
+ {title} + {content} + + +
+ ) + }, +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children }: { children: React.ReactNode }) =>
{children}
, + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +describe('Dropdown callback coverage', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' }) + mockIsDatasetOperator = false + mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' }) + mockCheckIsUsedInApp.mockResolvedValue({ is_using: false }) + mockDeleteDataset.mockResolvedValue({}) + }) + + it('should call refreshDataset when rename succeeds', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('portal-trigger')) + await user.click(screen.getByText('common.operation.edit')) + + expect(screen.getByTestId('rename-modal')).toBeInTheDocument() + await user.click(screen.getByText('Success')) + + await waitFor(() => { + expect(mockInvalidDatasetList).toHaveBeenCalled() + expect(mockInvalidDatasetDetail).toHaveBeenCalled() + }) + }) + + it('should close rename modal when onClose is called', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('portal-trigger')) + await user.click(screen.getByText('common.operation.edit')) + + expect(screen.getByTestId('rename-modal')).toBeInTheDocument() + await user.click(screen.getByText('Close')) + + await waitFor(() => { + expect(screen.queryByTestId('rename-modal')).not.toBeInTheDocument() + }) + }) + + it('should close confirm dialog when cancel is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('portal-trigger')) + await user.click(screen.getByText('common.operation.delete')) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + await user.click(screen.getByText('cancel')) + + await waitFor(() => { + expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app-sidebar/dataset-info/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx similarity index 98% rename from web/app/components/app-sidebar/dataset-info/index.spec.tsx rename to web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx index 9996ef2b4d..be27e247d7 100644 --- a/web/app/components/app-sidebar/dataset-info/index.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx @@ -9,10 +9,10 @@ import { DataSourceType, } from '@/models/datasets' import { RETRIEVE_METHOD } from '@/types/app' -import Dropdown from './dropdown' -import DatasetInfo from './index' -import Menu from './menu' -import MenuItem from './menu-item' +import DatasetInfo from '..' +import Dropdown from '../dropdown' +import Menu from '../menu' +import MenuItem from '../menu-item' let mockDataset: DataSet let mockIsDatasetOperator = false diff --git a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx index c6e7e04375..5beea54ab0 100644 --- a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx @@ -1,4 +1,4 @@ -import type { NavIcon } from './navLink' +import type { NavIcon } from './nav-link' import type { DataSet } from '@/models/datasets' import { RiMenuLine, @@ -21,7 +21,7 @@ import Divider from '../base/divider' import Effect from '../base/effect' import ExtraInfo from '../datasets/extra-info' import Dropdown from './dataset-info/dropdown' -import NavLink from './navLink' +import NavLink from './nav-link' type DatasetSidebarDropdownProps = { navigation: Array<{ diff --git a/web/app/components/app-sidebar/expert.png b/web/app/components/app-sidebar/expert.png deleted file mode 100644 index ba941a5865..0000000000 Binary files a/web/app/components/app-sidebar/expert.png and /dev/null differ diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index afc6bd0f13..e24b005d01 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -1,4 +1,4 @@ -import type { NavIcon } from './navLink' +import type { NavIcon } from './nav-link' import { useHover, useKeyPress } from 'ahooks' import { usePathname } from 'next/navigation' import * as React from 'react' @@ -14,7 +14,7 @@ import AppInfo from './app-info' import AppSidebarDropdown from './app-sidebar-dropdown' import DatasetInfo from './dataset-info' import DatasetSidebarDropdown from './dataset-sidebar-dropdown' -import NavLink from './navLink' +import NavLink from './nav-link' import ToggleButton from './toggle-button' export type IAppDetailNavProps = { diff --git a/web/app/components/app-sidebar/navLink.spec.tsx b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx similarity index 98% rename from web/app/components/app-sidebar/navLink.spec.tsx rename to web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx index 62ef553386..04ca7bd0e4 100644 --- a/web/app/components/app-sidebar/navLink.spec.tsx +++ b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx @@ -1,7 +1,7 @@ -import type { NavLinkProps } from './navLink' +import type { NavLinkProps } from '..' import { render, screen } from '@testing-library/react' import * as React from 'react' -import NavLink from './navLink' +import NavLink from '..' // Mock Next.js navigation vi.mock('next/navigation', () => ({ @@ -10,7 +10,7 @@ vi.mock('next/navigation', () => ({ // Mock Next.js Link component vi.mock('next/link', () => ({ - default: function MockLink({ children, href, className, title }: any) { + default: function MockLink({ children, href, className, title }: { children: React.ReactNode, href: string, className?: string, title?: string }) { return ( {children} diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/nav-link/index.tsx similarity index 100% rename from web/app/components/app-sidebar/navLink.tsx rename to web/app/components/app-sidebar/nav-link/index.tsx diff --git a/web/app/components/app-sidebar/style.module.css b/web/app/components/app-sidebar/style.module.css deleted file mode 100644 index ca0978b760..0000000000 --- a/web/app/components/app-sidebar/style.module.css +++ /dev/null @@ -1,11 +0,0 @@ -.sidebar { - border-right: 1px solid #F3F4F6; -} - -.completionPic { -background-image: url('./completion.png') -} - -.expertPic { -background-image: url('./expert.png') -} diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx index 6a67ba3207..55f5ee0564 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -1,7 +1,7 @@ import type { Props } from './csv-uploader' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import CSVUploader from './csv-uploader' describe('CSVUploader', () => { diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index 5bfade82ea..a969b3d491 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' export type Props = { diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index d0e9eb586c..9625204d81 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -20,7 +20,7 @@ import { } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index a651d935a4..39a1699063 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -17,7 +17,7 @@ import { useFeaturesStore } from '@/app/components/base/features/hooks' import PromptEditor from '@/app/components/base/prompt-editor' import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index b0134b1f8d..f719d87261 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -12,7 +12,7 @@ import { CopyCheck, } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx index b6273f66ff..264e66fd96 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx @@ -3,7 +3,7 @@ import type { DataSet } from '@/models/datasets' import type { RetrievalConfig } from '@/types/app' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { IndexingType } from '@/app/components/datasets/create/step-two' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index b03423ded4..4435e1b311 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { IndexingType } from '@/app/components/datasets/create/step-two' import IndexMethod from '@/app/components/datasets/settings/index-method' diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx new file mode 100644 index 0000000000..74aed2d1e2 --- /dev/null +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx @@ -0,0 +1,28 @@ +'use client' + +import type { ReactNode } from 'react' +import type { DebugWithMultipleModelContextType } from './context' +import { DebugWithMultipleModelContext } from './context' + +type DebugWithMultipleModelContextProviderProps = { + children: ReactNode +} & DebugWithMultipleModelContextType +export const DebugWithMultipleModelContextProvider = ({ + children, + onMultipleModelConfigsChange, + multipleModelConfigs, + onDebugWithMultipleModelChange, + checkCanSend, +}: DebugWithMultipleModelContextProviderProps) => { + return ( + + {children} + + ) +} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx index e26fcec607..989285f812 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx @@ -1,10 +1,8 @@ import type { ModelAndParameter } from '../types' import type { DebugWithMultipleModelContextType } from './context' import { render, screen } from '@testing-library/react' -import { - DebugWithMultipleModelContextProvider, - useDebugWithMultipleModelContext, -} from './context' +import { useDebugWithMultipleModelContext } from './context' +import { DebugWithMultipleModelContextProvider } from './context-provider' const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({ id: 'model-1', diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts similarity index 50% rename from web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx rename to web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts index 38f803f8ab..e3ad06f1b9 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts @@ -10,7 +10,8 @@ export type DebugWithMultipleModelContextType = { onDebugWithMultipleModelChange: (singleModelConfig: ModelAndParameter) => void checkCanSend?: () => boolean } -const DebugWithMultipleModelContext = createContext({ + +export const DebugWithMultipleModelContext = createContext({ multipleModelConfigs: [], onMultipleModelConfigsChange: noop, onDebugWithMultipleModelChange: noop, @@ -18,27 +19,4 @@ const DebugWithMultipleModelContext = createContext useContext(DebugWithMultipleModelContext) -type DebugWithMultipleModelContextProviderProps = { - children: React.ReactNode -} & DebugWithMultipleModelContextType -export const DebugWithMultipleModelContextProvider = ({ - children, - onMultipleModelConfigsChange, - multipleModelConfigs, - onDebugWithMultipleModelChange, - checkCanSend, -}: DebugWithMultipleModelContextProviderProps) => { - return ( - - {children} - - ) -} - export default DebugWithMultipleModelContext diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx index c73eb54329..f98e8c1f06 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx @@ -14,10 +14,8 @@ import { useDebugConfigurationContext } from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' import { AppModeEnum } from '@/types/app' import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types' -import { - DebugWithMultipleModelContextProvider, - useDebugWithMultipleModelContext, -} from './context' +import { useDebugWithMultipleModelContext } from './context' +import { DebugWithMultipleModelContextProvider } from './context-provider' import DebugItem from './debug-item' const DebugWithMultipleModel = () => { diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx index 08bdd2bfcb..48141d0045 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx @@ -387,7 +387,7 @@ vi.mock('@/context/event-emitter', () => ({ })) // Mock toast context -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: vi.fn(() => ({ notify: vi.fn(), })), diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index c52af813ab..1bef7f367a 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -29,7 +29,7 @@ import Button from '@/app/components/base/button' import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows' import PromptLogModal from '@/app/components/base/prompt-log-modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import TooltipPlus from '@/app/components/base/tooltip' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 765f0fa9b9..091192646d 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -50,7 +50,8 @@ import { FeaturesProvider } from '@/app/components/base/features' import NewFeaturePanel from '@/app/components/base/features/new-feature-panel' import Loading from '@/app/components/base/loading' import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' -import Toast, { ToastContext } from '@/app/components/base/toast' +import Toast from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { @@ -67,7 +68,7 @@ import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { ANNOTATION_DEFAULT, DATASET_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' import { useAppContext } from '@/context/app-context' import ConfigContext from '@/context/debug-configuration' -import { MittProvider } from '@/context/mitt-context' +import { MittProvider } from '@/context/mitt-context-provider' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' diff --git a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx index 62c29bd9fc..dd7a0c6a6c 100644 --- a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx +++ b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx @@ -13,7 +13,7 @@ import FormGeneration from '@/app/components/base/features/new-feature-panel/mod import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education' import Modal from '@/app/components/base/modal' import { SimpleSelect } from '@/app/components/base/select' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ApiBasedExtensionSelector from '@/app/components/header/account-setting/api-based-extension-page/selector' import { useDocLink, useLocale } from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' diff --git a/web/app/components/app/configuration/tools/index.tsx b/web/app/components/app/configuration/tools/index.tsx index d2873b0be3..f348a7718d 100644 --- a/web/app/components/app/configuration/tools/index.tsx +++ b/web/app/components/app/configuration/tools/index.tsx @@ -15,7 +15,7 @@ import { } from '@/app/components/base/icons/src/vender/line/general' import { Tool03 } from '@/app/components/base/icons/src/vender/solid/general' import Switch from '@/app/components/base/switch' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index 75d650742d..b1f00b481d 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -4,7 +4,7 @@ import { useRouter } from 'next/navigation' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' import { trackEvent } from '@/app/components/base/amplitude' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { MARKETPLACE_URL_PREFIX, NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 12d4a98d8f..6d5bdc2448 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -19,7 +19,7 @@ import { BubbleTextMod, ChatBot, ListSparkle, Logic } from '@/app/components/bas import Input from '@/app/components/base/input' import CustomSelect from '@/app/components/base/select/custom' import Textarea from '@/app/components/base/textarea' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { MARKETPLACE_URL_PREFIX, NEED_REFRESH_APP_LIST_KEY } from '@/config' import { STORAGE_KEYS } from '@/config/storage-keys' diff --git a/web/app/components/app/create-from-dsl-modal/dsl-confirm-modal.tsx b/web/app/components/app/create-from-dsl-modal/dsl-confirm-modal.tsx index add1ffbba5..f032474257 100644 --- a/web/app/components/app/create-from-dsl-modal/dsl-confirm-modal.tsx +++ b/web/app/components/app/create-from-dsl-modal/dsl-confirm-modal.tsx @@ -3,7 +3,7 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 89bffd14d3..0d3ad6c233 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -11,7 +11,7 @@ import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index 509b7f101c..677c671980 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -10,7 +10,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' import { formatFileSize } from '@/utils/format' diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index b43d44397d..146af44a10 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -31,7 +31,7 @@ import Drawer from '@/app/components/base/drawer' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import Loading from '@/app/components/base/loading' import MessageLogModal from '@/app/components/base/message-log-modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { WorkflowContextProvider } from '@/app/components/workflow/context' diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index c9cbe0b724..d98e02ad57 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -59,16 +59,12 @@ vi.mock('@/context/modal-context', () => ({ useModalContext: () => buildModalContext(), })) -vi.mock('@/app/components/base/toast', async () => { - const actual = await vi.importActual('@/app/components/base/toast') - return { - ...actual, - useToastContext: () => ({ - notify: mockNotify, - close: vi.fn(), - }), - } -}) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: mockNotify, + close: vi.fn(), + }), +})) vi.mock('@/context/i18n', async () => { const actual = await vi.importActual('@/context/i18n') diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 2a5770b2a2..20461dda7e 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -20,7 +20,7 @@ import PremiumBadge from '@/app/components/base/premium-badge' import { SimpleSelect } from '@/app/components/base/select' import Switch from '@/app/components/base/switch' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useModalContext } from '@/context/modal-context' diff --git a/web/app/components/app/switch-app-modal/index.spec.tsx b/web/app/components/app/switch-app-modal/index.spec.tsx index 14607a1c95..fa6c099e1b 100644 --- a/web/app/components/app/switch-app-modal/index.spec.tsx +++ b/web/app/components/app/switch-app-modal/index.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { AppModeEnum } from '@/types/app' diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index 30d7877ed0..8caa07c187 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -15,7 +15,7 @@ import Confirm from '@/app/components/base/confirm' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/components/apps/__tests__/app-card.spec.tsx b/web/app/components/apps/__tests__/app-card.spec.tsx index ee36d471fd..9bc23ce199 100644 --- a/web/app/components/apps/__tests__/app-card.spec.tsx +++ b/web/app/components/apps/__tests__/app-card.spec.tsx @@ -63,6 +63,15 @@ vi.mock('@/service/apps', () => ({ exportAppConfig: vi.fn(() => Promise.resolve({ data: 'yaml: content' })), })) +const mockDeleteAppMutation = vi.fn(() => Promise.resolve()) +let mockDeleteMutationPending = false +vi.mock('@/service/use-apps', () => ({ + useDeleteAppMutation: () => ({ + mutateAsync: mockDeleteAppMutation, + isPending: mockDeleteMutationPending, + }), +})) + vi.mock('@/service/workflow', () => ({ fetchWorkflowDraft: vi.fn(() => Promise.resolve({ environment_variables: [] })), })) @@ -146,13 +155,6 @@ vi.mock('next/dynamic', () => ({ return React.createElement('div', { 'data-testid': 'switch-modal' }, React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-switch-modal' }, 'Close'), React.createElement('button', { 'onClick': onSuccess, 'data-testid': 'confirm-switch-modal' }, 'Switch')) } } - if (fnString.includes('base/confirm')) { - return function MockConfirm({ isShow, onCancel, onConfirm }: { isShow: boolean, onCancel: () => void, onConfirm: () => void }) { - if (!isShow) - return null - return React.createElement('div', { 'data-testid': 'confirm-dialog' }, React.createElement('button', { 'onClick': onCancel, 'data-testid': 'cancel-confirm' }, 'Cancel'), React.createElement('button', { 'onClick': onConfirm, 'data-testid': 'confirm-confirm' }, 'Confirm')) - } - } if (fnString.includes('dsl-export-confirm-modal')) { return function MockDSLExportModal({ onClose, onConfirm }: { onClose?: () => void, onConfirm?: (withSecrets: boolean) => void }) { return React.createElement('div', { 'data-testid': 'dsl-export-modal' }, React.createElement('button', { 'onClick': () => onClose?.(), 'data-testid': 'close-dsl-export' }, 'Close'), React.createElement('button', { 'onClick': () => onConfirm?.(true), 'data-testid': 'confirm-dsl-export' }, 'Export with secrets'), React.createElement('button', { 'onClick': () => onConfirm?.(false), 'data-testid': 'confirm-dsl-export-no-secrets' }, 'Export without secrets')) @@ -235,6 +237,7 @@ describe('AppCard', () => { vi.clearAllMocks() mockOpenAsyncWindow.mockReset() mockWebappAuthEnabled = false + mockDeleteMutationPending = false }) describe('Rendering', () => { @@ -461,35 +464,19 @@ describe('AppCard', () => { render() fireEvent.click(screen.getByTestId('popover-trigger')) - - await waitFor(() => { - const deleteButton = screen.getByText('common.operation.delete') - fireEvent.click(deleteButton) - }) - - await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() }) it('should close confirm dialog when cancel is clicked', async () => { render() fireEvent.click(screen.getByTestId('popover-trigger')) - + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) await waitFor(() => { - const deleteButton = screen.getByText('common.operation.delete') - fireEvent.click(deleteButton) - }) - - await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByTestId('cancel-confirm')) - - await waitFor(() => { - expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() }) }) @@ -554,59 +541,41 @@ describe('AppCard', () => { // Open popover and click delete fireEvent.click(screen.getByTestId('popover-trigger')) - await waitFor(() => { - fireEvent.click(screen.getByText('common.operation.delete')) - }) - - // Confirm delete - await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByTestId('confirm-confirm')) + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { - expect(appsService.deleteApp).toHaveBeenCalled() + expect(mockDeleteAppMutation).toHaveBeenCalled() }) }) - it('should call onRefresh after successful delete', async () => { + it('should not call onRefresh after successful delete', async () => { render() fireEvent.click(screen.getByTestId('popover-trigger')) - await waitFor(() => { - fireEvent.click(screen.getByText('common.operation.delete')) - }) + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByTestId('confirm-confirm')) - - await waitFor(() => { - expect(mockOnRefresh).toHaveBeenCalled() + expect(mockDeleteAppMutation).toHaveBeenCalled() }) + expect(mockOnRefresh).not.toHaveBeenCalled() }) it('should handle delete failure', async () => { - (appsService.deleteApp as Mock).mockRejectedValueOnce(new Error('Delete failed')) + ;(mockDeleteAppMutation as Mock).mockRejectedValueOnce(new Error('Delete failed')) render() fireEvent.click(screen.getByTestId('popover-trigger')) - await waitFor(() => { - fireEvent.click(screen.getByText('common.operation.delete')) - }) + fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByTestId('confirm-confirm')) - - await waitFor(() => { - expect(appsService.deleteApp).toHaveBeenCalled() + expect(mockDeleteAppMutation).toHaveBeenCalled() expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Delete failed') }) }) }) diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index 4dd2472756..82e2347781 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -1,10 +1,7 @@ -import type { UrlUpdateEvent } from 'nuqs/adapters/testing' -import type { ReactNode } from 'react' -import { QueryClient, QueryClientProvider } from '@tanstack/react-query' -import { act, fireEvent, render, screen } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { act, fireEvent, screen } from '@testing-library/react' import * as React from 'react' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' import List from '../list' @@ -117,6 +114,10 @@ vi.mock('@/service/use-apps', () => ({ error: mockServiceState.error, refetch: mockRefetch, }), + useDeleteAppMutation: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), })) vi.mock('@/service/tag', () => ({ @@ -199,30 +200,14 @@ beforeAll(() => { } as unknown as typeof IntersectionObserver }) -// Render helper wrapping with NuqsTestingAdapter -const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>() +// Render helper wrapping with shared nuqs testing helper. const renderList = (searchParams = '') => { - const queryClient = new QueryClient({ - defaultOptions: { - queries: { - retry: false, - }, - }, - }) - const wrapper = ({ children }: { children: ReactNode }) => ( - - - {children} - - - ) - return render(, { wrapper }) + return renderWithNuqs(, { searchParams }) } describe('List', () => { beforeEach(() => { vi.clearAllMocks() - onUrlUpdate.mockClear() useTagStore.setState({ tagList: [{ id: 'tag-1', name: 'Test Tag', type: 'app', binding_count: 0 }], showTagManagementModal: false, @@ -300,7 +285,7 @@ describe('List', () => { describe('Tab Navigation', () => { it('should update URL when workflow tab is clicked', async () => { - renderList() + const { onUrlUpdate } = renderList() fireEvent.click(screen.getByText('app.types.workflow')) @@ -310,7 +295,7 @@ describe('List', () => { }) it('should update URL when all tab is clicked', async () => { - renderList('?category=workflow') + const { onUrlUpdate } = renderList('?category=workflow') fireEvent.click(screen.getByText('app.types.all')) @@ -414,7 +399,7 @@ describe('List', () => { describe('Edge Cases', () => { it('should handle multiple renders without issues', () => { - const { rerender } = renderList() + const { rerender } = renderWithNuqs() expect(screen.getByText('app.types.all')).toBeInTheDocument() rerender() @@ -463,7 +448,7 @@ describe('List', () => { }) it('should update URL for each app type tab click', async () => { - renderList() + const { onUrlUpdate } = renderList() const appTypeTexts = [ { mode: AppModeEnum.WORKFLOW, text: 'app.types.workflow' }, diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index a41ead0240..9c6c98a55e 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -18,8 +18,18 @@ import AppIcon from '@/app/components/base/app-icon' import Divider from '@/app/components/base/divider' import CustomPopover from '@/app/components/base/popover' import TagSelector from '@/app/components/base/tag-management/selector' -import Toast, { ToastContext } from '@/app/components/base/toast' +import Toast from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' +import { + AlertDialog, + AlertDialogActions, + AlertDialogCancelButton, + AlertDialogConfirmButton, + AlertDialogContent, + AlertDialogDescription, + AlertDialogTitle, +} from '@/app/components/base/ui/alert-dialog' import { UserAvatarList } from '@/app/components/base/user-avatar-list' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' @@ -28,8 +38,9 @@ import { useProviderContext } from '@/context/provider-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { AccessMode } from '@/models/access-control' import { useGetUserCanAccessApp } from '@/service/access-control' -import { copyApp, deleteApp, exportAppBundle, exportAppConfig, updateAppInfo, upgradeAppRuntime } from '@/service/apps' +import { copyApp, exportAppBundle, exportAppConfig, updateAppInfo, upgradeAppRuntime } from '@/service/apps' import { fetchInstalledAppList } from '@/service/explore' +import { useDeleteAppMutation } from '@/service/use-apps' import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' @@ -47,9 +58,6 @@ const DuplicateAppModal = dynamic(() => import('@/app/components/app/duplicate-m const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false, }) -const Confirm = dynamic(() => import('@/app/components/base/confirm'), { - ssr: false, -}) const DSLExportConfirmModal = dynamic(() => import('@/app/components/workflow/dsl-export-confirm-modal'), { ssr: false, }) @@ -79,13 +87,12 @@ const AppCard = ({ app, onRefresh, onlineUsers = [] }: AppCardProps) => { const [showAccessControl, setShowAccessControl] = useState(false) const [secretEnvList, setSecretEnvList] = useState([]) const [exporting, startExport] = useTransition() + const { mutateAsync: mutateDeleteApp, isPending: isDeleting } = useDeleteAppMutation() const onConfirmDelete = useCallback(async () => { try { - await deleteApp(app.id) + await mutateDeleteApp(app.id) notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) }) - if (onRefresh) - onRefresh() onPlanInfoChanged() } catch (e: unknown) { @@ -94,8 +101,17 @@ const AppCard = ({ app, onRefresh, onlineUsers = [] }: AppCardProps) => { message: `${t('appDeleteFailed', { ns: 'app' })}${e instanceof Error ? `: ${e.message}` : ''}`, }) } - setShowConfirmDelete(false) - }, [app.id, notify, onPlanInfoChanged, onRefresh, t]) + finally { + setShowConfirmDelete(false) + } + }, [app.id, mutateDeleteApp, notify, onPlanInfoChanged, t]) + + const onDeleteDialogOpenChange = useCallback((open: boolean) => { + if (isDeleting) + return + + setShowConfirmDelete(open) + }, [isDeleting]) const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, @@ -509,7 +525,8 @@ const AppCard = ({ app, onRefresh, onlineUsers = [] }: AppCardProps) => {
- + {t('operation.more', { ns: 'common' })} +
)} btnClassName={open => @@ -566,15 +583,26 @@ const AppCard = ({ app, onRefresh, onlineUsers = [] }: AppCardProps) => { onSuccess={onSwitch} /> )} - {showConfirmDelete && ( - setShowConfirmDelete(false)} - /> - )} + + +
+ + {t('deleteAppConfirmTitle', { ns: 'app' })} + + + {t('deleteAppConfirmContent', { ns: 'app' })} + +
+ + + {t('operation.cancel', { ns: 'common' })} + + + {t('operation.confirm', { ns: 'common' })} + + +
+
{secretEnvList.length > 0 && ( { - const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>() - const wrapper = ({ children }: { children: ReactNode }) => ( - - {children} - - ) - const { result } = renderHook(() => useAppsQueryState(), { wrapper }) - return { result, onUrlUpdate } + return renderHookWithNuqs(() => useAppsQueryState(), { searchParams }) } describe('useAppsQueryState', () => { diff --git a/web/app/components/apps/import-from-marketplace-template-modal.tsx b/web/app/components/apps/import-from-marketplace-template-modal.tsx index 42d705409b..b44bf36d32 100644 --- a/web/app/components/apps/import-from-marketplace-template-modal.tsx +++ b/web/app/components/apps/import-from-marketplace-template-modal.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next' import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { MARKETPLACE_API_PREFIX, MARKETPLACE_URL_PREFIX } from '@/config' import { fetchMarketplaceTemplateDSL, diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index f5ffcc2320..aac265228a 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -4,7 +4,7 @@ import type { FC } from 'react' import { useQuery } from '@tanstack/react-query' import { useDebounceFn } from 'ahooks' import dynamic from 'next/dynamic' -import { parseAsString, useQueryState } from 'nuqs' +import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -20,7 +20,7 @@ import { useGlobalPublicStore } from '@/context/global-public-context' import { CheckModal } from '@/hooks/use-pay' import { fetchWorkflowOnlineUsers } from '@/service/apps' import { useInfiniteAppList } from '@/service/use-apps' -import { AppModeEnum } from '@/types/app' +import { AppModeEnum, AppModes } from '@/types/app' import { cn } from '@/utils/classnames' import AppCard from './app-card' import { AppCardSkeleton } from './app-card-skeleton' @@ -37,6 +37,18 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro ssr: false, }) +const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const +type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number] +const appListCategorySet = new Set(APP_LIST_CATEGORY_VALUES) + +const isAppListCategory = (value: string): value is AppListCategory => { + return appListCategorySet.has(value) +} + +const parseAsAppListCategory = parseAsStringLiteral(APP_LIST_CATEGORY_VALUES) + .withDefault('all') + .withOptions({ history: 'push' }) + type Props = { controlRefreshList?: number } @@ -49,7 +61,7 @@ const List: FC = ({ const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [activeTab, setActiveTab] = useQueryState( 'category', - parseAsString.withDefault('all').withOptions({ history: 'push' }), + parseAsAppListCategory, ) const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState() @@ -90,7 +102,7 @@ const List: FC = ({ name: searchKeywords, tag_ids: tagIDs, is_created_by_me: isCreatedByMe, - ...(activeTab !== 'all' ? { mode: activeTab as AppModeEnum } : {}), + ...(activeTab !== 'all' ? { mode: activeTab } : {}), } const { @@ -227,7 +239,10 @@ const List: FC = ({
{ + if (isAppListCategory(nextValue)) + setActiveTab(nextValue) + }} options={options} />
diff --git a/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx b/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx index c77f144da2..47d854e028 100644 --- a/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx +++ b/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx @@ -2,7 +2,7 @@ import type { ComponentProps } from 'react' import type { IChatItem } from '@/app/components/base/chat/chat/type' import type { AgentLogDetailResponse } from '@/models/log' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import AgentLogDetail from '../detail' diff --git a/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx b/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx index 6b59e90c77..6437ae5b43 100644 --- a/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx +++ b/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx @@ -1,7 +1,7 @@ import type { IChatItem } from '@/app/components/base/chat/chat/type' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { useClickAway } from 'ahooks' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import AgentLogModal from '../index' diff --git a/web/app/components/base/agent-log-modal/detail.tsx b/web/app/components/base/agent-log-modal/detail.tsx index 36b502e9a5..21ed0be7e8 100644 --- a/web/app/components/base/agent-log-modal/detail.tsx +++ b/web/app/components/base/agent-log-modal/detail.tsx @@ -10,7 +10,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import { cn } from '@/utils/classnames' import ResultPanel from './result' diff --git a/web/app/components/base/button/__tests__/index.spec.tsx b/web/app/components/base/button/__tests__/index.spec.tsx index b43ae89403..4fe0ab3570 100644 --- a/web/app/components/base/button/__tests__/index.spec.tsx +++ b/web/app/components/base/button/__tests__/index.spec.tsx @@ -1,110 +1,156 @@ -import { cleanup, fireEvent, render } from '@testing-library/react' -import * as React from 'react' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' import Button from '../index' afterEach(cleanup) -// https://testing-library.com/docs/queries/about + describe('Button', () => { - describe('Button text', () => { - it('Button text should be same as children', async () => { - const { getByRole, container } = render() - expect(getByRole('button').textContent).toBe('Click me') - expect(container.querySelector('button')?.textContent).toBe('Click me') + describe('rendering', () => { + it('renders children text', () => { + render() + expect(screen.getByRole('button')).toHaveTextContent('Click me') + }) + + it('renders as a native button element by default', () => { + render() + expect(screen.getByRole('button').tagName).toBe('BUTTON') + }) + + it('defaults to type="button"', () => { + render() + expect(screen.getByRole('button')).toHaveAttribute('type', 'button') + }) + + it('allows type override to submit', () => { + render() + expect(screen.getByRole('button')).toHaveAttribute('type', 'submit') + }) + + it('renders custom element via render prop', () => { + render() + const link = screen.getByRole('link') + expect(link).toHaveTextContent('Link') + expect(link).toHaveAttribute('href', '/test') }) }) - describe('Button loading', () => { - it('Loading button text should include same as children', async () => { - const { getByRole } = render() - expect(getByRole('button').textContent?.includes('Loading')).toBe(true) - }) - it('Not loading button text should include same as children', async () => { - const { getByRole } = render() - expect(getByRole('button').textContent?.includes('Loading')).toBe(false) + describe('variants', () => { + it('applies default secondary variant', () => { + render() + expect(screen.getByRole('button').className).toContain('btn-secondary') }) - it('Loading button should have loading classname', async () => { + it.each([ + 'primary', + 'warning', + 'secondary', + 'secondary-accent', + 'ghost', + 'ghost-accent', + 'tertiary', + ] as const)('applies %s variant', (variant) => { + render() + expect(screen.getByRole('button').className).toContain(`btn-${variant}`) + }) + + it('applies destructive modifier', () => { + render() + expect(screen.getByRole('button').className).toContain('btn-destructive') + }) + }) + + describe('sizes', () => { + it('applies default medium size', () => { + render() + expect(screen.getByRole('button').className).toContain('btn-medium') + }) + + it.each(['small', 'medium', 'large'] as const)('applies %s size', (size) => { + render() + expect(screen.getByRole('button').className).toContain(`btn-${size}`) + }) + }) + + describe('loading', () => { + it('shows spinner when loading', () => { + render() + expect(screen.getByRole('button').querySelector('.animate-spin')).toBeInTheDocument() + }) + + it('hides spinner when not loading', () => { + render() + expect(screen.getByRole('button').querySelector('.animate-spin')).not.toBeInTheDocument() + }) + + it('auto-disables when loading', () => { + render() + expect(screen.getByRole('button')).toBeDisabled() + }) + + it('sets aria-busy when loading', () => { + render() + expect(screen.getByRole('button')).toHaveAttribute('aria-busy', 'true') + }) + + it('does not set aria-busy when not loading', () => { + render() + expect(screen.getByRole('button')).not.toHaveAttribute('aria-busy') + }) + + it('applies custom spinnerClassName', () => { const animClassName = 'anim-breath' - const { getByRole } = render() - expect(getByRole('button').getElementsByClassName('animate-spin')[0]?.className).toContain(animClassName) + render() + expect(screen.getByRole('button').querySelector('.animate-spin')?.className).toContain(animClassName) }) }) - describe('Button style', () => { - it('Button should have default variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-secondary') + describe('disabled', () => { + it('disables button when disabled prop is set', () => { + render() + expect(screen.getByRole('button')).toBeDisabled() }) - it('Button should have primary variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-primary') - }) - - it('Button should have warning variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-warning') - }) - - it('Button should have secondary variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-secondary') - }) - - it('Button should have secondary-accent variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-secondary-accent') - }) - it('Button should have ghost variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-ghost') - }) - it('Button should have ghost-accent variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-ghost-accent') - }) - - it('Button disabled should have disabled variant', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-disabled') + it('keeps focusable when loading with focusableWhenDisabled', () => { + render() + const button = screen.getByRole('button') + expect(button).toHaveAttribute('aria-disabled', 'true') }) }) - describe('Button size', () => { - it('Button should have default size', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-medium') - }) - - it('Button should have small size', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-small') - }) - - it('Button should have medium size', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-medium') - }) - - it('Button should have large size', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-large') - }) - }) - - describe('Button destructive', () => { - it('Button should have destructive classname', async () => { - const { getByRole } = render() - expect(getByRole('button').className).toContain('btn-destructive') - }) - }) - - describe('Button events', () => { - it('onClick should been call after clicked', async () => { + describe('events', () => { + it('fires onClick when clicked', () => { const onClick = vi.fn() - const { getByRole } = render() - fireEvent.click(getByRole('button')) - expect(onClick).toHaveBeenCalled() + render() + fireEvent.click(screen.getByRole('button')) + expect(onClick).toHaveBeenCalledTimes(1) + }) + + it('does not fire onClick when disabled', () => { + const onClick = vi.fn() + render() + fireEvent.click(screen.getByRole('button')) + expect(onClick).not.toHaveBeenCalled() + }) + + it('does not fire onClick when loading', () => { + const onClick = vi.fn() + render() + fireEvent.click(screen.getByRole('button')) + expect(onClick).not.toHaveBeenCalled() + }) + }) + + describe('ref forwarding', () => { + it('forwards ref to the button element', () => { + let buttonRef: HTMLButtonElement | null = null + render( + , + ) + expect(buttonRef).toBeInstanceOf(HTMLButtonElement) }) }) }) diff --git a/web/app/components/base/button/index.css b/web/app/components/base/button/index.css index 5899c027d3..6360ed9d0c 100644 --- a/web/app/components/base/button/index.css +++ b/web/app/components/base/button/index.css @@ -2,10 +2,11 @@ @layer components { .btn { - @apply inline-flex justify-center items-center cursor-pointer whitespace-nowrap; + @apply inline-flex justify-center items-center cursor-pointer whitespace-nowrap + outline-none focus-visible:ring-2 focus-visible:ring-state-accent-solid; } - .btn-disabled { + .btn:is(:disabled, [data-disabled]) { @apply cursor-not-allowed; } @@ -40,7 +41,7 @@ text-components-button-destructive-primary-text; } - .btn-primary.btn-disabled { + .btn-primary:is(:disabled, [data-disabled]) { @apply shadow-none bg-components-button-primary-bg-disabled @@ -48,7 +49,7 @@ text-components-button-primary-text-disabled; } - .btn-primary.btn-destructive.btn-disabled { + .btn-primary.btn-destructive:is(:disabled, [data-disabled]) { @apply shadow-none bg-components-button-destructive-primary-bg-disabled @@ -68,7 +69,7 @@ text-components-button-secondary-text; } - .btn-secondary.btn-disabled { + .btn-secondary:is(:disabled, [data-disabled]) { @apply backdrop-blur-sm bg-components-button-secondary-bg-disabled @@ -85,7 +86,7 @@ text-components-button-destructive-secondary-text; } - .btn-secondary.btn-destructive.btn-disabled { + .btn-secondary.btn-destructive:is(:disabled, [data-disabled]) { @apply bg-components-button-destructive-secondary-bg-disabled border-components-button-destructive-secondary-border-disabled @@ -104,7 +105,7 @@ text-components-button-secondary-accent-text; } - .btn-secondary-accent.btn-disabled { + .btn-secondary-accent:is(:disabled, [data-disabled]) { @apply bg-components-button-secondary-bg-disabled border-components-button-secondary-border-disabled @@ -120,7 +121,7 @@ text-components-button-destructive-primary-text; } - .btn-warning.btn-disabled { + .btn-warning:is(:disabled, [data-disabled]) { @apply bg-components-button-destructive-primary-bg-disabled border-components-button-destructive-primary-border-disabled @@ -134,7 +135,7 @@ text-components-button-tertiary-text; } - .btn-tertiary.btn-disabled { + .btn-tertiary:is(:disabled, [data-disabled]) { @apply bg-components-button-tertiary-bg-disabled text-components-button-tertiary-text-disabled; @@ -147,7 +148,7 @@ text-components-button-destructive-tertiary-text; } - .btn-tertiary.btn-destructive.btn-disabled { + .btn-tertiary.btn-destructive:is(:disabled, [data-disabled]) { @apply bg-components-button-destructive-tertiary-bg-disabled text-components-button-destructive-tertiary-text-disabled; @@ -159,7 +160,7 @@ text-components-button-ghost-text; } - .btn-ghost.btn-disabled { + .btn-ghost:is(:disabled, [data-disabled]) { @apply text-components-button-ghost-text-disabled; } @@ -170,7 +171,7 @@ text-components-button-destructive-ghost-text; } - .btn-ghost.btn-destructive.btn-disabled { + .btn-ghost.btn-destructive:is(:disabled, [data-disabled]) { @apply text-components-button-destructive-ghost-text-disabled; } @@ -181,7 +182,7 @@ text-components-button-secondary-accent-text; } - .btn-ghost-accent.btn-disabled { + .btn-ghost-accent:is(:disabled, [data-disabled]) { @apply text-components-button-secondary-accent-text-disabled; } diff --git a/web/app/components/base/button/index.stories.tsx b/web/app/components/base/button/index.stories.tsx index 25bd5957e1..5a7ec55e8f 100644 --- a/web/app/components/base/button/index.stories.tsx +++ b/web/app/components/base/button/index.stories.tsx @@ -1,6 +1,5 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' -import { RocketLaunchIcon } from '@heroicons/react/20/solid' import { Button } from '.' const meta = { @@ -12,10 +11,16 @@ const meta = { tags: ['autodocs'], argTypes: { loading: { control: 'boolean' }, + destructive: { control: 'boolean' }, + disabled: { control: 'boolean' }, variant: { control: 'select', options: ['primary', 'warning', 'secondary', 'secondary-accent', 'ghost', 'ghost-accent', 'tertiary'], }, + size: { + control: 'select', + options: ['small', 'medium', 'large'], + }, }, args: { variant: 'ghost', @@ -29,11 +34,7 @@ type Story = StoryObj export const Default: Story = { args: { variant: 'primary', - loading: false, children: 'Primary Button', - styleCss: {}, - spinnerClassName: '', - destructive: false, }, } @@ -95,14 +96,46 @@ export const Loading: Story = { }, } +export const Destructive: Story = { + args: { + variant: 'primary', + destructive: true, + children: 'Delete', + }, +} + export const WithIcon: Story = { args: { variant: 'primary', children: ( <> - + Launch ), }, } + +export const SmallSize: Story = { + args: { + variant: 'secondary', + size: 'small', + children: 'Small', + }, +} + +export const LargeSize: Story = { + args: { + variant: 'primary', + size: 'large', + children: 'Large Button', + }, +} + +export const AsLink: Story = { + args: { + variant: 'ghost-accent', + render: , + children: 'Link Button', + }, +} diff --git a/web/app/components/base/button/index.tsx b/web/app/components/base/button/index.tsx index 0de57617af..047ced4c53 100644 --- a/web/app/components/base/button/index.tsx +++ b/web/app/components/base/button/index.tsx @@ -1,12 +1,12 @@ import type { VariantProps } from 'class-variance-authority' -import type { CSSProperties } from 'react' +import { Button as BaseButton } from '@base-ui/react/button' import { cva } from 'class-variance-authority' import * as React from 'react' import { cn } from '@/utils/classnames' import Spinner from '../spinner' const buttonVariants = cva( - 'btn disabled:btn-disabled', + 'btn', { variants: { variant: { @@ -23,6 +23,9 @@ const buttonVariants = cva( medium: 'btn-medium', large: 'btn-large', }, + destructive: { + true: 'btn-destructive', + }, }, defaultVariants: { variant: 'secondary', @@ -32,25 +35,44 @@ const buttonVariants = cva( ) export type ButtonProps = { - destructive?: boolean loading?: boolean - styleCss?: CSSProperties spinnerClassName?: string ref?: React.Ref + render?: React.ReactElement + focusableWhenDisabled?: boolean } & React.ButtonHTMLAttributes & VariantProps -const Button = ({ className, variant, size, destructive, loading, styleCss, children, spinnerClassName, ref, ...props }: ButtonProps) => { +const Button = ({ + className, + variant, + size, + destructive, + loading, + children, + spinnerClassName, + ref, + render, + focusableWhenDisabled, + disabled, + type = 'button', + ...props +}: ButtonProps) => { + const isDisabled = disabled || loading + return ( - + ) } Button.displayName = 'Button' diff --git a/web/app/components/base/chat/chat-with-history/context.tsx b/web/app/components/base/chat/chat-with-history/context.ts similarity index 100% rename from web/app/components/base/chat/chat-with-history/context.tsx rename to web/app/components/base/chat/chat-with-history/context.ts diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index c086bce327..e19c57bd83 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -23,7 +23,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' diff --git a/web/app/components/base/chat/chat/__tests__/context.spec.tsx b/web/app/components/base/chat/chat/__tests__/context.spec.tsx index fd00156e59..aeba073a7b 100644 --- a/web/app/components/base/chat/chat/__tests__/context.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/context.spec.tsx @@ -3,7 +3,8 @@ import type { ChatContextValue } from '../context' import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ChatContextProvider, useChatContext } from '../context' +import { useChatContext } from '../context' +import { ChatContextProvider } from '../context-provider' const TestConsumer = () => { const context = useChatContext() diff --git a/web/app/components/base/chat/chat/__tests__/question.spec.tsx b/web/app/components/base/chat/chat/__tests__/question.spec.tsx index 1d0584805b..7c717b6e31 100644 --- a/web/app/components/base/chat/chat/__tests__/question.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/question.spec.tsx @@ -9,7 +9,7 @@ import { vi } from 'vitest' import Toast from '../../../toast' import { ThemeBuilder } from '../../embedded-chatbot/theme/theme-context' -import { ChatContextProvider } from '../context' +import { ChatContextProvider } from '../context-provider' import Question from '../question' // Global Mocks diff --git a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx index a6d4570fcb..03f3c673ce 100644 --- a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx @@ -91,15 +91,9 @@ vi.mock('@/app/components/base/features/hooks', () => ({ // --------------------------------------------------------------------------- // Toast context // --------------------------------------------------------------------------- -vi.mock('@/app/components/base/toast', async () => { - const actual = await vi.importActual( - '@/app/components/base/toast', - ) - return { - ...actual, - useToastContext: () => ({ notify: mockNotify }), - } -}) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify, close: vi.fn() }), +})) // --------------------------------------------------------------------------- // Internal layout hook – controls single/multi-line textarea mode diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 5caede0391..25636a1fb3 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -22,7 +22,7 @@ import { FileContextProvider, useFileStore, } from '@/app/components/base/file-uploader/store' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import VoiceInput from '@/app/components/base/voice-input' import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/components/base/chat/chat/check-input-forms-hooks.ts b/web/app/components/base/chat/chat/check-input-forms-hooks.ts index 2da57b289e..842e89070b 100644 --- a/web/app/components/base/chat/chat/check-input-forms-hooks.ts +++ b/web/app/components/base/chat/chat/check-input-forms-hooks.ts @@ -1,7 +1,7 @@ import type { InputForm } from './type' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { InputVarType } from '@/app/components/workflow/types' import { TransferMethod } from '@/types/app' diff --git a/web/app/components/base/chat/chat/context.tsx b/web/app/components/base/chat/chat/context-provider.tsx similarity index 58% rename from web/app/components/base/chat/chat/context.tsx rename to web/app/components/base/chat/chat/context-provider.tsx index 7843665ad7..02503521e5 100644 --- a/web/app/components/base/chat/chat/context.tsx +++ b/web/app/components/base/chat/chat/context-provider.tsx @@ -1,30 +1,8 @@ 'use client' import type { ReactNode } from 'react' -import type { ChatProps } from './index' -import { createContext, useContext } from 'use-context-selector' - -export type ChatContextValue = Pick & { - readonly?: boolean - } - -const ChatContext = createContext({ - chatList: [], - readonly: false, -}) +import type { ChatContextValue } from './context' +import { ChatContext } from './context' type ChatContextProviderProps = { children: ReactNode @@ -71,7 +49,3 @@ export const ChatContextProvider = ({ ) } - -export const useChatContext = () => useContext(ChatContext) - -export default ChatContext diff --git a/web/app/components/base/chat/chat/context.ts b/web/app/components/base/chat/chat/context.ts new file mode 100644 index 0000000000..ff0bd26336 --- /dev/null +++ b/web/app/components/base/chat/chat/context.ts @@ -0,0 +1,30 @@ +'use client' + +import type { ChatProps } from './index' +import { createContext, useContext } from 'use-context-selector' + +export type ChatContextValue = Pick & { + readonly?: boolean + } + +export const ChatContext = createContext({ + chatList: [], + readonly: false, +}) + +export const useChatContext = () => useContext(ChatContext) + +export default ChatContext diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index cc3866c2b0..2361175e3d 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -30,7 +30,7 @@ import { getProcessedFiles, getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types' import useTimestamp from '@/hooks/use-timestamp' import { diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index a77911d895..69c064e3e2 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -30,7 +30,7 @@ import PromptLogModal from '@/app/components/base/prompt-log-modal' import { cn } from '@/utils/classnames' import Answer from './answer' import ChatInputArea from './chat-input-area' -import { ChatContextProvider } from './context' +import { ChatContextProvider } from './context-provider' import Question from './question' import TryToAsk from './try-to-ask' diff --git a/web/app/components/base/chat/embedded-chatbot/context.tsx b/web/app/components/base/chat/embedded-chatbot/context.ts similarity index 100% rename from web/app/components/base/chat/embedded-chatbot/context.tsx rename to web/app/components/base/chat/embedded-chatbot/context.ts diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index bffee78792..da142a69ec 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -21,7 +21,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' diff --git a/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx b/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx index 08c9a035e7..aad2d3d09b 100644 --- a/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx +++ b/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx @@ -16,7 +16,7 @@ vi.mock('next/navigation', () => ({ useSearchParams: () => new URLSearchParams(), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/base/confirm/index.tsx b/web/app/components/base/confirm/index.tsx index 704c94a9fe..c19fd3f625 100644 --- a/web/app/components/base/confirm/index.tsx +++ b/web/app/components/base/confirm/index.tsx @@ -1,3 +1,8 @@ +/** + * @deprecated Use `@/app/components/base/ui/alert-dialog` instead. + * See issue #32767 for migration details. + */ + import * as React from 'react' import { useEffect, useRef, useState } from 'react' import { createPortal } from 'react-dom' @@ -5,6 +10,7 @@ import { useTranslation } from 'react-i18next' import Button from '../button' import Tooltip from '../tooltip' +/** @deprecated Use `@/app/components/base/ui/alert-dialog` instead. */ export type IConfirm = { className?: string isShow: boolean diff --git a/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx b/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx index 3c690635da..88f74d2686 100644 --- a/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import ModerationSettingModal from '../moderation-setting-modal' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx b/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx index 6760e000b1..ccb90fa229 100644 --- a/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ApiBasedExtensionSelector from '@/app/components/header/account-setting/api-based-extension-page/selector' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { CustomConfigurationStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/base/file-uploader/__tests__/hooks.spec.ts b/web/app/components/base/file-uploader/__tests__/hooks.spec.ts index 00c64224aa..8343974967 100644 --- a/web/app/components/base/file-uploader/__tests__/hooks.spec.ts +++ b/web/app/components/base/file-uploader/__tests__/hooks.spec.ts @@ -11,7 +11,7 @@ vi.mock('next/navigation', () => ({ })) // Exception: hook requires toast context that isn't available without a provider wrapper -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 14e46548d8..4aab60175c 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -18,7 +18,7 @@ import { MAX_FILE_UPLOAD_LIMIT, VIDEO_SIZE_LIMIT, } from '@/app/components/base/file-uploader/constants' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { uploadRemoteFileInfo } from '@/service/common' import { TransferMethod } from '@/types/app' diff --git a/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts b/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts index b9d60594f7..28eb5bd5ed 100644 --- a/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts +++ b/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts @@ -5,7 +5,7 @@ import { useCheckValidated } from '../use-check-validated' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/form/hooks/use-check-validated.ts b/web/app/components/base/form/hooks/use-check-validated.ts index 7ed6164bb2..d186996035 100644 --- a/web/app/components/base/form/hooks/use-check-validated.ts +++ b/web/app/components/base/form/hooks/use-check-validated.ts @@ -1,7 +1,7 @@ import type { AnyFormApi } from '@tanstack/react-form' import type { FormSchema } from '@/app/components/base/form/types' import { useCallback } from 'react' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' export const useCheckValidated = (form: AnyFormApi, FormSchemas: FormSchema[]) => { const { notify } = useToastContext() diff --git a/web/app/components/base/image-uploader/__tests__/hooks.spec.ts b/web/app/components/base/image-uploader/__tests__/hooks.spec.ts index 4d150830d0..f79ea98081 100644 --- a/web/app/components/base/image-uploader/__tests__/hooks.spec.ts +++ b/web/app/components/base/image-uploader/__tests__/hooks.spec.ts @@ -5,7 +5,7 @@ import { Resolution, TransferMethod } from '@/types/app' import { useClipboardUploader, useDraggableUploader, useImageFiles, useLocalFileUploader } from '../hooks' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/base/image-uploader/hooks.ts b/web/app/components/base/image-uploader/hooks.ts index cd309a1f7b..03cf0feeca 100644 --- a/web/app/components/base/image-uploader/hooks.ts +++ b/web/app/components/base/image-uploader/hooks.ts @@ -3,7 +3,7 @@ import type { ImageFile, VisionSettings } from '@/types/app' import { useParams } from 'next/navigation' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ALLOW_FILE_EXTENSIONS, TransferMethod } from '@/types/app' import { getImageUploadErrorMessage, imageUpload } from './utils' diff --git a/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx index 305896f4f1..95ed788db3 100644 --- a/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx @@ -5,7 +5,7 @@ import userEvent from '@testing-library/user-event' // markdown-button.spec.tsx import * as React from 'react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import MarkdownButton from '../button' diff --git a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx index 2cd31f9a49..e8b956cbbf 100644 --- a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx @@ -1,6 +1,6 @@ import { act, render, screen } from '@testing-library/react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import ThinkBlock from '../think-block' // Mock react-i18next diff --git a/web/app/components/base/markdown-blocks/think-block.stories.tsx b/web/app/components/base/markdown-blocks/think-block.stories.tsx index 23713fb263..7c3f809ee7 100644 --- a/web/app/components/base/markdown-blocks/think-block.stories.tsx +++ b/web/app/components/base/markdown-blocks/think-block.stories.tsx @@ -1,6 +1,6 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import ThinkBlock from './think-block' const THOUGHT_TEXT = ` diff --git a/web/app/components/base/modal/index.tsx b/web/app/components/base/modal/index.tsx index 5745a7a7be..2a5f647e6c 100644 --- a/web/app/components/base/modal/index.tsx +++ b/web/app/components/base/modal/index.tsx @@ -1,3 +1,8 @@ +/** + * @deprecated Use `@/app/components/base/ui/dialog` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + */ import { Dialog, DialogPanel, DialogTitle, Transition, TransitionChild } from '@headlessui/react' import { noop } from 'es-toolkit/function' import { Fragment } from 'react' diff --git a/web/app/components/base/modal/modal.tsx b/web/app/components/base/modal/modal.tsx index e7138b01de..15e7eaf94a 100644 --- a/web/app/components/base/modal/modal.tsx +++ b/web/app/components/base/modal/modal.tsx @@ -1,3 +1,8 @@ +/** + * @deprecated Use `@/app/components/base/ui/dialog` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + */ import type { ButtonProps } from '@/app/components/base/button' import { noop } from 'es-toolkit/function' import { memo } from 'react' diff --git a/web/app/components/base/portal-to-follow-elem/index.tsx b/web/app/components/base/portal-to-follow-elem/index.tsx index c57fba9dd0..7d4f6baa9b 100644 --- a/web/app/components/base/portal-to-follow-elem/index.tsx +++ b/web/app/components/base/portal-to-follow-elem/index.tsx @@ -1,4 +1,16 @@ 'use client' +/** + * @deprecated Use semantic overlay primitives from `@/app/components/base/ui/` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + * + * Migration guide: + * - Tooltip → `@/app/components/base/ui/tooltip` + * - Menu/Dropdown → `@/app/components/base/ui/dropdown-menu` + * - Popover → `@/app/components/base/ui/popover` + * - Dialog/Modal → `@/app/components/base/ui/dialog` + * - Select → `@/app/components/base/ui/select` + */ import type { OffsetOptions, Placement } from '@floating-ui/react' import { autoUpdate, @@ -33,6 +45,7 @@ export type PortalToFollowElemOptions = { triggerPopupSameWidth?: boolean } +/** @deprecated Use semantic overlay primitives instead. See #32767. */ export function usePortalToFollowElem({ placement = 'bottom', open: controlledOpen, @@ -110,6 +123,7 @@ export function usePortalToFollowElemContext() { return context } +/** @deprecated Use semantic overlay primitives instead. See #32767. */ export function PortalToFollowElem({ children, ...options @@ -124,6 +138,7 @@ export function PortalToFollowElem({ ) } +/** @deprecated Use semantic overlay primitives instead. See #32767. */ export const PortalToFollowElemTrigger = ( { ref: propRef, @@ -164,6 +179,7 @@ export const PortalToFollowElemTrigger = ( } PortalToFollowElemTrigger.displayName = 'PortalToFollowElemTrigger' +/** @deprecated Use semantic overlay primitives instead. See #32767. */ export const PortalToFollowElemContent = ( { ref: propRef, diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx index 2deec561e9..6cc6c3a67f 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx @@ -28,7 +28,8 @@ import { import * as React from 'react' import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types' import { VarType } from '@/app/components/workflow/types' -import { EventEmitterContextProvider, useEventEmitterContextContext } from '@/context/event-emitter' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import { INSERT_CONTEXT_BLOCK_COMMAND } from '../../context-block' import { INSERT_CURRENT_BLOCK_COMMAND } from '../../current-block' import { INSERT_ERROR_MESSAGE_BLOCK_COMMAND } from '../../error-message-block' diff --git a/web/app/components/base/radio/context/index.tsx b/web/app/components/base/radio/context/index.ts similarity index 100% rename from web/app/components/base/radio/context/index.tsx rename to web/app/components/base/radio/context/index.ts diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index ddfa800dbb..144629c380 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -1,4 +1,9 @@ 'use client' +/** + * @deprecated Use `@/app/components/base/ui/select` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + */ import type { FC } from 'react' import { Combobox, ComboboxButton, ComboboxInput, ComboboxOption, ComboboxOptions, Listbox, ListboxButton, ListboxOption, ListboxOptions } from '@headlessui/react' import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' diff --git a/web/app/components/base/tag-input/__tests__/index.spec.tsx b/web/app/components/base/tag-input/__tests__/index.spec.tsx index b091d9cd03..f07399f8af 100644 --- a/web/app/components/base/tag-input/__tests__/index.spec.tsx +++ b/web/app/components/base/tag-input/__tests__/index.spec.tsx @@ -5,7 +5,7 @@ import TagInput from '../index' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/tag-input/index.tsx b/web/app/components/base/tag-input/index.tsx index 1c49b026fb..377e68abe4 100644 --- a/web/app/components/base/tag-input/index.tsx +++ b/web/app/components/base/tag-input/index.tsx @@ -2,7 +2,7 @@ import type { ChangeEvent, FC, KeyboardEvent } from 'react' import { useCallback, useState } from 'react' import AutosizeInput from 'react-18-input-autosize' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' type TagInputProps = { diff --git a/web/app/components/base/tag-management/__tests__/panel.spec.tsx b/web/app/components/base/tag-management/__tests__/panel.spec.tsx index c91c72e583..cd9e37e286 100644 --- a/web/app/components/base/tag-management/__tests__/panel.spec.tsx +++ b/web/app/components/base/tag-management/__tests__/panel.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { act } from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Panel from '../panel' import { useStore as useTagStore } from '../store' diff --git a/web/app/components/base/tag-management/__tests__/selector.spec.tsx b/web/app/components/base/tag-management/__tests__/selector.spec.tsx index dc58ca37e6..43f17a1e8c 100644 --- a/web/app/components/base/tag-management/__tests__/selector.spec.tsx +++ b/web/app/components/base/tag-management/__tests__/selector.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { act } from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import TagSelector from '../selector' import { useStore as useTagStore } from '../store' diff --git a/web/app/components/base/tag-management/index.tsx b/web/app/components/base/tag-management/index.tsx index e9ce85ecc0..b7682bcdad 100644 --- a/web/app/components/base/tag-management/index.tsx +++ b/web/app/components/base/tag-management/index.tsx @@ -4,7 +4,7 @@ import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { createTag, fetchTagList, diff --git a/web/app/components/base/tag-management/panel.tsx b/web/app/components/base/tag-management/panel.tsx index 4174ba0476..cebed74f3b 100644 --- a/web/app/components/base/tag-management/panel.tsx +++ b/web/app/components/base/tag-management/panel.tsx @@ -10,7 +10,7 @@ import { useContext } from 'use-context-selector' import Checkbox from '@/app/components/base/checkbox' import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { bindTag, createTag, unBindTag } from '@/service/tag' import { useStore as useTagStore } from './store' diff --git a/web/app/components/base/tag-management/tag-item-editor.tsx b/web/app/components/base/tag-management/tag-item-editor.tsx index a37e42dcce..3cff335f58 100644 --- a/web/app/components/base/tag-management/tag-item-editor.tsx +++ b/web/app/components/base/tag-management/tag-item-editor.tsx @@ -6,7 +6,7 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Confirm from '@/app/components/base/confirm' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { deleteTag, diff --git a/web/app/components/base/text-generation/__tests__/hooks.spec.ts b/web/app/components/base/text-generation/__tests__/hooks.spec.ts index cab06f1c8a..a5b5578158 100644 --- a/web/app/components/base/text-generation/__tests__/hooks.spec.ts +++ b/web/app/components/base/text-generation/__tests__/hooks.spec.ts @@ -5,7 +5,7 @@ import { useTextGeneration } from '../hooks' const mockNotify = vi.fn() const mockSsePost = vi.fn<(url: string, fetchOptions: { body: Record }, otherOptions: IOtherOptions) => void>() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/text-generation/hooks.ts b/web/app/components/base/text-generation/hooks.ts index c5d008956b..4314a81925 100644 --- a/web/app/components/base/text-generation/hooks.ts +++ b/web/app/components/base/text-generation/hooks.ts @@ -1,6 +1,6 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ssePost } from '@/service/base' export const useTextGeneration = () => { diff --git a/web/app/components/base/theme-switcher.tsx b/web/app/components/base/theme-switcher.tsx index 86e24a443c..58da8f4664 100644 --- a/web/app/components/base/theme-switcher.tsx +++ b/web/app/components/base/theme-switcher.tsx @@ -13,44 +13,50 @@ export default function ThemeSwitcher() { return (
-
handleThemeChange('system')} + aria-label="System theme" data-testid="system-theme-container" >
-
+
-
handleThemeChange('light')} + aria-label="Light theme" data-testid="light-theme-container" >
-
+
-
handleThemeChange('dark')} + aria-label="Dark theme" data-testid="dark-theme-container" >
-
+
) } diff --git a/web/app/components/base/toast/__tests__/index.spec.tsx b/web/app/components/base/toast/__tests__/index.spec.tsx index f526290fa1..2f5fa49823 100644 --- a/web/app/components/base/toast/__tests__/index.spec.tsx +++ b/web/app/components/base/toast/__tests__/index.spec.tsx @@ -2,7 +2,8 @@ import type { ReactNode } from 'react' import { act, render, screen, waitFor } from '@testing-library/react' import { noop } from 'es-toolkit/function' import * as React from 'react' -import Toast, { ToastProvider, useToastContext } from '..' +import Toast, { ToastProvider } from '..' +import { useToastContext } from '../context' const TestComponent = () => { const { notify, close } = useToastContext() diff --git a/web/app/components/base/toast/context.ts b/web/app/components/base/toast/context.ts new file mode 100644 index 0000000000..ddd8f91336 --- /dev/null +++ b/web/app/components/base/toast/context.ts @@ -0,0 +1,23 @@ +'use client' + +import type { ReactNode } from 'react' +import { createContext, useContext } from 'use-context-selector' + +export type IToastProps = { + type?: 'success' | 'error' | 'warning' | 'info' + size?: 'md' | 'sm' + duration?: number + message: string + children?: ReactNode + onClose?: () => void + className?: string + customComponent?: ReactNode +} + +type IToastContext = { + notify: (props: IToastProps) => void + close: () => void +} + +export const ToastContext = createContext({} as IToastContext) +export const useToastContext = () => useContext(ToastContext) diff --git a/web/app/components/base/toast/index.stories.tsx b/web/app/components/base/toast/index.stories.tsx index 4ab9138070..40d6fecfc2 100644 --- a/web/app/components/base/toast/index.stories.tsx +++ b/web/app/components/base/toast/index.stories.tsx @@ -1,6 +1,7 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useCallback } from 'react' -import Toast, { ToastProvider, useToastContext } from '.' +import Toast, { ToastProvider } from '.' +import { useToastContext } from './context' const ToastControls = () => { const { notify } = useToastContext() diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index 0346524e6f..a70a0db06c 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { ReactNode } from 'react' +import type { IToastProps } from './context' import { RiAlertFill, RiCheckboxCircleFill, @@ -11,31 +12,13 @@ import { noop } from 'es-toolkit/function' import * as React from 'react' import { useEffect, useState } from 'react' import { createRoot } from 'react-dom/client' -import { createContext, useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' import { cn } from '@/utils/classnames' - -export type IToastProps = { - type?: 'success' | 'error' | 'warning' | 'info' - size?: 'md' | 'sm' - duration?: number - message: string - children?: ReactNode - onClose?: () => void - className?: string - customComponent?: ReactNode -} -type IToastContext = { - notify: (props: IToastProps) => void - close: () => void -} +import { ToastContext, useToastContext } from './context' export type ToastHandle = { clear?: VoidFunction } - -export const ToastContext = createContext({} as IToastContext) -export const useToastContext = () => useContext(ToastContext) const Toast = ({ type = 'info', size = 'md', @@ -183,3 +166,5 @@ Toast.notify = ({ } export default Toast + +export type { IToastProps } from './context' diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index fc0788d81f..7eb15b2c19 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -1,4 +1,9 @@ 'use client' +/** + * @deprecated Use `@/app/components/base/ui/tooltip` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32767 + */ import type { OffsetOptions, Placement } from '@floating-ui/react' import type { FC } from 'react' import { RiQuestionLine } from '@remixicon/react' diff --git a/web/app/components/base/ui/alert-dialog/__tests__/index.spec.tsx b/web/app/components/base/ui/alert-dialog/__tests__/index.spec.tsx new file mode 100644 index 0000000000..adbcb621c9 --- /dev/null +++ b/web/app/components/base/ui/alert-dialog/__tests__/index.spec.tsx @@ -0,0 +1,145 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { + AlertDialog, + AlertDialogActions, + AlertDialogCancelButton, + AlertDialogClose, + AlertDialogConfirmButton, + AlertDialogContent, + AlertDialogDescription, + AlertDialogTitle, + AlertDialogTrigger, +} from '../index' + +describe('AlertDialog wrapper', () => { + describe('Rendering', () => { + it('should render alert dialog content when dialog is open', () => { + render( + + + Confirm Delete + This action cannot be undone. + + , + ) + + const dialog = screen.getByRole('alertdialog') + expect(dialog).toHaveTextContent('Confirm Delete') + expect(dialog).toHaveTextContent('This action cannot be undone.') + }) + + it('should not render content when dialog is closed', () => { + render( + + + Hidden Title + + , + ) + + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() + }) + }) + + describe('Props', () => { + it('should apply custom className to popup', () => { + render( + + + Title + + , + ) + + const dialog = screen.getByRole('alertdialog') + expect(dialog).toHaveClass('custom-class') + }) + + it('should not render a close button by default', () => { + render( + + + Title + + , + ) + + expect(screen.queryByRole('button', { name: 'Close' })).not.toBeInTheDocument() + }) + }) + + describe('User Interactions', () => { + it('should open and close dialog when trigger and close are clicked', async () => { + render( + + Open Dialog + + Action Required + Please confirm the action. + Cancel + + , + ) + + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'Open Dialog' })) + expect(await screen.findByRole('alertdialog')).toHaveTextContent('Action Required') + + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + await waitFor(() => { + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() + }) + }) + }) + + describe('Composition Helpers', () => { + it('should render actions wrapper and default confirm button styles', () => { + render( + + + Action Required + + Confirm + + + , + ) + + expect(screen.getByTestId('actions')).toHaveClass('flex', 'items-start', 'justify-end', 'gap-2', 'self-stretch', 'p-6', 'custom-actions') + const confirmButton = screen.getByRole('button', { name: 'Confirm' }) + expect(confirmButton).toHaveClass('btn-primary') + expect(confirmButton).toHaveClass('btn-destructive') + }) + + it('should keep dialog open after confirm click and close via cancel helper', async () => { + const onConfirm = vi.fn() + + render( + + Open Dialog + + Action Required + + Cancel + Confirm + + + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'Open Dialog' })) + expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + expect(onConfirm).toHaveBeenCalledTimes(1) + expect(screen.getByRole('alertdialog')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + await waitFor(() => { + expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/base/ui/alert-dialog/index.tsx b/web/app/components/base/ui/alert-dialog/index.tsx new file mode 100644 index 0000000000..8d48c5b998 --- /dev/null +++ b/web/app/components/base/ui/alert-dialog/index.tsx @@ -0,0 +1,106 @@ +'use client' + +import type { ButtonProps } from '@/app/components/base/button' +import { AlertDialog as BaseAlertDialog } from '@base-ui/react/alert-dialog' +import * as React from 'react' +import Button from '@/app/components/base/button' +import { cn } from '@/utils/classnames' + +// z-index strategy (relies on root `isolation: isolate` in layout.tsx): +// All overlay primitives (Tooltip / Popover / Dropdown / Select / Dialog / AlertDialog) — z-50 +// Overlays share the same z-index; DOM order handles stacking when multiple are open. +// This ensures overlays inside an AlertDialog (e.g. a Tooltip on a dialog button) render +// above the dialog backdrop instead of being clipped by it. +// Toast — z-[99], always on top (defined in toast component) + +export const AlertDialog = BaseAlertDialog.Root +export const AlertDialogTrigger = BaseAlertDialog.Trigger +export const AlertDialogTitle = BaseAlertDialog.Title +export const AlertDialogDescription = BaseAlertDialog.Description +export const AlertDialogClose = BaseAlertDialog.Close + +type AlertDialogContentProps = { + children: React.ReactNode + className?: string + overlayClassName?: string + popupProps?: Omit, 'children' | 'className'> + backdropProps?: Omit, 'className'> +} + +export function AlertDialogContent({ + children, + className, + overlayClassName, + popupProps, + backdropProps, +}: AlertDialogContentProps) { + return ( + + + + {children} + + + ) +} + +type AlertDialogActionsProps = React.ComponentPropsWithoutRef<'div'> + +export function AlertDialogActions({ className, ...props }: AlertDialogActionsProps) { + return ( +
+ ) +} + +type AlertDialogCancelButtonProps = Omit & { + children: React.ReactNode + closeProps?: Omit, 'children' | 'render'> +} + +export function AlertDialogCancelButton({ + children, + closeProps, + ...buttonProps +}: AlertDialogCancelButtonProps) { + return ( + } + > + {children} + + ) +} + +type AlertDialogConfirmButtonProps = ButtonProps + +export function AlertDialogConfirmButton({ + variant = 'primary', + destructive = true, + ...props +}: AlertDialogConfirmButtonProps) { + return ( +
- {doc.name} + {doc.name} {doc.summary_index_status && (
@@ -113,7 +114,7 @@ const DocumentTableRow: FC = React.memo(({ className="cursor-pointer rounded-md p-1 hover:bg-state-base-hover" onClick={handleRenameClick} > - +
diff --git a/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx b/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx index 1dc13df2b0..1d693565cb 100644 --- a/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx @@ -1,6 +1,5 @@ import type { FC } from 'react' import type { SortField, SortOrder } from '../hooks' -import { RiArrowDownLine } from '@remixicon/react' import * as React from 'react' import { cn } from '@/utils/classnames' @@ -23,19 +22,20 @@ const SortHeader: FC = React.memo(({ const isDesc = isActive && sortOrder === 'desc' return ( -
onSort(field)} > {label} - -
+ ) }) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-sort.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-sort.spec.ts index 43bc0e1dd5..004597afa9 100644 --- a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-sort.spec.ts +++ b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-sort.spec.ts @@ -1,340 +1,98 @@ -import type { SimpleDocumentDetail } from '@/models/datasets' import { act, renderHook } from '@testing-library/react' -import { describe, expect, it } from 'vitest' +import { describe, expect, it, vi } from 'vitest' import { useDocumentSort } from '../use-document-sort' -type LocalDoc = SimpleDocumentDetail & { percent?: number } - -const createMockDocument = (overrides: Partial = {}): LocalDoc => ({ - id: 'doc1', - name: 'Test Document', - data_source_type: 'upload_file', - data_source_info: {}, - data_source_detail_dict: {}, - word_count: 100, - hit_count: 10, - created_at: 1000000, - position: 1, - doc_form: 'text_model', - enabled: true, - archived: false, - display_status: 'available', - created_from: 'api', - ...overrides, -} as LocalDoc) - describe('useDocumentSort', () => { - describe('initial state', () => { - it('should return null sortField initially', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) + describe('remote state parsing', () => { + it('should parse descending created_at sort', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-created_at', + onRemoteSortChange, + })) - expect(result.current.sortField).toBeNull() + expect(result.current.sortField).toBe('created_at') expect(result.current.sortOrder).toBe('desc') }) - it('should return documents unchanged when no sort is applied', () => { - const docs = [ - createMockDocument({ id: 'doc1', name: 'B' }), - createMockDocument({ id: 'doc2', name: 'A' }), - ] + it('should parse ascending hit_count sort', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: 'hit_count', + onRemoteSortChange, + })) - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) + expect(result.current.sortField).toBe('hit_count') + expect(result.current.sortOrder).toBe('asc') + }) - expect(result.current.sortedDocuments).toEqual(docs) + it('should fallback to inactive field for unsupported sort key', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-name', + onRemoteSortChange, + })) + + expect(result.current.sortField).toBeNull() + expect(result.current.sortOrder).toBe('desc') }) }) describe('handleSort', () => { - it('should set sort field when called', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - - expect(result.current.sortField).toBe('name') - expect(result.current.sortOrder).toBe('desc') - }) - - it('should toggle sort order when same field is clicked twice', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortOrder).toBe('desc') - - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortOrder).toBe('asc') - - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortOrder).toBe('desc') - }) - - it('should reset to desc when different field is selected', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortOrder).toBe('asc') - - act(() => { - result.current.handleSort('word_count') - }) - expect(result.current.sortField).toBe('word_count') - expect(result.current.sortOrder).toBe('desc') - }) - - it('should not change state when null is passed', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort(null) - }) - - expect(result.current.sortField).toBeNull() - }) - }) - - describe('sorting documents', () => { - const docs = [ - createMockDocument({ id: 'doc1', name: 'Banana', word_count: 200, hit_count: 5, created_at: 3000 }), - createMockDocument({ id: 'doc2', name: 'Apple', word_count: 100, hit_count: 10, created_at: 1000 }), - createMockDocument({ id: 'doc3', name: 'Cherry', word_count: 300, hit_count: 1, created_at: 2000 }), - ] - - it('should sort by name descending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - - const names = result.current.sortedDocuments.map(d => d.name) - expect(names).toEqual(['Cherry', 'Banana', 'Apple']) - }) - - it('should sort by name ascending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - act(() => { - result.current.handleSort('name') - }) - - const names = result.current.sortedDocuments.map(d => d.name) - expect(names).toEqual(['Apple', 'Banana', 'Cherry']) - }) - - it('should sort by word_count descending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('word_count') - }) - - const counts = result.current.sortedDocuments.map(d => d.word_count) - expect(counts).toEqual([300, 200, 100]) - }) - - it('should sort by hit_count ascending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) + it('should switch to desc when selecting a different field', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-created_at', + onRemoteSortChange, + })) act(() => { result.current.handleSort('hit_count') }) + + expect(onRemoteSortChange).toHaveBeenCalledWith('-hit_count') + }) + + it('should toggle desc -> asc when clicking active field', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-hit_count', + onRemoteSortChange, + })) + act(() => { result.current.handleSort('hit_count') }) - const counts = result.current.sortedDocuments.map(d => d.hit_count) - expect(counts).toEqual([1, 5, 10]) + expect(onRemoteSortChange).toHaveBeenCalledWith('hit_count') }) - it('should sort by created_at descending', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) + it('should toggle asc -> desc when clicking active field', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: 'created_at', + onRemoteSortChange, + })) act(() => { result.current.handleSort('created_at') }) - const times = result.current.sortedDocuments.map(d => d.created_at) - expect(times).toEqual([3000, 2000, 1000]) - }) - }) - - describe('status filtering', () => { - const docs = [ - createMockDocument({ id: 'doc1', display_status: 'available' }), - createMockDocument({ id: 'doc2', display_status: 'error' }), - createMockDocument({ id: 'doc3', display_status: 'available' }), - ] - - it('should not filter when statusFilterValue is empty', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - expect(result.current.sortedDocuments.length).toBe(3) + expect(onRemoteSortChange).toHaveBeenCalledWith('-created_at') }) - it('should not filter when statusFilterValue is all', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: 'all', - remoteSortValue: '', - }), - ) - - expect(result.current.sortedDocuments.length).toBe(3) - }) - }) - - describe('remoteSortValue reset', () => { - it('should reset sort state when remoteSortValue changes', () => { - const { result, rerender } = renderHook( - ({ remoteSortValue }) => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue, - }), - { initialProps: { remoteSortValue: 'initial' } }, - ) + it('should ignore null field', () => { + const onRemoteSortChange = vi.fn() + const { result } = renderHook(() => useDocumentSort({ + remoteSortValue: '-created_at', + onRemoteSortChange, + })) act(() => { - result.current.handleSort('name') - }) - act(() => { - result.current.handleSort('name') - }) - expect(result.current.sortField).toBe('name') - expect(result.current.sortOrder).toBe('asc') - - rerender({ remoteSortValue: 'changed' }) - - expect(result.current.sortField).toBeNull() - expect(result.current.sortOrder).toBe('desc') - }) - }) - - describe('edge cases', () => { - it('should handle documents with missing values', () => { - const docs = [ - createMockDocument({ id: 'doc1', name: undefined as unknown as string, word_count: undefined }), - createMockDocument({ id: 'doc2', name: 'Test', word_count: 100 }), - ] - - const { result } = renderHook(() => - useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') + result.current.handleSort(null) }) - expect(result.current.sortedDocuments.length).toBe(2) - }) - - it('should handle empty documents array', () => { - const { result } = renderHook(() => - useDocumentSort({ - documents: [], - statusFilterValue: '', - remoteSortValue: '', - }), - ) - - act(() => { - result.current.handleSort('name') - }) - - expect(result.current.sortedDocuments).toEqual([]) + expect(onRemoteSortChange).not.toHaveBeenCalled() }) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts index 98cf244f36..0e0b07db6f 100644 --- a/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts @@ -1,102 +1,42 @@ -import type { SimpleDocumentDetail } from '@/models/datasets' -import { useCallback, useMemo, useRef, useState } from 'react' -import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter' +import { useCallback, useMemo } from 'react' -export type SortField = 'name' | 'word_count' | 'hit_count' | 'created_at' | null +type RemoteSortField = 'hit_count' | 'created_at' +const REMOTE_SORT_FIELDS = new Set(['hit_count', 'created_at']) + +export type SortField = RemoteSortField | null export type SortOrder = 'asc' | 'desc' -type LocalDoc = SimpleDocumentDetail & { percent?: number } - type UseDocumentSortOptions = { - documents: LocalDoc[] - statusFilterValue: string remoteSortValue: string + onRemoteSortChange: (nextSortValue: string) => void } export const useDocumentSort = ({ - documents, - statusFilterValue, remoteSortValue, + onRemoteSortChange, }: UseDocumentSortOptions) => { - const [sortField, setSortField] = useState(null) - const [sortOrder, setSortOrder] = useState('desc') - const prevRemoteSortValueRef = useRef(remoteSortValue) + const sortOrder: SortOrder = remoteSortValue.startsWith('-') ? 'desc' : 'asc' + const sortKey = remoteSortValue.startsWith('-') ? remoteSortValue.slice(1) : remoteSortValue - // Reset sort when remote sort changes - if (prevRemoteSortValueRef.current !== remoteSortValue) { - prevRemoteSortValueRef.current = remoteSortValue - setSortField(null) - setSortOrder('desc') - } + const sortField = useMemo(() => { + return REMOTE_SORT_FIELDS.has(sortKey as RemoteSortField) ? sortKey as RemoteSortField : null + }, [sortKey]) const handleSort = useCallback((field: SortField) => { - if (field === null) + if (!field) return if (sortField === field) { - setSortOrder(prev => prev === 'asc' ? 'desc' : 'asc') + const nextSortOrder = sortOrder === 'desc' ? 'asc' : 'desc' + onRemoteSortChange(nextSortOrder === 'desc' ? `-${field}` : field) + return } - else { - setSortField(field) - setSortOrder('desc') - } - }, [sortField]) - - const sortedDocuments = useMemo(() => { - let filteredDocs = documents - - if (statusFilterValue && statusFilterValue !== 'all') { - filteredDocs = filteredDocs.filter(doc => - typeof doc.display_status === 'string' - && normalizeStatusForQuery(doc.display_status) === statusFilterValue, - ) - } - - if (!sortField) - return filteredDocs - - const sortedDocs = [...filteredDocs].sort((a, b) => { - let aValue: string | number - let bValue: string | number - - switch (sortField) { - case 'name': - aValue = a.name?.toLowerCase() || '' - bValue = b.name?.toLowerCase() || '' - break - case 'word_count': - aValue = a.word_count || 0 - bValue = b.word_count || 0 - break - case 'hit_count': - aValue = a.hit_count || 0 - bValue = b.hit_count || 0 - break - case 'created_at': - aValue = a.created_at - bValue = b.created_at - break - default: - return 0 - } - - if (sortField === 'name') { - const result = (aValue as string).localeCompare(bValue as string) - return sortOrder === 'asc' ? result : -result - } - else { - const result = (aValue as number) - (bValue as number) - return sortOrder === 'asc' ? result : -result - } - }) - - return sortedDocs - }, [documents, sortField, sortOrder, statusFilterValue]) + onRemoteSortChange(`-${field}`) + }, [onRemoteSortChange, sortField, sortOrder]) return { sortField, sortOrder, handleSort, - sortedDocuments, } } diff --git a/web/app/components/datasets/documents/components/list.tsx b/web/app/components/datasets/documents/components/list.tsx index 3106f6c30b..e40e4c061b 100644 --- a/web/app/components/datasets/documents/components/list.tsx +++ b/web/app/components/datasets/documents/components/list.tsx @@ -14,7 +14,7 @@ import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from ' import { ChunkingMode, DocumentActionType } from '@/models/datasets' import BatchAction from '../detail/completed/common/batch-action' import s from '../style.module.css' -import { DocumentTableRow, renderTdValue, SortHeader } from './document-list/components' +import { DocumentTableRow, SortHeader } from './document-list/components' import { useDocumentActions, useDocumentSelection, useDocumentSort } from './document-list/hooks' import RenameModal from './rename-modal' @@ -29,8 +29,8 @@ type DocumentListProps = { pagination: PaginationProps onUpdate: () => void onManageMetadata: () => void - statusFilterValue: string remoteSortValue: string + onSortChange: (value: string) => void } /** @@ -45,8 +45,8 @@ const DocumentList: FC = ({ pagination, onUpdate, onManageMetadata, - statusFilterValue, remoteSortValue, + onSortChange, }) => { const { t } = useTranslation() const datasetConfig = useDatasetDetailContext(s => s.dataset) @@ -55,10 +55,9 @@ const DocumentList: FC = ({ const isQAMode = chunkingMode === ChunkingMode.qa // Sorting - const { sortField, sortOrder, handleSort, sortedDocuments } = useDocumentSort({ - documents, - statusFilterValue, + const { sortField, sortOrder, handleSort } = useDocumentSort({ remoteSortValue, + onRemoteSortChange: onSortChange, }) // Selection @@ -71,7 +70,7 @@ const DocumentList: FC = ({ downloadableSelectedIds, clearSelection, } = useDocumentSelection({ - documents: sortedDocuments, + documents, selectedIds, onSelectedIdChange, }) @@ -135,24 +134,10 @@ const DocumentList: FC = ({
- + {t('list.table.header.fileName', { ns: 'datasetDocuments' })} {t('list.table.header.chunkingMode', { ns: 'datasetDocuments' })} - - - + {t('list.table.header.words', { ns: 'datasetDocuments' })} = ({ - {sortedDocuments.map((doc, index) => ( + {documents.map((doc, index) => ( = ({ } export default DocumentList - -export { renderTdValue } diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx index 15c89a9b26..84e16c7c48 100644 --- a/web/app/components/datasets/documents/components/operations.tsx +++ b/web/app/components/datasets/documents/components/operations.tsx @@ -24,7 +24,7 @@ import Divider from '@/app/components/base/divider' import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge' import CustomPopover from '@/app/components/base/popover' import Switch from '@/app/components/base/switch' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { IS_CE_EDITION } from '@/config' import { DataSourceType, DocumentActionType } from '@/models/datasets' diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx index bc9ce04beb..efd1f2a483 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx @@ -9,7 +9,7 @@ const mockNotify = vi.fn() const mockClose = vi.fn() // Mock ToastContext with factory function -vi.mock('@/app/components/base/toast', async () => { +vi.mock('@/app/components/base/toast/context', async () => { const { createContext, useContext } = await import('use-context-selector') const context = createContext({ notify: mockNotify, close: mockClose }) return { @@ -87,7 +87,7 @@ vi.mock('@/service/base', () => ({ // Import after all mocks are set up const { useLocalFileUpload } = await import('../use-local-file-upload') -const { ToastContext } = await import('@/app/components/base/toast') +const { ToastContext } = await import('@/app/components/base/toast/context') const createWrapper = () => { return ({ children }: { children: ReactNode }) => ( diff --git a/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx b/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx index ad8741a8e1..f01a64e34e 100644 --- a/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx @@ -9,6 +9,7 @@ const mocks = vi.hoisted(() => { documentError: null as Error | null, documentMetadata: null as Record | null, media: 'desktop' as string, + searchParams: '' as string, } return { state, @@ -26,6 +27,7 @@ const mocks = vi.hoisted(() => { // --- External mocks --- vi.mock('next/navigation', () => ({ useRouter: () => ({ push: mocks.push }), + useSearchParams: () => new URLSearchParams(mocks.state.searchParams), })) vi.mock('@/hooks/use-breakpoints', () => ({ @@ -193,6 +195,7 @@ describe('DocumentDetail', () => { mocks.state.documentError = null mocks.state.documentMetadata = null mocks.state.media = 'desktop' + mocks.state.searchParams = '' }) afterEach(() => { @@ -286,15 +289,23 @@ describe('DocumentDetail', () => { }) it('should toggle metadata panel when button clicked', () => { - const { container } = render() + render() expect(screen.getByTestId('metadata')).toBeInTheDocument() - const svgs = container.querySelectorAll('svg') - const toggleBtn = svgs[svgs.length - 1].closest('button')! - fireEvent.click(toggleBtn) + fireEvent.click(screen.getByTestId('document-detail-metadata-toggle')) expect(screen.queryByTestId('metadata')).not.toBeInTheDocument() }) + it('should expose aria semantics for metadata toggle button', () => { + render() + const toggle = screen.getByTestId('document-detail-metadata-toggle') + expect(toggle).toHaveAttribute('aria-label') + expect(toggle).toHaveAttribute('aria-pressed', 'true') + + fireEvent.click(toggle) + expect(toggle).toHaveAttribute('aria-pressed', 'false') + }) + it('should pass correct props to Metadata', () => { render() const metadata = screen.getByTestId('metadata') @@ -305,20 +316,21 @@ describe('DocumentDetail', () => { describe('Navigation', () => { it('should navigate back when back button clicked', () => { - const { container } = render() - const backBtn = container.querySelector('svg')!.parentElement! - fireEvent.click(backBtn) + render() + fireEvent.click(screen.getByTestId('document-detail-back-button')) expect(mocks.push).toHaveBeenCalledWith('/datasets/ds-1/documents') }) + it('should expose aria label for back button', () => { + render() + expect(screen.getByTestId('document-detail-back-button')).toHaveAttribute('aria-label') + }) + it('should preserve query params when navigating back', () => { - const origLocation = window.location - window.history.pushState({}, '', '?page=2&status=active') - const { container } = render() - const backBtn = container.querySelector('svg')!.parentElement! - fireEvent.click(backBtn) + mocks.state.searchParams = 'page=2&status=active' + render() + fireEvent.click(screen.getByTestId('document-detail-back-button')) expect(mocks.push).toHaveBeenCalledWith('/datasets/ds-1/documents?page=2&status=active') - window.history.pushState({}, '', origLocation.href) }) }) diff --git a/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx b/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx index 7fb1de7cf9..6876753714 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx @@ -25,7 +25,7 @@ vi.mock('@/hooks/use-theme', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: ReactNode }) => children, Consumer: ({ children }: { children: (ctx: { notify: typeof mockNotify }) => ReactNode }) => children({ notify: mockNotify }), diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx index 919895f104..93992a1aba 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' import SimplePieChart from '@/app/components/base/simple-pie-chart' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import useTheme from '@/hooks/use-theme' import { upload } from '@/service/base' import { useFileUploadConfig } from '@/service/use-common' diff --git a/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx b/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx index 5802fb8b82..59ecbf5f25 100644 --- a/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx @@ -65,7 +65,7 @@ vi.mock('../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: React.ReactNode }) => children, Consumer: () => null }, useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx b/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx index 719e2867b7..cc7f1aafa4 100644 --- a/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx @@ -1,7 +1,8 @@ import type { ReactNode } from 'react' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { EventEmitterContextProvider, useEventEmitterContextContext } from '@/context/event-emitter' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import RegenerationModal from '../regeneration-modal' // Store emit function for triggering events in tests diff --git a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts index 83918a3f30..4cfb4d5927 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts @@ -59,7 +59,7 @@ vi.mock('../../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts index aef2053298..f54c00e3e7 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts @@ -92,7 +92,7 @@ vi.mock('../../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts b/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts index 4f4c6a532d..cdc8a0b22d 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts @@ -2,7 +2,7 @@ import type { ChildChunkDetail, ChildSegmentsResponse, SegmentDetailModel, Segme import { useQueryClient } from '@tanstack/react-query' import { useCallback, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useChildSegmentList, diff --git a/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts b/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts index fd391d2864..aa91e9f464 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts @@ -4,7 +4,7 @@ import { useQueryClient } from '@tanstack/react-query' import { usePathname } from 'next/navigation' import { useCallback, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { ChunkingMode } from '@/models/datasets' import { diff --git a/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx b/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx index 19f27743b1..e28fb774fb 100644 --- a/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx +++ b/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx @@ -8,7 +8,7 @@ import { useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import Divider from '@/app/components/base/divider' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { ChunkingMode } from '@/models/datasets' import { useAddChildSegment } from '@/service/knowledge/use-segment' import { cn } from '@/utils/classnames' diff --git a/web/app/components/datasets/documents/detail/embedding/index.tsx b/web/app/components/datasets/documents/detail/embedding/index.tsx index e89a85c6de..bd344800db 100644 --- a/web/app/components/datasets/documents/detail/embedding/index.tsx +++ b/web/app/components/datasets/documents/detail/embedding/index.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useProcessRule } from '@/service/knowledge/use-dataset' import { useDocumentContext } from '../context' import { ProgressBar, RuleDetail, SegmentProgress, StatusHeader } from './components' diff --git a/web/app/components/datasets/documents/detail/index.tsx b/web/app/components/datasets/documents/detail/index.tsx index e147bf9aba..b6842605c6 100644 --- a/web/app/components/datasets/documents/detail/index.tsx +++ b/web/app/components/datasets/documents/detail/index.tsx @@ -1,8 +1,7 @@ 'use client' import type { FC } from 'react' import type { DataSourceInfo, FileItem, FullDocumentDetail, LegacyDataSourceInfo } from '@/models/datasets' -import { RiArrowLeftLine, RiLayoutLeft2Line, RiLayoutRight2Line } from '@remixicon/react' -import { useRouter } from 'next/navigation' +import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -35,6 +34,7 @@ type DocumentDetailProps = { const DocumentDetail: FC = ({ datasetId, documentId }) => { const router = useRouter() + const searchParams = useSearchParams() const { t } = useTranslation() const media = useBreakpoints() @@ -98,11 +98,8 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { }) const backToPrev = () => { - // Preserve pagination and filter states when navigating back - const searchParams = new URLSearchParams(window.location.search) const queryString = searchParams.toString() - const separator = queryString ? '?' : '' - const backPath = `/datasets/${datasetId}/documents${separator}${queryString}` + const backPath = `/datasets/${datasetId}/documents${queryString ? `?${queryString}` : ''}` router.push(backPath) } @@ -152,6 +149,11 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { return chunkMode === ChunkingMode.parentChild && parentMode === 'full-doc' }, [documentDetail?.doc_form, parentMode]) + const backButtonLabel = t('operation.back', { ns: 'common' }) + const metadataToggleLabel = `${showMetadata + ? t('operation.close', { ns: 'common' }) + : t('operation.view', { ns: 'common' })} ${t('metadata.title', { ns: 'datasetDocuments' })}` + return ( = ({ datasetId, documentId }) => { >
-
- -
+ = ({ datasetId, documentId }) => { />
diff --git a/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts b/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts index ab1d45338f..3d7b28c78c 100644 --- a/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts +++ b/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts @@ -3,7 +3,7 @@ import type { FullDocumentDetail } from '@/models/datasets' import { act, renderHook } from '@testing-library/react' import * as React from 'react' import { describe, expect, it, vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useMetadataState } from '../use-metadata-state' diff --git a/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts b/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts index 08651b699e..f786609981 100644 --- a/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts +++ b/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts @@ -4,7 +4,7 @@ import type { DocType, FullDocumentDetail } from '@/models/datasets' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { modifyDocMetadata } from '@/service/datasets' import { asyncRunSafe } from '@/utils' import { useDocumentContext } from '../../context' diff --git a/web/app/components/datasets/documents/detail/new-segment.tsx b/web/app/components/datasets/documents/detail/new-segment.tsx index bd35613e42..d2e27e9969 100644 --- a/web/app/components/datasets/documents/detail/new-segment.tsx +++ b/web/app/components/datasets/documents/detail/new-segment.tsx @@ -9,7 +9,7 @@ import { useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import Divider from '@/app/components/base/divider' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import ImageUploaderInChunk from '@/app/components/datasets/common/image-uploader/image-uploader-in-chunk' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { ChunkingMode } from '@/models/datasets' diff --git a/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.ts b/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.ts deleted file mode 100644 index e31d4ac547..0000000000 --- a/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.ts +++ /dev/null @@ -1,439 +0,0 @@ -import type { DocumentListQuery } from '../use-document-list-query-state' -import { act, renderHook } from '@testing-library/react' - -import { beforeEach, describe, expect, it, vi } from 'vitest' -import useDocumentListQueryState from '../use-document-list-query-state' - -const mockPush = vi.fn() -const mockSearchParams = new URLSearchParams() - -vi.mock('@/models/datasets', () => ({ - DisplayStatusList: [ - 'queuing', - 'indexing', - 'paused', - 'error', - 'available', - 'enabled', - 'disabled', - 'archived', - ], -})) - -vi.mock('next/navigation', () => ({ - useRouter: () => ({ push: mockPush }), - usePathname: () => '/datasets/test-id/documents', - useSearchParams: () => mockSearchParams, -})) - -describe('useDocumentListQueryState', () => { - beforeEach(() => { - vi.clearAllMocks() - // Reset mock search params to empty - for (const key of [...mockSearchParams.keys()]) - mockSearchParams.delete(key) - }) - - // Tests for parseParams (exposed via the query property) - describe('parseParams (via query)', () => { - it('should return default query when no search params present', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query).toEqual({ - page: 1, - limit: 10, - keyword: '', - status: 'all', - sort: '-created_at', - }) - }) - - it('should parse page from search params', () => { - mockSearchParams.set('page', '3') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.page).toBe(3) - }) - - it('should default page to 1 when page is zero', () => { - mockSearchParams.set('page', '0') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.page).toBe(1) - }) - - it('should default page to 1 when page is negative', () => { - mockSearchParams.set('page', '-5') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.page).toBe(1) - }) - - it('should default page to 1 when page is NaN', () => { - mockSearchParams.set('page', 'abc') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.page).toBe(1) - }) - - it('should parse limit from search params', () => { - mockSearchParams.set('limit', '50') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(50) - }) - - it('should default limit to 10 when limit is zero', () => { - mockSearchParams.set('limit', '0') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(10) - }) - - it('should default limit to 10 when limit exceeds 100', () => { - mockSearchParams.set('limit', '101') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(10) - }) - - it('should default limit to 10 when limit is negative', () => { - mockSearchParams.set('limit', '-1') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(10) - }) - - it('should accept limit at boundary 100', () => { - mockSearchParams.set('limit', '100') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(100) - }) - - it('should accept limit at boundary 1', () => { - mockSearchParams.set('limit', '1') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.limit).toBe(1) - }) - - it('should parse and decode keyword from search params', () => { - mockSearchParams.set('keyword', encodeURIComponent('hello world')) - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.keyword).toBe('hello world') - }) - - it('should return empty keyword when not present', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.keyword).toBe('') - }) - - it('should sanitize status from search params', () => { - mockSearchParams.set('status', 'available') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.status).toBe('available') - }) - - it('should fallback status to all for unknown status', () => { - mockSearchParams.set('status', 'badvalue') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.status).toBe('all') - }) - - it('should resolve active status alias to available', () => { - mockSearchParams.set('status', 'active') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.status).toBe('available') - }) - - it('should parse valid sort value from search params', () => { - mockSearchParams.set('sort', 'hit_count') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.sort).toBe('hit_count') - }) - - it('should default sort to -created_at for invalid sort value', () => { - mockSearchParams.set('sort', 'invalid_sort') - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.sort).toBe('-created_at') - }) - - it('should default sort to -created_at when not present', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.sort).toBe('-created_at') - }) - - it.each([ - '-created_at', - 'created_at', - '-hit_count', - 'hit_count', - ] as const)('should accept valid sort value %s', (sortValue) => { - mockSearchParams.set('sort', sortValue) - - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current.query.sort).toBe(sortValue) - }) - }) - - // Tests for updateQuery - describe('updateQuery', () => { - it('should call router.push with updated params when page is changed', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 3 }) - }) - - expect(mockPush).toHaveBeenCalledTimes(1) - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('page=3') - }) - - it('should call router.push with scroll false', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 2 }) - }) - - expect(mockPush).toHaveBeenCalledWith( - expect.any(String), - { scroll: false }, - ) - }) - - it('should set status in URL when status is not all', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ status: 'error' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('status=error') - }) - - it('should not set status in URL when status is all', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ status: 'all' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('status=') - }) - - it('should set sort in URL when sort is not the default -created_at', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ sort: 'hit_count' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('sort=hit_count') - }) - - it('should not set sort in URL when sort is default -created_at', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ sort: '-created_at' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('sort=') - }) - - it('should encode keyword in URL when keyword is provided', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ keyword: 'test query' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - // Source code applies encodeURIComponent before setting in URLSearchParams - expect(pushedUrl).toContain('keyword=') - const params = new URLSearchParams(pushedUrl.split('?')[1]) - // params.get decodes one layer, but the value was pre-encoded with encodeURIComponent - expect(decodeURIComponent(params.get('keyword')!)).toBe('test query') - }) - - it('should remove keyword from URL when keyword is empty', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ keyword: '' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('keyword=') - }) - - it('should sanitize invalid status to all and not include in URL', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ status: 'invalidstatus' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('status=') - }) - - it('should sanitize invalid sort to -created_at and not include in URL', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ sort: 'invalidsort' as DocumentListQuery['sort'] }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('sort=') - }) - - it('should omit page and limit when they are default and no keyword', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 1, limit: 10 }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).not.toContain('page=') - expect(pushedUrl).not.toContain('limit=') - }) - - it('should include page and limit when page is greater than 1', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 2 }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('page=2') - expect(pushedUrl).toContain('limit=10') - }) - - it('should include page and limit when limit is non-default', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ limit: 25 }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('page=1') - expect(pushedUrl).toContain('limit=25') - }) - - it('should include page and limit when keyword is provided', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ keyword: 'search' }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toContain('page=1') - expect(pushedUrl).toContain('limit=10') - }) - - it('should use pathname prefix in pushed URL', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({ page: 2 }) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toMatch(/^\/datasets\/test-id\/documents/) - }) - - it('should push path without query string when all values are defaults', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.updateQuery({}) - }) - - const pushedUrl = mockPush.mock.calls[0][0] as string - expect(pushedUrl).toBe('/datasets/test-id/documents') - }) - }) - - // Tests for resetQuery - describe('resetQuery', () => { - it('should push URL with default query params when called', () => { - mockSearchParams.set('page', '5') - mockSearchParams.set('status', 'error') - - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.resetQuery() - }) - - expect(mockPush).toHaveBeenCalledTimes(1) - const pushedUrl = mockPush.mock.calls[0][0] as string - // Default query has all defaults, so no params should be in the URL - expect(pushedUrl).toBe('/datasets/test-id/documents') - }) - - it('should call router.push with scroll false when resetting', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - act(() => { - result.current.resetQuery() - }) - - expect(mockPush).toHaveBeenCalledWith( - expect.any(String), - { scroll: false }, - ) - }) - }) - - // Tests for return value stability - describe('return value', () => { - it('should return query, updateQuery, and resetQuery', () => { - const { result } = renderHook(() => useDocumentListQueryState()) - - expect(result.current).toHaveProperty('query') - expect(result.current).toHaveProperty('updateQuery') - expect(result.current).toHaveProperty('resetQuery') - expect(typeof result.current.updateQuery).toBe('function') - expect(typeof result.current.resetQuery).toBe('function') - }) - }) -}) diff --git a/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.tsx b/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.tsx new file mode 100644 index 0000000000..5879e72782 --- /dev/null +++ b/web/app/components/datasets/documents/hooks/__tests__/use-document-list-query-state.spec.tsx @@ -0,0 +1,426 @@ +import type { DocumentListQuery } from '../use-document-list-query-state' +import { act, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { renderHookWithNuqs } from '@/test/nuqs-testing' +import { useDocumentListQueryState } from '../use-document-list-query-state' + +vi.mock('@/models/datasets', () => ({ + DisplayStatusList: [ + 'queuing', + 'indexing', + 'paused', + 'error', + 'available', + 'enabled', + 'disabled', + 'archived', + ], +})) + +const renderWithAdapter = (searchParams = '') => { + return renderHookWithNuqs(() => useDocumentListQueryState(), { searchParams }) +} + +describe('useDocumentListQueryState', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('query parsing', () => { + it('should return default query when no search params present', () => { + const { result } = renderWithAdapter() + + expect(result.current.query).toEqual({ + page: 1, + limit: 10, + keyword: '', + status: 'all', + sort: '-created_at', + }) + }) + + it('should parse page from search params', () => { + const { result } = renderWithAdapter('?page=3') + expect(result.current.query.page).toBe(3) + }) + + it('should default page to 1 when page is zero', () => { + const { result } = renderWithAdapter('?page=0') + expect(result.current.query.page).toBe(1) + }) + + it('should default page to 1 when page is negative', () => { + const { result } = renderWithAdapter('?page=-5') + expect(result.current.query.page).toBe(1) + }) + + it('should default page to 1 when page is NaN', () => { + const { result } = renderWithAdapter('?page=abc') + expect(result.current.query.page).toBe(1) + }) + + it('should parse limit from search params', () => { + const { result } = renderWithAdapter('?limit=50') + expect(result.current.query.limit).toBe(50) + }) + + it('should default limit to 10 when limit is zero', () => { + const { result } = renderWithAdapter('?limit=0') + expect(result.current.query.limit).toBe(10) + }) + + it('should default limit to 10 when limit exceeds 100', () => { + const { result } = renderWithAdapter('?limit=101') + expect(result.current.query.limit).toBe(10) + }) + + it('should default limit to 10 when limit is negative', () => { + const { result } = renderWithAdapter('?limit=-1') + expect(result.current.query.limit).toBe(10) + }) + + it('should accept limit at boundary 100', () => { + const { result } = renderWithAdapter('?limit=100') + expect(result.current.query.limit).toBe(100) + }) + + it('should accept limit at boundary 1', () => { + const { result } = renderWithAdapter('?limit=1') + expect(result.current.query.limit).toBe(1) + }) + + it('should parse keyword from search params', () => { + const { result } = renderWithAdapter('?keyword=hello+world') + expect(result.current.query.keyword).toBe('hello world') + }) + + it('should preserve legacy double-encoded keyword text after URL decoding', () => { + const { result } = renderWithAdapter('?keyword=test%2520query') + expect(result.current.query.keyword).toBe('test%20query') + }) + + it('should return empty keyword when not present', () => { + const { result } = renderWithAdapter() + expect(result.current.query.keyword).toBe('') + }) + + it('should sanitize status from search params', () => { + const { result } = renderWithAdapter('?status=available') + expect(result.current.query.status).toBe('available') + }) + + it('should fallback status to all for unknown status', () => { + const { result } = renderWithAdapter('?status=badvalue') + expect(result.current.query.status).toBe('all') + }) + + it('should resolve active status alias to available', () => { + const { result } = renderWithAdapter('?status=active') + expect(result.current.query.status).toBe('available') + }) + + it('should parse valid sort value from search params', () => { + const { result } = renderWithAdapter('?sort=hit_count') + expect(result.current.query.sort).toBe('hit_count') + }) + + it('should default sort to -created_at for invalid sort value', () => { + const { result } = renderWithAdapter('?sort=invalid_sort') + expect(result.current.query.sort).toBe('-created_at') + }) + + it('should default sort to -created_at when not present', () => { + const { result } = renderWithAdapter() + expect(result.current.query.sort).toBe('-created_at') + }) + + it.each([ + '-created_at', + 'created_at', + '-hit_count', + 'hit_count', + ] as const)('should accept valid sort value %s', (sortValue) => { + const { result } = renderWithAdapter(`?sort=${sortValue}`) + expect(result.current.query.sort).toBe(sortValue) + }) + }) + + describe('updateQuery', () => { + it('should update page in state when page is changed', () => { + const { result } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: 3 }) + }) + + expect(result.current.query.page).toBe(3) + }) + + it('should sync page to URL with push history', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('page')).toBe('2') + expect(update.options.history).toBe('push') + }) + + it('should set status in URL when status is not all', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ status: 'error' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('status')).toBe('error') + }) + + it('should not set status in URL when status is all', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ status: 'all' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('status')).toBe(false) + }) + + it('should set sort in URL when sort is not the default -created_at', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ sort: 'hit_count' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('sort')).toBe('hit_count') + }) + + it('should not set sort in URL when sort is default -created_at', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ sort: '-created_at' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('sort')).toBe(false) + }) + + it('should set keyword in URL when keyword is provided', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ keyword: 'test query' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('keyword')).toBe('test query') + expect(update.options.history).toBe('replace') + }) + + it('should use replace history when keyword update also resets page', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?page=3') + + act(() => { + result.current.updateQuery({ keyword: 'hello', page: 1 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('keyword')).toBe('hello') + expect(update.searchParams.has('page')).toBe(false) + expect(update.options.history).toBe('replace') + }) + + it('should remove keyword from URL when keyword is empty', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?keyword=existing') + + act(() => { + result.current.updateQuery({ keyword: '' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('keyword')).toBe(false) + expect(update.options.history).toBe('replace') + }) + + it('should remove keyword from URL when keyword contains only whitespace', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?keyword=existing') + + act(() => { + result.current.updateQuery({ keyword: ' ' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('keyword')).toBe(false) + expect(result.current.query.keyword).toBe('') + }) + + it('should preserve literal percent-encoded-like keyword values', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ keyword: '%2F' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('keyword')).toBe('%2F') + expect(result.current.query.keyword).toBe('%2F') + }) + + it('should keep keyword text unchanged when updating query from legacy URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?keyword=test%2520query') + + act(() => { + result.current.updateQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + expect(result.current.query.keyword).toBe('test%20query') + }) + + it('should sanitize invalid status to all and not include in URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ status: 'invalidstatus' }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('status')).toBe(false) + }) + + it('should sanitize invalid sort to -created_at and not include in URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ sort: 'invalidsort' as DocumentListQuery['sort'] }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('sort')).toBe(false) + }) + + it('should not include page in URL when page is default', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: 1 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('page')).toBe(false) + }) + + it('should include page in URL when page is greater than 1', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('page')).toBe('2') + }) + + it('should include limit in URL when limit is non-default', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ limit: 25 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('limit')).toBe('25') + }) + + it('should sanitize invalid page to default and omit page from URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ page: -1 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('page')).toBe(false) + expect(result.current.query.page).toBe(1) + }) + + it('should sanitize invalid limit to default and omit limit from URL', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.updateQuery({ limit: 999 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('limit')).toBe(false) + expect(result.current.query.limit).toBe(10) + }) + }) + + describe('resetQuery', () => { + it('should reset all values to defaults', () => { + const { result } = renderWithAdapter('?page=5&status=error&sort=hit_count') + + act(() => { + result.current.resetQuery() + }) + + expect(result.current.query).toEqual({ + page: 1, + limit: 10, + keyword: '', + status: 'all', + sort: '-created_at', + }) + }) + + it('should clear all params from URL when called', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?page=5&status=error') + + act(() => { + result.current.resetQuery() + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('page')).toBe(false) + expect(update.searchParams.has('status')).toBe(false) + }) + }) + + describe('return value', () => { + it('should return query, updateQuery, and resetQuery', () => { + const { result } = renderWithAdapter() + + expect(result.current).toHaveProperty('query') + expect(result.current).toHaveProperty('updateQuery') + expect(result.current).toHaveProperty('resetQuery') + expect(typeof result.current.updateQuery).toBe('function') + expect(typeof result.current.resetQuery).toBe('function') + }) + }) +}) diff --git a/web/app/components/datasets/documents/hooks/__tests__/use-documents-page-state.spec.ts b/web/app/components/datasets/documents/hooks/__tests__/use-documents-page-state.spec.ts index 34911e9e9c..e0dbee6660 100644 --- a/web/app/components/datasets/documents/hooks/__tests__/use-documents-page-state.spec.ts +++ b/web/app/components/datasets/documents/hooks/__tests__/use-documents-page-state.spec.ts @@ -1,12 +1,10 @@ import type { DocumentListQuery } from '../use-document-list-query-state' -import type { DocumentListResponse } from '@/models/datasets' import { act, renderHook } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { useDocumentsPageState } from '../use-documents-page-state' const mockUpdateQuery = vi.fn() -const mockResetQuery = vi.fn() let mockQuery: DocumentListQuery = { page: 1, limit: 10, keyword: '', status: 'all', sort: '-created_at' } vi.mock('@/models/datasets', () => ({ @@ -22,151 +20,70 @@ vi.mock('@/models/datasets', () => ({ ], })) -vi.mock('next/navigation', () => ({ - useRouter: () => ({ push: vi.fn() }), - usePathname: () => '/datasets/test-id/documents', - useSearchParams: () => new URLSearchParams(), -})) - -// Mock ahooks debounce utilities: required because tests capture the debounce -// callback reference to invoke it synchronously, bypassing real timer delays. -let capturedDebounceFnCallback: (() => void) | null = null - vi.mock('ahooks', () => ({ useDebounce: (value: unknown, _options?: { wait?: number }) => value, - useDebounceFn: (fn: () => void, _options?: { wait?: number }) => { - capturedDebounceFnCallback = fn - return { run: fn, cancel: vi.fn(), flush: vi.fn() } - }, })) -// Mock the dependent hook -vi.mock('../use-document-list-query-state', () => ({ - default: () => ({ - query: mockQuery, - updateQuery: mockUpdateQuery, - resetQuery: mockResetQuery, - }), -})) - -// Factory for creating DocumentListResponse test data -function createDocumentListResponse(overrides: Partial = {}): DocumentListResponse { +vi.mock('../use-document-list-query-state', async () => { + const React = await import('react') return { - data: [], - has_more: false, - total: 0, - page: 1, - limit: 10, - ...overrides, + useDocumentListQueryState: () => { + const [query, setQuery] = React.useState(mockQuery) + return { + query, + updateQuery: (updates: Partial) => { + mockUpdateQuery(updates) + setQuery(prev => ({ ...prev, ...updates })) + }, + } + }, } -} - -// Factory for creating a minimal document item -function createDocumentItem(overrides: Record = {}) { - return { - id: `doc-${Math.random().toString(36).slice(2, 8)}`, - name: 'test-doc.txt', - indexing_status: 'completed' as string, - display_status: 'available' as string, - enabled: true, - archived: false, - word_count: 100, - created_at: Date.now(), - updated_at: Date.now(), - created_from: 'web' as const, - created_by: 'user-1', - dataset_process_rule_id: 'rule-1', - doc_form: 'text_model' as const, - doc_language: 'en', - position: 1, - data_source_type: 'upload_file', - ...overrides, - } -} +}) describe('useDocumentsPageState', () => { beforeEach(() => { vi.clearAllMocks() - capturedDebounceFnCallback = null mockQuery = { page: 1, limit: 10, keyword: '', status: 'all', sort: '-created_at' } }) // Initial state verification describe('initial state', () => { - it('should return correct initial search state', () => { + it('should return correct initial query-derived state', () => { const { result } = renderHook(() => useDocumentsPageState()) expect(result.current.inputValue).toBe('') - expect(result.current.searchValue).toBe('') expect(result.current.debouncedSearchValue).toBe('') - }) - - it('should return correct initial filter and sort state', () => { - const { result } = renderHook(() => useDocumentsPageState()) - expect(result.current.statusFilterValue).toBe('all') expect(result.current.sortValue).toBe('-created_at') expect(result.current.normalizedStatusFilterValue).toBe('all') - }) - - it('should return correct initial pagination state', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - // page is query.page - 1 = 0 expect(result.current.currPage).toBe(0) expect(result.current.limit).toBe(10) - }) - - it('should return correct initial selection state', () => { - const { result } = renderHook(() => useDocumentsPageState()) - expect(result.current.selectedIds).toEqual([]) }) - it('should return correct initial polling state', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should initialize from query when query has keyword', () => { - mockQuery = { ...mockQuery, keyword: 'initial search' } + it('should initialize from non-default query values', () => { + mockQuery = { + page: 3, + limit: 25, + keyword: 'initial', + status: 'enabled', + sort: 'hit_count', + } const { result } = renderHook(() => useDocumentsPageState()) - expect(result.current.inputValue).toBe('initial search') - expect(result.current.searchValue).toBe('initial search') - }) - - it('should initialize pagination from query with non-default page', () => { - mockQuery = { ...mockQuery, page: 3, limit: 25 } - - const { result } = renderHook(() => useDocumentsPageState()) - - expect(result.current.currPage).toBe(2) // page - 1 + expect(result.current.inputValue).toBe('initial') + expect(result.current.currPage).toBe(2) expect(result.current.limit).toBe(25) - }) - - it('should initialize status filter from query', () => { - mockQuery = { ...mockQuery, status: 'error' } - - const { result } = renderHook(() => useDocumentsPageState()) - - expect(result.current.statusFilterValue).toBe('error') - }) - - it('should initialize sort from query', () => { - mockQuery = { ...mockQuery, sort: 'hit_count' } - - const { result } = renderHook(() => useDocumentsPageState()) - + expect(result.current.statusFilterValue).toBe('enabled') + expect(result.current.normalizedStatusFilterValue).toBe('available') expect(result.current.sortValue).toBe('hit_count') }) }) // Handler behaviors describe('handleInputChange', () => { - it('should update input value when called', () => { + it('should update keyword and reset page', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -174,30 +91,59 @@ describe('useDocumentsPageState', () => { }) expect(result.current.inputValue).toBe('new value') + expect(mockUpdateQuery).toHaveBeenCalledWith({ keyword: 'new value', page: 1 }) }) - it('should trigger debounced search callback when called', () => { + it('should clear selected ids when keyword changes', () => { const { result } = renderHook(() => useDocumentsPageState()) - // First call sets inputValue and triggers the debounced fn act(() => { - result.current.handleInputChange('search term') + result.current.setSelectedIds(['doc-1']) + }) + expect(result.current.selectedIds).toEqual(['doc-1']) + + act(() => { + result.current.handleInputChange('keyword') }) - // The debounced fn captures inputValue from its render closure. - // After re-render with new inputValue, calling the captured callback again - // should reflect the updated state. + expect(result.current.selectedIds).toEqual([]) + }) + + it('should keep selected ids when keyword is unchanged', () => { + mockQuery = { ...mockQuery, keyword: 'same' } + const { result } = renderHook(() => useDocumentsPageState()) + act(() => { - if (capturedDebounceFnCallback) - capturedDebounceFnCallback() + result.current.setSelectedIds(['doc-1']) }) - expect(result.current.searchValue).toBe('search term') + act(() => { + result.current.handleInputChange('same') + }) + + expect(result.current.selectedIds).toEqual(['doc-1']) + expect(mockUpdateQuery).toHaveBeenCalledWith({ keyword: 'same', page: 1 }) }) }) describe('handleStatusFilterChange', () => { - it('should update status filter value when called with valid status', () => { + it('should sanitize status, reset page, and clear selection', () => { + const { result } = renderHook(() => useDocumentsPageState()) + + act(() => { + result.current.setSelectedIds(['doc-1']) + }) + + act(() => { + result.current.handleStatusFilterChange('invalid') + }) + + expect(result.current.statusFilterValue).toBe('all') + expect(result.current.selectedIds).toEqual([]) + expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 }) + }) + + it('should update to valid status value', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -205,61 +151,23 @@ describe('useDocumentsPageState', () => { }) expect(result.current.statusFilterValue).toBe('error') - }) - - it('should reset page to 0 when status filter changes', () => { - mockQuery = { ...mockQuery, page: 3 } - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('error') - }) - - expect(result.current.currPage).toBe(0) - }) - - it('should call updateQuery with sanitized status and page 1', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('error') - }) - expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'error', page: 1 }) }) - - it('should sanitize invalid status to all', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('invalid') - }) - - expect(result.current.statusFilterValue).toBe('all') - expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 }) - }) }) describe('handleStatusFilterClear', () => { - it('should set status to all and reset page when status is not all', () => { + it('should reset status to all when status is not all', () => { + mockQuery = { ...mockQuery, status: 'error' } const { result } = renderHook(() => useDocumentsPageState()) - // First set a non-all status - act(() => { - result.current.handleStatusFilterChange('error') - }) - vi.clearAllMocks() - - // Then clear act(() => { result.current.handleStatusFilterClear() }) - expect(result.current.statusFilterValue).toBe('all') expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 }) }) - it('should not call updateQuery when status is already all', () => { + it('should do nothing when status is already all', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -271,7 +179,7 @@ describe('useDocumentsPageState', () => { }) describe('handleSortChange', () => { - it('should update sort value and call updateQuery when value changes', () => { + it('should update sort and reset page when sort changes', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -282,18 +190,7 @@ describe('useDocumentsPageState', () => { expect(mockUpdateQuery).toHaveBeenCalledWith({ sort: 'hit_count', page: 1 }) }) - it('should reset page to 0 when sort changes', () => { - mockQuery = { ...mockQuery, page: 5 } - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleSortChange('hit_count') - }) - - expect(result.current.currPage).toBe(0) - }) - - it('should not call updateQuery when sort value is same as current', () => { + it('should ignore sort update when value is unchanged', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -304,8 +201,8 @@ describe('useDocumentsPageState', () => { }) }) - describe('handlePageChange', () => { - it('should update current page and call updateQuery', () => { + describe('pagination handlers', () => { + it('should update page with one-based value', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -313,23 +210,10 @@ describe('useDocumentsPageState', () => { }) expect(result.current.currPage).toBe(2) - expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 }) // newPage + 1 + expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 }) }) - it('should handle page 0 (first page)', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handlePageChange(0) - }) - - expect(result.current.currPage).toBe(0) - expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 1 }) - }) - }) - - describe('handleLimitChange', () => { - it('should update limit, reset page to 0, and call updateQuery', () => { + it('should update limit and reset page', () => { const { result } = renderHook(() => useDocumentsPageState()) act(() => { @@ -342,359 +226,29 @@ describe('useDocumentsPageState', () => { }) }) - // Selection state - describe('selection state', () => { - it('should update selectedIds via setSelectedIds', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.setSelectedIds(['doc-1', 'doc-2']) - }) - - expect(result.current.selectedIds).toEqual(['doc-1', 'doc-2']) - }) - }) - - // Polling state management - describe('updatePollingState', () => { - it('should not update timer when documentsRes is undefined', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(undefined) - }) - - // timerCanRun remains true (initial value) - expect(result.current.timerCanRun).toBe(true) - }) - - it('should not update timer when documentsRes.data is undefined', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState({ data: undefined } as unknown as DocumentListResponse) - }) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should set timerCanRun to false when all documents are completed and status filter is all', () => { - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'completed' }), - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 2, - }) - - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(false) - }) - - it('should set timerCanRun to true when some documents are not completed', () => { - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'completed' }), - createDocumentItem({ indexing_status: 'indexing' }), - ] as DocumentListResponse['data'], - total: 2, - }) - - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should count paused documents as completed for polling purposes', () => { - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'paused' }), - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 2, - }) - - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(response) - }) - - // All docs are "embedded" (completed, paused, error), so hasIncomplete = false - // statusFilter is 'all', so shouldForcePolling = false - expect(result.current.timerCanRun).toBe(false) - }) - - it('should count error documents as completed for polling purposes', () => { - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'error' }), - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 2, - }) - - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(false) - }) - - it('should force polling when status filter is a transient status (queuing)', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - // Set status filter to queuing - act(() => { - result.current.handleStatusFilterChange('queuing') - }) - - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 1, - }) - - act(() => { - result.current.updatePollingState(response) - }) - - // shouldForcePolling = true (queuing is transient), hasIncomplete = false - // timerCanRun = true || false = true - expect(result.current.timerCanRun).toBe(true) - }) - - it('should force polling when status filter is indexing', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('indexing') - }) - - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'completed' }), - ] as DocumentListResponse['data'], - total: 1, - }) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should force polling when status filter is paused', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('paused') - }) - - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'paused' }), - ] as DocumentListResponse['data'], - total: 1, - }) - - act(() => { - result.current.updatePollingState(response) - }) - - expect(result.current.timerCanRun).toBe(true) - }) - - it('should not force polling when status filter is a non-transient status (error)', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('error') - }) - - const response = createDocumentListResponse({ - data: [ - createDocumentItem({ indexing_status: 'error' }), - ] as DocumentListResponse['data'], - total: 1, - }) - - act(() => { - result.current.updatePollingState(response) - }) - - // shouldForcePolling = false (error is not transient), hasIncomplete = false (error is embedded) - expect(result.current.timerCanRun).toBe(false) - }) - - it('should set timerCanRun to true when data is empty and filter is transient', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('indexing') - }) - - const response = createDocumentListResponse({ data: [] as DocumentListResponse['data'], total: 0 }) - - act(() => { - result.current.updatePollingState(response) - }) - - // shouldForcePolling = true (indexing is transient), hasIncomplete = false (0 !== 0 is false) - expect(result.current.timerCanRun).toBe(true) - }) - }) - - // Page adjustment - describe('adjustPageForTotal', () => { - it('should not adjust page when documentsRes is undefined', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.adjustPageForTotal(undefined) - }) - - expect(result.current.currPage).toBe(0) - }) - - it('should not adjust page when currPage is within total pages', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - const response = createDocumentListResponse({ total: 20 }) - - act(() => { - result.current.adjustPageForTotal(response) - }) - - // currPage is 0, totalPages is 2, so no adjustment needed - expect(result.current.currPage).toBe(0) - }) - - it('should adjust page to last page when currPage exceeds total pages', () => { - mockQuery = { ...mockQuery, page: 6 } - const { result } = renderHook(() => useDocumentsPageState()) - - // currPage should be 5 (page - 1) - expect(result.current.currPage).toBe(5) - - const response = createDocumentListResponse({ total: 30 }) // 30/10 = 3 pages - - act(() => { - result.current.adjustPageForTotal(response) - }) - - // currPage (5) + 1 > totalPages (3), so adjust to totalPages - 1 = 2 - expect(result.current.currPage).toBe(2) - expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 }) // handlePageChange passes newPage + 1 - }) - - it('should adjust page to 0 when total is 0 and currPage > 0', () => { - mockQuery = { ...mockQuery, page: 3 } - const { result } = renderHook(() => useDocumentsPageState()) - - const response = createDocumentListResponse({ total: 0 }) - - act(() => { - result.current.adjustPageForTotal(response) - }) - - // totalPages = 0, so adjust to max(0 - 1, 0) = 0 - expect(result.current.currPage).toBe(0) - expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 1 }) - }) - - it('should not adjust page when currPage is 0 even if total is 0', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - const response = createDocumentListResponse({ total: 0 }) - - act(() => { - result.current.adjustPageForTotal(response) - }) - - // currPage is 0, condition is currPage > 0 so no adjustment - expect(mockUpdateQuery).not.toHaveBeenCalled() - }) - }) - - // Normalized status filter value - describe('normalizedStatusFilterValue', () => { - it('should return all for default status', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - expect(result.current.normalizedStatusFilterValue).toBe('all') - }) - - it('should normalize enabled to available', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('enabled') - }) - - expect(result.current.normalizedStatusFilterValue).toBe('available') - }) - - it('should return non-aliased status as-is', () => { - const { result } = renderHook(() => useDocumentsPageState()) - - act(() => { - result.current.handleStatusFilterChange('error') - }) - - expect(result.current.normalizedStatusFilterValue).toBe('error') - }) - }) - // Return value shape describe('return value', () => { it('should return all expected properties', () => { const { result } = renderHook(() => useDocumentsPageState()) - // Search state expect(result.current).toHaveProperty('inputValue') - expect(result.current).toHaveProperty('searchValue') expect(result.current).toHaveProperty('debouncedSearchValue') expect(result.current).toHaveProperty('handleInputChange') - - // Filter & sort state expect(result.current).toHaveProperty('statusFilterValue') expect(result.current).toHaveProperty('sortValue') expect(result.current).toHaveProperty('normalizedStatusFilterValue') expect(result.current).toHaveProperty('handleStatusFilterChange') expect(result.current).toHaveProperty('handleStatusFilterClear') expect(result.current).toHaveProperty('handleSortChange') - - // Pagination state expect(result.current).toHaveProperty('currPage') expect(result.current).toHaveProperty('limit') expect(result.current).toHaveProperty('handlePageChange') expect(result.current).toHaveProperty('handleLimitChange') - - // Selection state expect(result.current).toHaveProperty('selectedIds') expect(result.current).toHaveProperty('setSelectedIds') - - // Polling state - expect(result.current).toHaveProperty('timerCanRun') - expect(result.current).toHaveProperty('updatePollingState') - expect(result.current).toHaveProperty('adjustPageForTotal') }) - it('should have function types for all handlers', () => { + it('should expose function handlers', () => { const { result } = renderHook(() => useDocumentsPageState()) expect(typeof result.current.handleInputChange).toBe('function') @@ -704,8 +258,6 @@ describe('useDocumentsPageState', () => { expect(typeof result.current.handlePageChange).toBe('function') expect(typeof result.current.handleLimitChange).toBe('function') expect(typeof result.current.setSelectedIds).toBe('function') - expect(typeof result.current.updatePollingState).toBe('function') - expect(typeof result.current.adjustPageForTotal).toBe('function') }) }) }) diff --git a/web/app/components/datasets/documents/hooks/use-document-list-query-state.ts b/web/app/components/datasets/documents/hooks/use-document-list-query-state.ts index 505f15efc0..60717d532c 100644 --- a/web/app/components/datasets/documents/hooks/use-document-list-query-state.ts +++ b/web/app/components/datasets/documents/hooks/use-document-list-query-state.ts @@ -1,6 +1,6 @@ -import type { ReadonlyURLSearchParams } from 'next/navigation' +import type { inferParserType } from 'nuqs' import type { SortType } from '@/service/datasets' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' +import { createParser, parseAsString, throttle, useQueryStates } from 'nuqs' import { useCallback, useMemo } from 'react' import { sanitizeStatusValue } from '../status-filter' @@ -13,99 +13,87 @@ const sanitizeSortValue = (value?: string | null): SortType => { return (ALLOWED_SORT_VALUES.includes(value as SortType) ? value : '-created_at') as SortType } -export type DocumentListQuery = { - page: number - limit: number - keyword: string - status: string - sort: SortType +const sanitizePageValue = (value: number): number => { + return Number.isInteger(value) && value > 0 ? value : 1 } -const DEFAULT_QUERY: DocumentListQuery = { - page: 1, - limit: 10, - keyword: '', - status: 'all', - sort: '-created_at', +const sanitizeLimitValue = (value: number): number => { + return Number.isInteger(value) && value > 0 && value <= 100 ? value : 10 } -// Parse the query parameters from the URL search string. -function parseParams(params: ReadonlyURLSearchParams): DocumentListQuery { - const page = Number.parseInt(params.get('page') || '1', 10) - const limit = Number.parseInt(params.get('limit') || '10', 10) - const keyword = params.get('keyword') || '' - const status = sanitizeStatusValue(params.get('status')) - const sort = sanitizeSortValue(params.get('sort')) +const parseAsPage = createParser({ + parse: (value) => { + const n = Number.parseInt(value, 10) + return Number.isNaN(n) || n <= 0 ? null : n + }, + serialize: value => value.toString(), +}).withDefault(1) - return { - page: page > 0 ? page : 1, - limit: (limit > 0 && limit <= 100) ? limit : 10, - keyword: keyword ? decodeURIComponent(keyword) : '', - status, - sort, - } +const parseAsLimit = createParser({ + parse: (value) => { + const n = Number.parseInt(value, 10) + return Number.isNaN(n) || n <= 0 || n > 100 ? null : n + }, + serialize: value => value.toString(), +}).withDefault(10) + +const parseAsDocStatus = createParser({ + parse: value => sanitizeStatusValue(value), + serialize: value => value, +}).withDefault('all') + +const parseAsDocSort = createParser({ + parse: value => sanitizeSortValue(value), + serialize: value => value, +}).withDefault('-created_at' as SortType) + +const parseAsKeyword = parseAsString.withDefault('') + +export const documentListParsers = { + page: parseAsPage, + limit: parseAsLimit, + keyword: parseAsKeyword, + status: parseAsDocStatus, + sort: parseAsDocSort, } -// Update the URL search string with the given query parameters. -function updateSearchParams(query: DocumentListQuery, searchParams: URLSearchParams) { - const { page, limit, keyword, status, sort } = query || {} +export type DocumentListQuery = inferParserType - const hasNonDefaultParams = (page && page > 1) || (limit && limit !== 10) || (keyword && keyword.trim()) +// Search input updates can be frequent; throttle URL writes to reduce history/api churn. +const KEYWORD_URL_UPDATE_THROTTLE = throttle(300) - if (hasNonDefaultParams) { - searchParams.set('page', (page || 1).toString()) - searchParams.set('limit', (limit || 10).toString()) - } - else { - searchParams.delete('page') - searchParams.delete('limit') - } +export function useDocumentListQueryState() { + const [query, setQuery] = useQueryStates(documentListParsers) - if (keyword && keyword.trim()) - searchParams.set('keyword', encodeURIComponent(keyword)) - else - searchParams.delete('keyword') - - const sanitizedStatus = sanitizeStatusValue(status) - if (sanitizedStatus && sanitizedStatus !== 'all') - searchParams.set('status', sanitizedStatus) - else - searchParams.delete('status') - - const sanitizedSort = sanitizeSortValue(sort) - if (sanitizedSort !== '-created_at') - searchParams.set('sort', sanitizedSort) - else - searchParams.delete('sort') -} - -function useDocumentListQueryState() { - const searchParams = useSearchParams() - const query = useMemo(() => parseParams(searchParams), [searchParams]) - - const router = useRouter() - const pathname = usePathname() - - // Helper function to update specific query parameters const updateQuery = useCallback((updates: Partial) => { - const newQuery = { ...query, ...updates } - newQuery.status = sanitizeStatusValue(newQuery.status) - newQuery.sort = sanitizeSortValue(newQuery.sort) - const params = new URLSearchParams() - updateSearchParams(newQuery, params) - const search = params.toString() - const queryString = search ? `?${search}` : '' - router.push(`${pathname}${queryString}`, { scroll: false }) - }, [query, router, pathname]) + const patch = { ...updates } + if ('page' in patch && patch.page !== undefined) + patch.page = sanitizePageValue(patch.page) + if ('limit' in patch && patch.limit !== undefined) + patch.limit = sanitizeLimitValue(patch.limit) + if ('status' in patch) + patch.status = sanitizeStatusValue(patch.status) + if ('sort' in patch) + patch.sort = sanitizeSortValue(patch.sort) + if ('keyword' in patch && typeof patch.keyword === 'string' && patch.keyword.trim() === '') + patch.keyword = '' + + // If keyword is part of this patch (even with page reset), treat it as a search update: + // use replace to avoid creating a history entry per input-driven change. + if ('keyword' in patch) { + setQuery(patch, { + history: 'replace', + limitUrlUpdates: patch.keyword === '' ? undefined : KEYWORD_URL_UPDATE_THROTTLE, + }) + return + } + + setQuery(patch, { history: 'push' }) + }, [setQuery]) - // Helper function to reset query to defaults const resetQuery = useCallback(() => { - const params = new URLSearchParams() - updateSearchParams(DEFAULT_QUERY, params) - const search = params.toString() - const queryString = search ? `?${search}` : '' - router.push(`${pathname}${queryString}`, { scroll: false }) - }, [router, pathname]) + setQuery(null, { history: 'replace' }) + }, [setQuery]) return useMemo(() => ({ query, @@ -113,5 +101,3 @@ function useDocumentListQueryState() { resetQuery, }), [query, updateQuery, resetQuery]) } - -export default useDocumentListQueryState diff --git a/web/app/components/datasets/documents/hooks/use-documents-page-state.ts b/web/app/components/datasets/documents/hooks/use-documents-page-state.ts index 4fb227f717..36b1e8c760 100644 --- a/web/app/components/datasets/documents/hooks/use-documents-page-state.ts +++ b/web/app/components/datasets/documents/hooks/use-documents-page-state.ts @@ -1,175 +1,63 @@ -import type { DocumentListResponse } from '@/models/datasets' import type { SortType } from '@/service/datasets' -import { useDebounce, useDebounceFn } from 'ahooks' -import { useCallback, useEffect, useMemo, useState } from 'react' +import { useDebounce } from 'ahooks' +import { useCallback, useState } from 'react' import { normalizeStatusForQuery, sanitizeStatusValue } from '../status-filter' -import useDocumentListQueryState from './use-document-list-query-state' +import { useDocumentListQueryState } from './use-document-list-query-state' -/** - * Custom hook to manage documents page state including: - * - Search state (input value, debounced search value) - * - Filter state (status filter, sort value) - * - Pagination state (current page, limit) - * - Selection state (selected document ids) - * - Polling state (timer control for auto-refresh) - */ export function useDocumentsPageState() { const { query, updateQuery } = useDocumentListQueryState() - // Search state - const [inputValue, setInputValue] = useState('') - const [searchValue, setSearchValue] = useState('') - const debouncedSearchValue = useDebounce(searchValue, { wait: 500 }) + const inputValue = query.keyword + const debouncedSearchValue = useDebounce(query.keyword, { wait: 500 }) - // Filter & sort state - const [statusFilterValue, setStatusFilterValue] = useState(() => sanitizeStatusValue(query.status)) - const [sortValue, setSortValue] = useState(query.sort) - const normalizedStatusFilterValue = useMemo( - () => normalizeStatusForQuery(statusFilterValue), - [statusFilterValue], - ) + const statusFilterValue = sanitizeStatusValue(query.status) + const sortValue = query.sort + const normalizedStatusFilterValue = normalizeStatusForQuery(statusFilterValue) - // Pagination state - const [currPage, setCurrPage] = useState(query.page - 1) - const [limit, setLimit] = useState(query.limit) + const currPage = query.page - 1 + const limit = query.limit - // Selection state const [selectedIds, setSelectedIds] = useState([]) - // Polling state - const [timerCanRun, setTimerCanRun] = useState(true) - - // Initialize search value from URL on mount - useEffect(() => { - if (query.keyword) { - setInputValue(query.keyword) - setSearchValue(query.keyword) - } - }, []) // Only run on mount - - // Sync local state with URL query changes - useEffect(() => { - setCurrPage(query.page - 1) - setLimit(query.limit) - if (query.keyword !== searchValue) { - setInputValue(query.keyword) - setSearchValue(query.keyword) - } - setStatusFilterValue((prev) => { - const nextValue = sanitizeStatusValue(query.status) - return prev === nextValue ? prev : nextValue - }) - setSortValue(query.sort) - }, [query]) - - // Update URL when search changes - useEffect(() => { - if (debouncedSearchValue !== query.keyword) { - setCurrPage(0) - updateQuery({ keyword: debouncedSearchValue, page: 1 }) - } - }, [debouncedSearchValue, query.keyword, updateQuery]) - - // Clear selection when search changes - useEffect(() => { - if (searchValue !== query.keyword) - setSelectedIds([]) - }, [searchValue, query.keyword]) - - // Clear selection when status filter changes - useEffect(() => { - setSelectedIds([]) - }, [normalizedStatusFilterValue]) - - // Page change handler const handlePageChange = useCallback((newPage: number) => { - setCurrPage(newPage) updateQuery({ page: newPage + 1 }) }, [updateQuery]) - // Limit change handler const handleLimitChange = useCallback((newLimit: number) => { - setLimit(newLimit) - setCurrPage(0) updateQuery({ limit: newLimit, page: 1 }) }, [updateQuery]) - // Debounced search handler - const { run: handleSearch } = useDebounceFn(() => { - setSearchValue(inputValue) - }, { wait: 500 }) - - // Input change handler const handleInputChange = useCallback((value: string) => { - setInputValue(value) - handleSearch() - }, [handleSearch]) + if (value !== query.keyword) + setSelectedIds([]) + updateQuery({ keyword: value, page: 1 }) + }, [query.keyword, updateQuery]) - // Status filter change handler const handleStatusFilterChange = useCallback((value: string) => { const selectedValue = sanitizeStatusValue(value) - setStatusFilterValue(selectedValue) - setCurrPage(0) + setSelectedIds([]) updateQuery({ status: selectedValue, page: 1 }) }, [updateQuery]) - // Status filter clear handler const handleStatusFilterClear = useCallback(() => { if (statusFilterValue === 'all') return - setStatusFilterValue('all') - setCurrPage(0) + setSelectedIds([]) updateQuery({ status: 'all', page: 1 }) }, [statusFilterValue, updateQuery]) - // Sort change handler const handleSortChange = useCallback((value: string) => { const next = value as SortType if (next === sortValue) return - setSortValue(next) - setCurrPage(0) updateQuery({ sort: next, page: 1 }) }, [sortValue, updateQuery]) - // Update polling state based on documents response - const updatePollingState = useCallback((documentsRes: DocumentListResponse | undefined) => { - if (!documentsRes?.data) - return - - let completedNum = 0 - documentsRes.data.forEach((documentItem) => { - const { indexing_status } = documentItem - const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error' - if (isEmbedded) - completedNum++ - }) - - const hasIncompleteDocuments = completedNum !== documentsRes.data.length - const transientStatuses = ['queuing', 'indexing', 'paused'] - const shouldForcePolling = normalizedStatusFilterValue === 'all' - ? false - : transientStatuses.includes(normalizedStatusFilterValue) - setTimerCanRun(shouldForcePolling || hasIncompleteDocuments) - }, [normalizedStatusFilterValue]) - - // Adjust page when total pages change - const adjustPageForTotal = useCallback((documentsRes: DocumentListResponse | undefined) => { - if (!documentsRes) - return - const totalPages = Math.ceil(documentsRes.total / limit) - if (currPage > 0 && currPage + 1 > totalPages) - handlePageChange(totalPages > 0 ? totalPages - 1 : 0) - }, [limit, currPage, handlePageChange]) - return { - // Search state inputValue, - searchValue, debouncedSearchValue, handleInputChange, - // Filter & sort state statusFilterValue, sortValue, normalizedStatusFilterValue, @@ -177,21 +65,12 @@ export function useDocumentsPageState() { handleStatusFilterClear, handleSortChange, - // Pagination state currPage, limit, handlePageChange, handleLimitChange, - // Selection state selectedIds, setSelectedIds, - - // Polling state - timerCanRun, - updatePollingState, - adjustPageForTotal, } } - -export default useDocumentsPageState diff --git a/web/app/components/datasets/documents/index.tsx b/web/app/components/datasets/documents/index.tsx index 676e715f56..764b04227c 100644 --- a/web/app/components/datasets/documents/index.tsx +++ b/web/app/components/datasets/documents/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import { useRouter } from 'next/navigation' -import { useCallback, useEffect } from 'react' +import { useCallback } from 'react' import Loading from '@/app/components/base/loading' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' @@ -13,12 +13,16 @@ import useEditDocumentMetadata from '../metadata/hooks/use-edit-dataset-metadata import DocumentsHeader from './components/documents-header' import EmptyElement from './components/empty-element' import List from './components/list' -import useDocumentsPageState from './hooks/use-documents-page-state' +import { useDocumentsPageState } from './hooks/use-documents-page-state' type IDocumentsProps = { datasetId: string } +const POLLING_INTERVAL = 2500 +const TERMINAL_INDEXING_STATUSES = new Set(['completed', 'paused', 'error']) +const FORCED_POLLING_STATUSES = new Set(['queuing', 'indexing', 'paused']) + const Documents: FC = ({ datasetId }) => { const router = useRouter() const { plan } = useProviderContext() @@ -44,9 +48,6 @@ const Documents: FC = ({ datasetId }) => { handleLimitChange, selectedIds, setSelectedIds, - timerCanRun, - updatePollingState, - adjustPageForTotal, } = useDocumentsPageState() // Fetch document list @@ -59,19 +60,17 @@ const Documents: FC = ({ datasetId }) => { status: normalizedStatusFilterValue, sort: sortValue, }, - refetchInterval: timerCanRun ? 2500 : 0, + refetchInterval: (query) => { + const shouldForcePolling = normalizedStatusFilterValue !== 'all' + && FORCED_POLLING_STATUSES.has(normalizedStatusFilterValue) + const documents = query.state.data?.data + if (!documents) + return POLLING_INTERVAL + const hasIncompleteDocuments = documents.some(({ indexing_status }) => !TERMINAL_INDEXING_STATUSES.has(indexing_status)) + return shouldForcePolling || hasIncompleteDocuments ? POLLING_INTERVAL : false + }, }) - // Update polling state when documents change - useEffect(() => { - updatePollingState(documentsRes) - }, [documentsRes, updatePollingState]) - - // Adjust page when total changes - useEffect(() => { - adjustPageForTotal(documentsRes) - }, [documentsRes, adjustPageForTotal]) - // Invalidation hooks const invalidDocumentList = useInvalidDocumentList(datasetId) const invalidDocumentDetail = useInvalidDocumentDetail() @@ -119,7 +118,7 @@ const Documents: FC = ({ datasetId }) => { // Render content based on loading and data state const renderContent = () => { - if (isListLoading) + if (isListLoading && !documentsRes) return if (total > 0) { @@ -131,8 +130,8 @@ const Documents: FC = ({ datasetId }) => { onUpdate={handleUpdate} selectedIds={selectedIds} onSelectedIdChange={setSelectedIds} - statusFilterValue={normalizedStatusFilterValue} remoteSortValue={sortValue} + onSortChange={handleSortChange} pagination={{ total, limit, diff --git a/web/app/components/datasets/documents/status-item/index.tsx b/web/app/components/datasets/documents/status-item/index.tsx index 60d837fd81..8d3abed7cf 100644 --- a/web/app/components/datasets/documents/status-item/index.tsx +++ b/web/app/components/datasets/documents/status-item/index.tsx @@ -8,7 +8,7 @@ import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Switch from '@/app/components/base/switch' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import Indicator from '@/app/components/header/indicator' import { useDocumentDelete, useDocumentDisable, useDocumentEnable } from '@/service/knowledge/use-document' diff --git a/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx b/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx index a631de3ea0..66d9a163be 100644 --- a/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx @@ -12,7 +12,7 @@ vi.mock('@/service/datasets', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/datasets/external-api/external-api-modal/index.tsx b/web/app/components/datasets/external-api/external-api-modal/index.tsx index 723cad199d..38c290a3a7 100644 --- a/web/app/components/datasets/external-api/external-api-modal/index.tsx +++ b/web/app/components/datasets/external-api/external-api-modal/index.tsx @@ -19,7 +19,7 @@ import { PortalToFollowElem, PortalToFollowElemContent, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { createExternalAPI } from '@/service/datasets' import Form from './Form' diff --git a/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx b/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx index ccd637887b..a6a60aa856 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx @@ -22,7 +22,7 @@ vi.mock('@/context/i18n', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/datasets/external-knowledge-base/connector/index.tsx b/web/app/components/datasets/external-knowledge-base/connector/index.tsx index 1545c0d232..cf36eed382 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/index.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/index.tsx @@ -5,7 +5,7 @@ import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { trackEvent } from '@/app/components/base/amplitude' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create' import { createExternalKnowledgeBase } from '@/service/datasets' diff --git a/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx b/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx index 0a5a55b744..fe7510b498 100644 --- a/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx +++ b/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx @@ -579,10 +579,20 @@ describe('HitTestingPage', () => { }) describe('Integration: Hit Testing Flow', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() mockHitTestingMutateAsync.mockReset() mockExternalHitTestingMutateAsync.mockReset() + + const { useHitTesting, useExternalKnowledgeBaseHitTesting } = await import('@/service/knowledge/use-hit-testing') + vi.mocked(useHitTesting).mockReturnValue({ + mutateAsync: mockHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) + vi.mocked(useExternalKnowledgeBaseHitTesting).mockReturnValue({ + mutateAsync: mockExternalHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) }) it('should complete a full hit testing flow', async () => { @@ -781,8 +791,18 @@ describe('Integration: Hit Testing Flow', () => { // Drawer and Modal Interaction Tests describe('Drawer and Modal Interactions', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() + + const { useHitTesting, useExternalKnowledgeBaseHitTesting } = await import('@/service/knowledge/use-hit-testing') + vi.mocked(useHitTesting).mockReturnValue({ + mutateAsync: mockHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) + vi.mocked(useExternalKnowledgeBaseHitTesting).mockReturnValue({ + mutateAsync: mockExternalHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) }) it('should save retrieval config when ModifyRetrievalModal onSave is called', async () => { @@ -828,9 +848,19 @@ describe('Drawer and Modal Interactions', () => { // renderHitResults Coverage Tests describe('renderHitResults Coverage', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() mockHitTestingMutateAsync.mockReset() + + const { useHitTesting, useExternalKnowledgeBaseHitTesting } = await import('@/service/knowledge/use-hit-testing') + vi.mocked(useHitTesting).mockReturnValue({ + mutateAsync: mockHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) + vi.mocked(useExternalKnowledgeBaseHitTesting).mockReturnValue({ + mutateAsync: mockExternalHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) }) it('should render hit results panel with records count', async () => { @@ -952,10 +982,20 @@ describe('ModifyRetrievalModal onSave Coverage', () => { // Direct Component Coverage Tests describe('HitTestingPage Internal Functions Coverage', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() mockHitTestingMutateAsync.mockReset() mockExternalHitTestingMutateAsync.mockReset() + + const { useHitTesting, useExternalKnowledgeBaseHitTesting } = await import('@/service/knowledge/use-hit-testing') + vi.mocked(useHitTesting).mockReturnValue({ + mutateAsync: mockHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) + vi.mocked(useExternalKnowledgeBaseHitTesting).mockReturnValue({ + mutateAsync: mockExternalHitTestingMutateAsync, + isPending: false, + } as unknown as ReturnType) }) it('should trigger renderHitResults when mutation succeeds with records', async () => { diff --git a/web/app/components/devtools/react-scan/loader.tsx b/web/app/components/devtools/react-scan/loader.tsx index ee702216f7..a5956d7825 100644 --- a/web/app/components/devtools/react-scan/loader.tsx +++ b/web/app/components/devtools/react-scan/loader.tsx @@ -1,21 +1,15 @@ -'use client' - -import { lazy, Suspense } from 'react' +import Script from 'next/script' import { IS_DEV } from '@/config' -const ReactScan = lazy(() => - import('./scan').then(module => ({ - default: module.ReactScan, - })), -) - -export const ReactScanLoader = () => { +export function ReactScanLoader() { if (!IS_DEV) return null return ( - - - +