diff --git a/.claude/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md similarity index 99% rename from .claude/skills/component-refactoring/SKILL.md rename to .agents/skills/component-refactoring/SKILL.md index ea695ea442..7006c382c8 100644 --- a/.claude/skills/component-refactoring/SKILL.md +++ b/.agents/skills/component-refactoring/SKILL.md @@ -187,7 +187,7 @@ const Template = useMemo(() => { **When**: Component directly handles API calls, data transformation, or complex async operations. -**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks. Project is migrating from SWR to React Query. +**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks. ```typescript // ❌ Before: API logic in component diff --git a/.claude/skills/component-refactoring/references/complexity-patterns.md b/.agents/skills/component-refactoring/references/complexity-patterns.md similarity index 100% rename from .claude/skills/component-refactoring/references/complexity-patterns.md rename to .agents/skills/component-refactoring/references/complexity-patterns.md diff --git a/.claude/skills/component-refactoring/references/component-splitting.md b/.agents/skills/component-refactoring/references/component-splitting.md similarity index 100% rename from .claude/skills/component-refactoring/references/component-splitting.md rename to .agents/skills/component-refactoring/references/component-splitting.md diff --git a/.claude/skills/component-refactoring/references/hook-extraction.md b/.agents/skills/component-refactoring/references/hook-extraction.md similarity index 100% rename from .claude/skills/component-refactoring/references/hook-extraction.md rename to .agents/skills/component-refactoring/references/hook-extraction.md diff --git a/.agents/skills/frontend-code-review/SKILL.md b/.agents/skills/frontend-code-review/SKILL.md new file mode 100644 index 0000000000..6cc23ca171 --- /dev/null +++ b/.agents/skills/frontend-code-review/SKILL.md @@ -0,0 +1,73 @@ +--- +name: frontend-code-review +description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules." +--- + +# Frontend Code Review + +## Intent +Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes: + +1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission. +2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings. + +Stick to the checklist below for every applicable file and mode. + +## Checklist +See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow. + +Flag each rule violation with urgency metadata so future reviewers can prioritize fixes. + +## Review Process +1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling. +2. For each rule in the review point, note where the code deviates and capture a representative snippet. +3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic). + +## Required output +When invoked, the response must exactly follow one of the two templates: + +### Template A (any findings) +``` +# Code review +Found urgent issues need to be fixed: + +## 1 +FilePath: line + + + +### Suggested fix + + +--- +... (repeat for each urgent issue) ... + +Found suggestions for improvement: + +## 1 +FilePath: line + + + +### Suggested fix + + +--- + +... (repeat for each suggestion) ... +``` + +If there are no urgent issues, omit that section. If there are no suggestions, omit that section. + +If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues. + +Don't compress the blank lines between sections; keep them as-is for readability. + +If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?" + +### Template B (no issues) +``` +## Code review +No issues found. +``` + diff --git a/.agents/skills/frontend-code-review/references/business-logic.md b/.agents/skills/frontend-code-review/references/business-logic.md new file mode 100644 index 0000000000..4584f99dfc --- /dev/null +++ b/.agents/skills/frontend-code-review/references/business-logic.md @@ -0,0 +1,15 @@ +# Rule Catalog — Business Logic + +## Can't use workflowStore in Node components + +IsUrgent: True + +### Description + +File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx` + +Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason. + +### Suggested Fix + +Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`. diff --git a/.agents/skills/frontend-code-review/references/code-quality.md b/.agents/skills/frontend-code-review/references/code-quality.md new file mode 100644 index 0000000000..afdd40deb3 --- /dev/null +++ b/.agents/skills/frontend-code-review/references/code-quality.md @@ -0,0 +1,44 @@ +# Rule Catalog — Code Quality + +## Conditional class names use utility function + +IsUrgent: True +Category: Code Quality + +### Description + +Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain. + +### Suggested Fix + +```ts +import { cn } from '@/utils/classnames' +const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500') +``` + +## Tailwind-first styling + +IsUrgent: True +Category: Code Quality + +### Description + +Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead. + +Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate. + +## Classname ordering for easy overrides + +### Description + +When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles. + +Example: + +```tsx +import { cn } from '@/utils/classnames' + +const Button = ({ className }) => { + return
+} +``` diff --git a/.agents/skills/frontend-code-review/references/performance.md b/.agents/skills/frontend-code-review/references/performance.md new file mode 100644 index 0000000000..2d60072f5c --- /dev/null +++ b/.agents/skills/frontend-code-review/references/performance.md @@ -0,0 +1,45 @@ +# Rule Catalog — Performance + +## React Flow data usage + +IsUrgent: True +Category: Performance + +### Description + +When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks. + +## Complex prop memoization + +IsUrgent: True +Category: Performance + +### Description + +Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders. + +Update this file when adding, editing, or removing Performance rules so the catalog remains accurate. + +Wrong: + +```tsx + +``` + +Right: + +```tsx +const config = useMemo(() => ({ + provider: ..., + detail: ... +}), [provider, detail]); + + +``` diff --git a/.claude/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md similarity index 98% rename from .claude/skills/frontend-testing/SKILL.md rename to .agents/skills/frontend-testing/SKILL.md index dd9677a78e..0716c81ef7 100644 --- a/.claude/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -83,6 +83,9 @@ vi.mock('next/navigation', () => ({ usePathname: () => '/test', })) +// ✅ Zustand stores: Use real stores (auto-mocked globally) +// Set test state with: useAppStore.setState({ ... }) + // Shared state for mocks (if needed) let mockSharedState = false @@ -296,7 +299,7 @@ For each test file generated, aim for: For more detailed information, refer to: - `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing) -- `references/mocking.md` - Mock patterns and best practices +- `references/mocking.md` - Mock patterns, Zustand store testing, and best practices - `references/async-testing.md` - Async operations and API calls - `references/domain-components.md` - Workflow, Dataset, Configuration testing - `references/common-patterns.md` - Frequently used testing patterns diff --git a/.claude/skills/frontend-testing/assets/component-test.template.tsx b/.agents/skills/frontend-testing/assets/component-test.template.tsx similarity index 97% rename from .claude/skills/frontend-testing/assets/component-test.template.tsx rename to .agents/skills/frontend-testing/assets/component-test.template.tsx index c39baff916..6b7803bd4b 100644 --- a/.claude/skills/frontend-testing/assets/component-test.template.tsx +++ b/.agents/skills/frontend-testing/assets/component-test.template.tsx @@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event' // i18n (automatically mocked) // WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup -// No explicit mock needed - it returns translation keys as-is +// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage +// No explicit mock needed for most tests +// // Override only if custom translations are required: -// vi.mock('react-i18next', () => ({ -// useTranslation: () => ({ -// t: (key: string) => { -// const customTranslations: Record = { -// 'my.custom.key': 'Custom Translation', -// } -// return customTranslations[key] || key -// }, -// }), +// import { createReactI18nextMock } from '@/test/i18n-mock' +// vi.mock('react-i18next', () => createReactI18nextMock({ +// 'my.custom.key': 'Custom Translation', +// 'button.save': 'Save', // })) // Router (if component uses useRouter, usePathname, useSearchParams) diff --git a/.claude/skills/frontend-testing/assets/hook-test.template.ts b/.agents/skills/frontend-testing/assets/hook-test.template.ts similarity index 100% rename from .claude/skills/frontend-testing/assets/hook-test.template.ts rename to .agents/skills/frontend-testing/assets/hook-test.template.ts diff --git a/.claude/skills/frontend-testing/assets/utility-test.template.ts b/.agents/skills/frontend-testing/assets/utility-test.template.ts similarity index 100% rename from .claude/skills/frontend-testing/assets/utility-test.template.ts rename to .agents/skills/frontend-testing/assets/utility-test.template.ts diff --git a/.claude/skills/frontend-testing/references/async-testing.md b/.agents/skills/frontend-testing/references/async-testing.md similarity index 100% rename from .claude/skills/frontend-testing/references/async-testing.md rename to .agents/skills/frontend-testing/references/async-testing.md diff --git a/.claude/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md similarity index 100% rename from .claude/skills/frontend-testing/references/checklist.md rename to .agents/skills/frontend-testing/references/checklist.md diff --git a/.claude/skills/frontend-testing/references/common-patterns.md b/.agents/skills/frontend-testing/references/common-patterns.md similarity index 100% rename from .claude/skills/frontend-testing/references/common-patterns.md rename to .agents/skills/frontend-testing/references/common-patterns.md diff --git a/.claude/skills/frontend-testing/references/domain-components.md b/.agents/skills/frontend-testing/references/domain-components.md similarity index 100% rename from .claude/skills/frontend-testing/references/domain-components.md rename to .agents/skills/frontend-testing/references/domain-components.md diff --git a/.claude/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md similarity index 59% rename from .claude/skills/frontend-testing/references/mocking.md rename to .agents/skills/frontend-testing/references/mocking.md index 23889c8d3d..86bd375987 100644 --- a/.claude/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -37,38 +37,64 @@ Only mock these categories: 1. **Third-party libraries with side effects** - `next/navigation`, external SDKs 1. **i18n** - Always mock to return keys +### Zustand Stores - DO NOT Mock Manually + +**Zustand is globally mocked** in `web/vitest.setup.ts`. Use real stores with `setState()`: + +```typescript +// ✅ CORRECT: Use real store, set test state +import { useAppStore } from '@/app/components/app/store' + +useAppStore.setState({ appDetail: { id: 'test', name: 'Test' } }) +render() + +// ❌ WRONG: Don't mock the store module +vi.mock('@/app/components/app/store', () => ({ ... })) +``` + +See [Zustand Store Testing](#zustand-store-testing) section for full details. + ## Mock Placement | Location | Purpose | |----------|---------| -| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) | +| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) | +| `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) | | `web/__mocks__/` | Reusable mock factories shared across multiple test files | | Test file | Test-specific mocks, inline with `vi.mock()` | Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`. +**Note**: Zustand is special - it's globally mocked but you should NOT mock store modules manually. See [Zustand Store Testing](#zustand-store-testing). + ## Essential Mocks ### 1. i18n (Auto-loaded via Global Mock) A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup. -**No explicit mock needed** for most tests - it returns translation keys as-is. -For tests requiring custom translations, override the mock: +The global mock provides: + +- `useTranslation` - returns translation keys with namespace prefix +- `Trans` component - renders i18nKey and components +- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`) +- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'` + +**Default behavior**: Most tests should use the global mock (no local override needed). + +**For custom translations**: Use the helper function from `@/test/i18n-mock`: ```typescript -vi.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => { - const translations: Record = { - 'my.custom.key': 'Custom translation', - } - return translations[key] || key - }, - }), +import { createReactI18nextMock } from '@/test/i18n-mock' + +vi.mock('react-i18next', () => createReactI18nextMock({ + 'my.custom.key': 'Custom translation', + 'button.save': 'Save', })) ``` +**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this. + ### 2. Next.js Router ```typescript @@ -270,6 +296,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { 1. **Use real base components** - Import from `@/app/components/base/` directly 1. **Use real project components** - Prefer importing over mocking +1. **Use real Zustand stores** - Set test state via `store.setState()` 1. **Reset mocks in `beforeEach`**, not `afterEach` 1. **Match actual component behavior** in mocks (when mocking is necessary) 1. **Use factory functions** for complex mock data @@ -279,6 +306,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ❌ DON'T 1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +1. **Don't mock Zustand store modules** - Use real stores with `setState()` 1. Don't mock components you can import directly 1. Don't create overly simplified mocks that miss conditional logic 1. Don't forget to clean up nock after each test @@ -302,10 +330,151 @@ Need to use a component in test? ├─ Is it a third-party lib with side effects? │ └─ YES → Mock it (next/navigation, external SDKs) │ +├─ Is it a Zustand store? +│ └─ YES → DO NOT mock the module! +│ Use real store + setState() to set test state +│ (Global mock handles auto-reset) +│ └─ Is it i18n? └─ YES → Uses shared mock (auto-loaded). Override only for custom translations ``` +## Zustand Store Testing + +### Global Zustand Mock (Auto-loaded) + +Zustand is globally mocked in `web/vitest.setup.ts` following the [official Zustand testing guide](https://zustand.docs.pmnd.rs/guides/testing). The mock in `web/__mocks__/zustand.ts` provides: + +- Real store behavior with `getState()`, `setState()`, `subscribe()` methods +- Automatic store reset after each test via `afterEach` +- Proper test isolation between tests + +### ✅ Recommended: Use Real Stores (Official Best Practice) + +**DO NOT mock store modules manually.** Import and use the real store, then use `setState()` to set test state: + +```typescript +// ✅ CORRECT: Use real store with setState +import { useAppStore } from '@/app/components/app/store' + +describe('MyComponent', () => { + it('should render app details', () => { + // Arrange: Set test state via setState + useAppStore.setState({ + appDetail: { + id: 'test-app', + name: 'Test App', + mode: 'chat', + }, + }) + + // Act + render() + + // Assert + expect(screen.getByText('Test App')).toBeInTheDocument() + // Can also verify store state directly + expect(useAppStore.getState().appDetail?.name).toBe('Test App') + }) + + // No cleanup needed - global mock auto-resets after each test +}) +``` + +### ❌ Avoid: Manual Store Module Mocking + +Manual mocking conflicts with the global Zustand mock and loses store functionality: + +```typescript +// ❌ WRONG: Don't mock the store module +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector) => mockSelector(selector), // Missing getState, setState! +})) + +// ❌ WRONG: This conflicts with global zustand mock +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: vi.fn(() => mockState), +})) +``` + +**Problems with manual mocking:** + +1. Loses `getState()`, `setState()`, `subscribe()` methods +1. Conflicts with global Zustand mock behavior +1. Requires manual maintenance of store API +1. Tests don't reflect actual store behavior + +### When Manual Store Mocking is Necessary + +In rare cases where the store has complex initialization or side effects, you can mock it, but ensure you provide the full store API: + +```typescript +// If you MUST mock (rare), include full store API +const mockStore = { + appDetail: { id: 'test', name: 'Test' }, + setAppDetail: vi.fn(), +} + +vi.mock('@/app/components/app/store', () => ({ + useStore: Object.assign( + (selector: (state: typeof mockStore) => unknown) => selector(mockStore), + { + getState: () => mockStore, + setState: vi.fn(), + subscribe: vi.fn(), + }, + ), +})) +``` + +### Store Testing Decision Tree + +``` +Need to test a component using Zustand store? +│ +├─ Can you use the real store? +│ └─ YES → Use real store + setState (RECOMMENDED) +│ useAppStore.setState({ ... }) +│ +├─ Does the store have complex initialization/side effects? +│ └─ YES → Consider mocking, but include full API +│ (getState, setState, subscribe) +│ +└─ Are you testing the store itself (not a component)? + └─ YES → Test store directly with getState/setState + const store = useMyStore + store.setState({ count: 0 }) + store.getState().increment() + expect(store.getState().count).toBe(1) +``` + +### Example: Testing Store Actions + +```typescript +import { useCounterStore } from '@/stores/counter' + +describe('Counter Store', () => { + it('should increment count', () => { + // Initial state (auto-reset by global mock) + expect(useCounterStore.getState().count).toBe(0) + + // Call action + useCounterStore.getState().increment() + + // Verify state change + expect(useCounterStore.getState().count).toBe(1) + }) + + it('should reset to initial state', () => { + // Set some state + useCounterStore.setState({ count: 100 }) + expect(useCounterStore.getState().count).toBe(100) + + // After this test, global mock will reset to initial state + }) +}) +``` + ## Factory Function Pattern ```typescript diff --git a/.claude/skills/frontend-testing/references/workflow.md b/.agents/skills/frontend-testing/references/workflow.md similarity index 100% rename from .claude/skills/frontend-testing/references/workflow.md rename to .agents/skills/frontend-testing/references/workflow.md diff --git a/.agents/skills/orpc-contract-first/SKILL.md b/.agents/skills/orpc-contract-first/SKILL.md new file mode 100644 index 0000000000..4e3bfc7a37 --- /dev/null +++ b/.agents/skills/orpc-contract-first/SKILL.md @@ -0,0 +1,46 @@ +--- +name: orpc-contract-first +description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories. +--- + +# oRPC Contract-First Development + +## Project Structure + +``` +web/contract/ +├── base.ts # Base contract (inputStructure: 'detailed') +├── router.ts # Router composition & type exports +├── marketplace.ts # Marketplace contracts +└── console/ # Console contracts by domain + ├── system.ts + └── billing.ts +``` + +## Workflow + +1. **Create contract** in `web/contract/console/{domain}.ts` + - Import `base` from `../base` and `type` from `@orpc/contract` + - Define route with `path`, `method`, `input`, `output` + +2. **Register in router** at `web/contract/router.ts` + - Import directly from domain file (no barrel files) + - Nest by API prefix: `billing: { invoices, bindPartnerStack }` + +3. **Create hooks** in `web/service/use-{domain}.ts` + - Use `consoleQuery.{group}.{contract}.queryKey()` for query keys + - Use `consoleClient.{group}.{contract}()` for API calls + +## Key Rules + +- **Input structure**: Always use `{ params, query?, body? }` format +- **Path params**: Use `{paramName}` in path, match in `params` object +- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`) +- **No barrel files**: Import directly from specific files +- **Types**: Import from `@/types/`, use `type()` helper + +## Type Export + +```typescript +export type ConsoleInputs = InferContractRouterInputs +``` diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 0000000000..fe108722be --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,15 @@ +{ + "hooks": { + "PreToolUse": [ + { + "matcher": "Bash", + "hooks": [ + { + "type": "command", + "command": "npx -y block-no-verify@1.1.1" + } + ] + } + ] + } +} diff --git a/.claude/settings.json.example b/.claude/settings.json.example deleted file mode 100644 index 1149895340..0000000000 --- a/.claude/settings.json.example +++ /dev/null @@ -1,19 +0,0 @@ -{ - "permissions": { - "allow": [], - "deny": [] - }, - "env": { - "__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.", - "GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" - }, - "enabledMcpjsonServers": [ - "context7", - "sequential-thinking", - "github", - "fetch", - "playwright", - "ide" - ], - "enableAllProjectMcpServers": true - } \ No newline at end of file diff --git a/.claude/skills/component-refactoring b/.claude/skills/component-refactoring new file mode 120000 index 0000000000..53ae67e2f2 --- /dev/null +++ b/.claude/skills/component-refactoring @@ -0,0 +1 @@ +../../.agents/skills/component-refactoring \ No newline at end of file diff --git a/.claude/skills/frontend-code-review b/.claude/skills/frontend-code-review new file mode 120000 index 0000000000..55654ffbd7 --- /dev/null +++ b/.claude/skills/frontend-code-review @@ -0,0 +1 @@ +../../.agents/skills/frontend-code-review \ No newline at end of file diff --git a/.claude/skills/frontend-testing b/.claude/skills/frontend-testing new file mode 120000 index 0000000000..092cec7745 --- /dev/null +++ b/.claude/skills/frontend-testing @@ -0,0 +1 @@ +../../.agents/skills/frontend-testing \ No newline at end of file diff --git a/.claude/skills/orpc-contract-first b/.claude/skills/orpc-contract-first new file mode 120000 index 0000000000..da47b335c7 --- /dev/null +++ b/.claude/skills/orpc-contract-first @@ -0,0 +1 @@ +../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.codex/skills b/.codex/skills deleted file mode 120000 index 454b8427cd..0000000000 --- a/.codex/skills +++ /dev/null @@ -1 +0,0 @@ -../.claude/skills \ No newline at end of file diff --git a/.codex/skills/component-refactoring b/.codex/skills/component-refactoring new file mode 120000 index 0000000000..53ae67e2f2 --- /dev/null +++ b/.codex/skills/component-refactoring @@ -0,0 +1 @@ +../../.agents/skills/component-refactoring \ No newline at end of file diff --git a/.codex/skills/frontend-code-review b/.codex/skills/frontend-code-review new file mode 120000 index 0000000000..55654ffbd7 --- /dev/null +++ b/.codex/skills/frontend-code-review @@ -0,0 +1 @@ +../../.agents/skills/frontend-code-review \ No newline at end of file diff --git a/.codex/skills/frontend-testing b/.codex/skills/frontend-testing new file mode 120000 index 0000000000..092cec7745 --- /dev/null +++ b/.codex/skills/frontend-testing @@ -0,0 +1 @@ +../../.agents/skills/frontend-testing \ No newline at end of file diff --git a/.codex/skills/orpc-contract-first b/.codex/skills/orpc-contract-first new file mode 120000 index 0000000000..da47b335c7 --- /dev/null +++ b/.codex/skills/orpc-contract-first @@ -0,0 +1 @@ +../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 220f77e5ce..637593b9de 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -8,7 +8,7 @@ pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc -echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc +echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 0000000000..d1d324d381 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,3 @@ +web: + - changed-files: + - any-glob-to-any-file: 'web/**' diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index aa5a50918a..50dbde2aee 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -20,4 +20,4 @@ - [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) - [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. - [x] I've updated the documentation accordingly. -- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods +- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 76cbf64fca..190e00d9fe 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -22,12 +22,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -39,12 +39,6 @@ jobs: - name: Install dependencies run: uv sync --project api --dev - - name: Run pyrefly check - run: | - cd api - uv add --dev pyrefly - uv run pyrefly check || true - - name: Run dify config tests run: uv run --project api dev/pytest/pytest_config_tests.py @@ -57,7 +51,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@v2.0.2 + uses: hoverkraft-tech/compose-action@v2 with: compose-file: | docker/docker-compose.middleware.yaml diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 97027c2218..4a8c61e7d2 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -12,22 +12,22 @@ jobs: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Check Docker Compose inputs id: docker-compose-changes - uses: tj-actions/changed-files@v46 + uses: tj-actions/changed-files@v47 with: files: | docker/generate_docker_compose docker/.env.example docker/docker-compose-template.yaml docker/docker-compose.yaml - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: "3.11" - - uses: astral-sh/setup-uv@v6 + - uses: astral-sh/setup-uv@v7 - name: Generate Docker Compose if: steps.docker-compose-changes.outputs.any_changed == 'true' @@ -79,9 +79,32 @@ jobs: find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 24 + cache: pnpm + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Install web dependencies + run: | + cd web + pnpm install --frozen-lockfile + + - name: ESLint autofix + run: | + cd web + pnpm lint:fix || true + # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - name: mdformat run: | - uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md" + uvx --python 3.13 mdformat . --exclude ".agents/skills/**" - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 44c9a92ab7..ac7f3a6b48 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -91,7 +91,7 @@ jobs: touch "/tmp/digests/${sanitized_digest}" - name: Upload digest - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }} path: /tmp/digests/* @@ -113,7 +113,7 @@ jobs: context: "web" steps: - name: Download digests - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: path: /tmp/digests pattern: digests-${{ matrix.context }}-* diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index 101d973466..e20cf9850b 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -13,13 +13,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 with: enable-cache: true python-version: "3.12" @@ -63,13 +63,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/deploy-trigger-dev.yml b/.github/workflows/deploy-agent-dev.yml similarity index 69% rename from .github/workflows/deploy-trigger-dev.yml rename to .github/workflows/deploy-agent-dev.yml index 2d9a904fc5..dd759f7ba5 100644 --- a/.github/workflows/deploy-trigger-dev.yml +++ b/.github/workflows/deploy-agent-dev.yml @@ -1,4 +1,4 @@ -name: Deploy Trigger Dev +name: Deploy Agent Dev permissions: contents: read @@ -7,7 +7,7 @@ on: workflow_run: workflows: ["Build and Push API & Web"] branches: - - "deploy/trigger-dev" + - "deploy/agent-dev" types: - completed @@ -16,12 +16,12 @@ jobs: runs-on: ubuntu-latest if: | github.event.workflow_run.conclusion == 'success' && - github.event.workflow_run.head_branch == 'deploy/trigger-dev' + github.event.workflow_run.head_branch == 'deploy/agent-dev' steps: - name: Deploy to server - uses: appleboy/ssh-action@v0.1.8 + uses: appleboy/ssh-action@v1 with: - host: ${{ secrets.TRIGGER_SSH_HOST }} + host: ${{ secrets.AGENT_DEV_SSH_HOST }} username: ${{ secrets.SSH_USER }} key: ${{ secrets.SSH_PRIVATE_KEY }} script: | diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index cd1c86e668..38fa0b9a7f 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -16,7 +16,7 @@ jobs: github.event.workflow_run.head_branch == 'deploy/dev' steps: - name: Deploy to server - uses: appleboy/ssh-action@v0.1.8 + uses: appleboy/ssh-action@v1 with: host: ${{ secrets.SSH_HOST }} username: ${{ secrets.SSH_USER }} diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml new file mode 100644 index 0000000000..7d5f0a22e7 --- /dev/null +++ b/.github/workflows/deploy-hitl.yml @@ -0,0 +1,29 @@ +name: Deploy HITL + +on: + workflow_run: + workflows: ["Build and Push API & Web"] + branches: + - "feat/hitl-frontend" + - "feat/hitl-backend" + types: + - completed + +jobs: + deploy: + runs-on: ubuntu-latest + if: | + github.event.workflow_run.conclusion == 'success' && + ( + github.event.workflow_run.head_branch == 'feat/hitl-frontend' || + github.event.workflow_run.head_branch == 'feat/hitl-backend' + ) + steps: + - name: Deploy to server + uses: appleboy/ssh-action@v1 + with: + host: ${{ secrets.HITL_SSH_HOST }} + username: ${{ secrets.SSH_USER }} + key: ${{ secrets.SSH_PRIVATE_KEY }} + script: | + ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }} diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 0000000000..06782b53c1 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,14 @@ +name: "Pull Request Labeler" +on: + pull_request_target: + +jobs: + labeler: + permissions: + contents: read + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v6 + with: + sync-labels: true diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index 876ec23a3d..d6653de950 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -27,7 +27,7 @@ jobs: vdb-changed: ${{ steps.changes.outputs.vdb }} migration-changed: ${{ steps.changes.outputs.migration }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 id: changes with: @@ -38,6 +38,7 @@ jobs: - '.github/workflows/api-tests.yml' web: - 'web/**' + - '.github/workflows/web-tests.yml' vdb: - 'api/core/rag/datasource/**' - 'docker/**' diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 1870b1f670..b6df1d7e93 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -18,7 +18,7 @@ jobs: pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v10 with: days-before-issue-stale: 15 days-before-issue-close: 3 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 8710f422fc..fdc05d1d65 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -19,13 +19,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v46 + uses: tj-actions/changed-files@v47 with: files: | api/** @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 with: enable-cache: false python-version: "3.12" @@ -65,18 +65,23 @@ jobs: defaults: run: working-directory: ./web + permissions: + checks: write + pull-requests: read steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v46 + uses: tj-actions/changed-files@v47 with: - files: web/** + files: | + web/** + .github/workflows/style.yml - name: Install pnpm uses: pnpm/action-setup@v4 @@ -85,10 +90,10 @@ jobs: run_install: false - name: Setup NodeJS - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 if: steps.changed-files.outputs.any_changed == 'true' with: - node-version: 22 + node-version: 24 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml @@ -101,12 +106,31 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web run: | - pnpm run lint + pnpm run lint:ci + # pnpm run lint:report + # continue-on-error: true + + # - name: Annotate Code + # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request' + # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae + # with: + # eslint-report: web/eslint_report.json + # github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Web tsslint + if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web + run: pnpm run lint:tss - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run type-check:tsgo + run: pnpm run type-check + + - name: Web dead code check + if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web + run: pnpm run knip superlinter: name: SuperLinter @@ -114,14 +138,14 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v46 + uses: tj-actions/changed-files@v47 with: files: | **.sh diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index b1ccd7417a..ec392cb3b2 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -16,23 +16,19 @@ jobs: name: unit test for Node.js SDK runs-on: ubuntu-latest - strategy: - matrix: - node-version: [16, 18, 20, 22] - defaults: run: working-directory: sdks/nodejs-client steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: persist-credentials: false - - name: Use Node.js ${{ matrix.node-version }} - uses: actions/setup-node@v4 + - name: Use Node.js + uses: actions/setup-node@v6 with: - node-version: ${{ matrix.node-version }} + node-version: 24 cache: '' cache-dependency-path: 'pnpm-lock.yaml' diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml deleted file mode 100644 index 87e24a4f90..0000000000 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ /dev/null @@ -1,85 +0,0 @@ -name: Translate i18n Files Based on English - -on: - push: - branches: [main] - paths: - - 'web/i18n/en-US/*.ts' - -permissions: - contents: write - pull-requests: write - -jobs: - check-and-update: - if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest - defaults: - run: - working-directory: web - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - token: ${{ secrets.GITHUB_TOKEN }} - - - name: Check for file changes in i18n/en-US - id: check_files - run: | - git fetch origin "${{ github.event.before }}" || true - git fetch origin "${{ github.sha }}" || true - changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts') - echo "Changed files: $changed_files" - if [ -n "$changed_files" ]; then - echo "FILES_CHANGED=true" >> $GITHUB_ENV - file_args="" - for file in $changed_files; do - filename=$(basename "$file" .ts) - file_args="$file_args --file $filename" - done - echo "FILE_ARGS=$file_args" >> $GITHUB_ENV - echo "File arguments: $file_args" - else - echo "FILES_CHANGED=false" >> $GITHUB_ENV - fi - - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Set up Node.js - if: env.FILES_CHANGED == 'true' - uses: actions/setup-node@v4 - with: - node-version: 'lts/*' - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Install dependencies - if: env.FILES_CHANGED == 'true' - working-directory: ./web - run: pnpm install --frozen-lockfile - - - name: Generate i18n translations - if: env.FILES_CHANGED == 'true' - working-directory: ./web - run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} - - - name: Create Pull Request - if: env.FILES_CHANGED == 'true' - uses: peter-evans/create-pull-request@v6 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: 'chore(i18n): update translations based on en-US changes' - title: 'chore(i18n): translate i18n files based on en-US changes' - body: | - This PR was automatically created to update i18n translation files based on changes in en-US locale. - - **Triggered by:** ${{ github.sha }} - - **Changes included:** - - Updated translation files for all locales - branch: chore/automated-i18n-updates-${{ github.sha }} - delete-branch: true diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml new file mode 100644 index 0000000000..5d9440ff35 --- /dev/null +++ b/.github/workflows/translate-i18n-claude.yml @@ -0,0 +1,440 @@ +name: Translate i18n Files with Claude Code + +# Note: claude-code-action doesn't support push events directly. +# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch. +# See: https://github.com/langgenius/dify/issues/30743 + +on: + repository_dispatch: + types: [i18n-sync] + workflow_dispatch: + inputs: + files: + description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.' + required: false + type: string + languages: + description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.' + required: false + type: string + mode: + description: 'Sync mode: incremental (only changes) or full (re-check all keys)' + required: false + default: 'incremental' + type: choice + options: + - incremental + - full + +permissions: + contents: write + pull-requests: write + +jobs: + translate: + if: github.repository == 'langgenius/dify' + runs-on: ubuntu-latest + timeout-minutes: 60 + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Configure Git + run: | + git config --global user.name "github-actions[bot]" + git config --global user.email "github-actions[bot]@users.noreply.github.com" + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: 24 + cache: pnpm + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Detect changed files and generate diff + id: detect_changes + run: | + if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + # Manual trigger + if [ -n "${{ github.event.inputs.files }}" ]; then + echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT + else + # Get all JSON files in en-US directory + files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ') + echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT + fi + echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT + echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT + + # For manual trigger with incremental mode, get diff from last commit + # For full mode, we'll do a complete check anyway + if [ "${{ github.event.inputs.mode }}" == "full" ]; then + echo "Full mode: will check all keys" > /tmp/i18n-diff.txt + echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + else + git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt + if [ -s /tmp/i18n-diff.txt ]; then + echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT + else + echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + fi + fi + elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then + # Triggered by push via trigger-i18n-sync.yml workflow + # Validate required payload fields + if [ -z "${{ github.event.client_payload.changed_files }}" ]; then + echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2 + exit 1 + fi + echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT + echo "TARGET_LANGS=" >> $GITHUB_OUTPUT + echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT + + # Decode the base64-encoded diff from the trigger workflow + if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then + if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then + echo "Warning: Failed to decode base64 diff payload" >&2 + echo "" > /tmp/i18n-diff.txt + echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + elif [ -s /tmp/i18n-diff.txt ]; then + echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT + else + echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + fi + else + echo "" > /tmp/i18n-diff.txt + echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + fi + else + echo "Unsupported event type: ${{ github.event_name }}" + exit 1 + fi + + # Truncate diff if too large (keep first 50KB) + if [ -f /tmp/i18n-diff.txt ]; then + head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt + mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt + fi + + echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')" + + - name: Run Claude Code for Translation Sync + if: steps.detect_changes.outputs.CHANGED_FILES != '' + uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + github_token: ${{ secrets.GITHUB_TOKEN }} + # Allow github-actions bot to trigger this workflow via repository_dispatch + # See: https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md + allowed_bots: 'github-actions[bot]' + prompt: | + You are a professional i18n synchronization engineer for the Dify project. + Your task is to keep all language translations in sync with the English source (en-US). + + ## CRITICAL TOOL RESTRICTIONS + - Use **Read** tool to read files (NOT cat or bash) + - Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts) + - Use **Bash** ONLY for: git commands, gh commands, pnpm commands + - Run bash commands ONE BY ONE, never combine with && or || + - NEVER use `$()` command substitution - it's not supported. Split into separate commands instead. + + ## WORKING DIRECTORY & ABSOLUTE PATHS + Claude Code sandbox working directory may vary. Always use absolute paths: + - For pnpm: `pnpm --dir ${{ github.workspace }}/web ` + - For git: `git -C ${{ github.workspace }} ` + - For gh: `gh --repo ${{ github.repository }} ` + - For file paths: `${{ github.workspace }}/web/i18n/` + + ## EFFICIENCY RULES + - **ONE Edit per language file** - batch all key additions into a single Edit + - Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them + - Translate ALL keys for a language mentally first, then do ONE Edit + + ## Context + - Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }} + - Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }} + - Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }} + - Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json + - Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts + - Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }} + + ## CRITICAL DESIGN: Verify First, Then Sync + + You MUST follow this three-phase approach: + + ═══════════════════════════════════════════════════════════════ + ║ PHASE 1: VERIFY - Analyze and Generate Change Report ║ + ═══════════════════════════════════════════════════════════════ + + ### Step 1.1: Analyze Git Diff (for incremental mode) + Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff. + + Parse the diff to categorize changes: + - Lines with `+` (not `+++`): Added or modified values + - Lines with `-` (not `---`): Removed or old values + - Identify specific keys for each category: + * ADD: Keys that appear only in `+` lines (new keys) + * UPDATE: Keys that appear in both `-` and `+` lines (value changed) + * DELETE: Keys that appear only in `-` lines (removed keys) + + ### Step 1.2: Read Language Configuration + Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`. + Extract all languages with `supported: true`. + + ### Step 1.3: Run i18n:check for Each Language + ```bash + pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile + ``` + ```bash + pnpm --dir ${{ github.workspace }}/web run i18n:check + ``` + + This will report: + - Missing keys (need to ADD) + - Extra keys (need to DELETE) + + ### Step 1.4: Generate Change Report + + Create a structured report identifying: + ``` + ╔══════════════════════════════════════════════════════════════╗ + ║ I18N SYNC CHANGE REPORT ║ + ╠══════════════════════════════════════════════════════════════╣ + ║ Files to process: [list] ║ + ║ Languages to sync: [list] ║ + ╠══════════════════════════════════════════════════════════════╣ + ║ ADD (New Keys): ║ + ║ - [filename].[key]: "English value" ║ + ║ ... ║ + ╠══════════════════════════════════════════════════════════════╣ + ║ UPDATE (Modified Keys - MUST re-translate): ║ + ║ - [filename].[key]: "Old value" → "New value" ║ + ║ ... ║ + ╠══════════════════════════════════════════════════════════════╣ + ║ DELETE (Extra Keys): ║ + ║ - [language]/[filename].[key] ║ + ║ ... ║ + ╚══════════════════════════════════════════════════════════════╝ + ``` + + **IMPORTANT**: For UPDATE detection, compare git diff to find keys where + the English value changed. These MUST be re-translated even if target + language already has a translation (it's now stale!). + + ═══════════════════════════════════════════════════════════════ + ║ PHASE 2: SYNC - Execute Changes Based on Report ║ + ═══════════════════════════════════════════════════════════════ + + ### Step 2.1: Process ADD Operations (BATCH per language file) + + **CRITICAL WORKFLOW for efficiency:** + 1. First, translate ALL new keys for ALL languages mentally + 2. Then, for EACH language file, do ONE Edit operation: + - Read the file once + - Insert ALL new keys at the beginning (right after the opening `{`) + - Don't worry about alphabetical order - lint:fix will sort them later + + Example Edit (adding 3 keys to zh-Hans/app.json): + ``` + old_string: '{\n "accessControl"' + new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"' + ``` + + **IMPORTANT**: + - ONE Edit per language file (not one Edit per key!) + - Always use the Edit tool. NEVER use bash scripts, node, or jq. + + ### Step 2.2: Process UPDATE Operations + + **IMPORTANT: Special handling for zh-Hans and ja-JP** + If zh-Hans or ja-JP files were ALSO modified in the same push: + - Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files + - If found, it means someone manually translated them. Apply these rules: + + 1. **Missing keys**: Still ADD them (completeness required) + 2. **Existing translations**: Compare with the NEW English value: + - If translation is **completely wrong** or **unrelated** → Update it + - If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work + - When in doubt, **keep the manual translation** + + Example: + - English changed: "Save" → "Save Changes" + - Manual translation: "保存更改" → Keep it (correct meaning) + - Manual translation: "删除" → Update it (completely wrong) + + For other languages: + Use Edit tool to replace the old value with the new translation. + You can batch multiple updates in one Edit if they are adjacent. + + ### Step 2.3: Process DELETE Operations + For extra keys reported by i18n:check: + - Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove` + - Or manually remove from target language JSON files + + ## Translation Guidelines + + - PRESERVE all placeholders exactly as-is: + - `{{variable}}` - Mustache interpolation + - `${variable}` - Template literal + - `content` - HTML tags + - `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values) + + **CRITICAL: Variable names and tag names MUST stay in English - NEVER translate them** + + ✅ CORRECT examples: + - English: "{{count}} items" → Japanese: "{{count}} 個のアイテム" + - English: "{{name}} updated" → Korean: "{{name}} 업데이트됨" + - English: "{{email}}" → Chinese: "{{email}}" + - English: "Marketplace" → Japanese: "マーケットプレイス" + + ❌ WRONG examples (NEVER do this - will break the application): + - "{{count}}" → "{{カウント}}" ❌ (variable name translated to Japanese) + - "{{name}}" → "{{이름}}" ❌ (variable name translated to Korean) + - "{{email}}" → "{{邮箱}}" ❌ (variable name translated to Chinese) + - "" → "<メール>" ❌ (tag name translated) + - "" → "<自定义链接>" ❌ (component name translated) + + - Use appropriate language register (formal/informal) based on existing translations + - Match existing translation style in each language + - Technical terms: check existing conventions per language + - For CJK languages: no spaces between characters unless necessary + - For RTL languages (ar-TN, fa-IR): ensure proper text handling + + ## Output Format Requirements + - Alphabetical key ordering (if original file uses it) + - 2-space indentation + - Trailing newline at end of file + - Valid JSON (use proper escaping for special characters) + + ═══════════════════════════════════════════════════════════════ + ║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║ + ═══════════════════════════════════════════════════════════════ + + ### Step 3.1: Run Lint Fix (IMPORTANT!) + ```bash + pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json' + ``` + This ensures: + - JSON keys are sorted alphabetically (jsonc/sort-keys rule) + - Valid i18n keys (dify-i18n/valid-i18n-keys rule) + - No extra keys (dify-i18n/no-extra-keys rule) + + ### Step 3.2: Run Final i18n Check + ```bash + pnpm --dir ${{ github.workspace }}/web run i18n:check + ``` + + ### Step 3.3: Fix Any Remaining Issues + If check reports issues: + - Go back to PHASE 2 for unresolved items + - Repeat until check passes + + ### Step 3.4: Generate Final Summary + ``` + ╔══════════════════════════════════════════════════════════════╗ + ║ SYNC COMPLETED SUMMARY ║ + ╠══════════════════════════════════════════════════════════════╣ + ║ Language │ Added │ Updated │ Deleted │ Status ║ + ╠══════════════════════════════════════════════════════════════╣ + ║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║ + ║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║ + ║ ... │ ... │ ... │ ... │ ... ║ + ╠══════════════════════════════════════════════════════════════╣ + ║ i18n:check │ PASSED - All keys in sync ║ + ╚══════════════════════════════════════════════════════════════╝ + ``` + + ## Mode-Specific Behavior + + **SYNC_MODE = "incremental"** (default): + - Focus on keys identified from git diff + - Also check i18n:check output for any missing/extra keys + - Efficient for small changes + + **SYNC_MODE = "full"**: + - Compare ALL keys between en-US and each language + - Run i18n:check to identify all discrepancies + - Use for first-time sync or fixing historical issues + + ## Important Notes + + 1. Always run i18n:check BEFORE and AFTER making changes + 2. The check script is the source of truth for missing/extra keys + 3. For UPDATE scenario: git diff is the source of truth for changed values + 4. Create a single commit with all translation changes + 5. If any translation fails, continue with others and report failures + + ═══════════════════════════════════════════════════════════════ + ║ PHASE 4: COMMIT AND CREATE PR ║ + ═══════════════════════════════════════════════════════════════ + + After all translations are complete and verified: + + ### Step 4.1: Check for changes + ```bash + git -C ${{ github.workspace }} status --porcelain + ``` + + If there are changes: + + ### Step 4.2: Create a new branch and commit + Run these git commands ONE BY ONE (not combined with &&). + **IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands: + + 1. First, get the timestamp: + ```bash + date +%Y%m%d-%H%M%S + ``` + (Note the output, e.g., "20260115-143052") + + 2. Then create branch using the timestamp value: + ```bash + git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052 + ``` + (Replace "20260115-143052" with the actual timestamp from step 1) + + 3. Stage changes: + ```bash + git -C ${{ github.workspace }} add web/i18n/ + ``` + + 4. Commit: + ```bash + git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}" + ``` + + 5. Push: + ```bash + git -C ${{ github.workspace }} push origin HEAD + ``` + + ### Step 4.3: Create Pull Request + ```bash + gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary + + This PR was automatically generated to sync i18n translation files. + + ### Changes + - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }} + - Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }} + + ### Verification + - [x] \`i18n:check\` passed + - [x] \`lint:fix\` applied + + 🤖 Generated with Claude Code GitHub Action" --base main + ``` + + claude_args: | + --max-turns 150 + --allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep" diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml new file mode 100644 index 0000000000..66a29453b4 --- /dev/null +++ b/.github/workflows/trigger-i18n-sync.yml @@ -0,0 +1,66 @@ +name: Trigger i18n Sync on Push + +# This workflow bridges the push event to repository_dispatch +# because claude-code-action doesn't support push events directly. +# See: https://github.com/langgenius/dify/issues/30743 + +on: + push: + branches: [main] + paths: + - 'web/i18n/en-US/*.json' + +permissions: + contents: write + +jobs: + trigger: + if: github.repository == 'langgenius/dify' + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Detect changed files and generate diff + id: detect + run: | + BEFORE_SHA="${{ github.event.before }}" + # Handle edge case: force push may have null/zero SHA + if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then + BEFORE_SHA="HEAD~1" + fi + + # Detect changed i18n files + changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "") + echo "changed_files=$changed" >> $GITHUB_OUTPUT + + # Generate diff for context + git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt + + # Truncate if too large (keep first 50KB to match receiving workflow) + head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt + mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt + + # Base64 encode the diff for safe JSON transport (portable, single-line) + diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n') + echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT + + if [ -n "$changed" ]; then + echo "has_changes=true" >> $GITHUB_OUTPUT + echo "Detected changed files: $changed" + else + echo "has_changes=false" >> $GITHUB_OUTPUT + echo "No i18n changes detected" + fi + + - name: Trigger i18n sync workflow + if: steps.detect.outputs.has_changes == 'true' + uses: peter-evans/repository-dispatch@v3 + with: + token: ${{ secrets.GITHUB_TOKEN }} + event-type: i18n-sync + client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}' diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 291171e5c7..7735afdaca 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -19,19 +19,19 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: persist-credentials: false - name: Free Disk Space - uses: endersonmenezes/free-disk-space@v2 + uses: endersonmenezes/free-disk-space@v3 with: remove_dotnet: true remove_haskell: true remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 with: enable-cache: true python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 1a8925e38d..191ce56aaa 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: persist-credentials: false @@ -29,9 +29,9 @@ jobs: run_install: false - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: - node-version: 22 + node-version: 24 cache: pnpm cache-dependency-path: ./web/pnpm-lock.yaml @@ -360,9 +360,54 @@ jobs: - name: Upload Coverage Artifact if: steps.coverage-summary.outputs.has_coverage == 'true' - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: web-coverage-report path: web/coverage retention-days: 30 if-no-files-found: error + + web-build: + name: Web Build + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./web + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Check changed files + id: changed-files + uses: tj-actions/changed-files@v47 + with: + files: | + web/** + .github/workflows/web-tests.yml + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup NodeJS + uses: actions/setup-node@v6 + if: steps.changed-files.outputs.any_changed == 'true' + with: + node-version: 24 + cache: pnpm + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Web dependencies + if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web + run: pnpm install --frozen-lockfile + + - name: Web build check + if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web + run: pnpm run build diff --git a/.gitignore b/.gitignore index 17a2bd5b7b..7bd919f095 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,4 @@ scripts/stress-test/reports/ # settings *.local.json +*.local.md diff --git a/.mcp.json b/.mcp.json deleted file mode 100644 index 8eceaf9ead..0000000000 --- a/.mcp.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "mcpServers": { - "context7": { - "type": "http", - "url": "https://mcp.context7.com/mcp" - }, - "sequential-thinking": { - "type": "stdio", - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-sequential-thinking"], - "env": {} - }, - "github": { - "type": "stdio", - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-github"], - "env": { - "GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}" - } - }, - "fetch": { - "type": "stdio", - "command": "uvx", - "args": ["mcp-server-fetch"], - "env": {} - }, - "playwright": { - "type": "stdio", - "command": "npx", - "args": ["-y", "@playwright/mcp@latest"], - "env": {} - } - } - } \ No newline at end of file diff --git a/.nvmrc b/.nvmrc deleted file mode 100644 index 7af24b7ddb..0000000000 --- a/.nvmrc +++ /dev/null @@ -1 +0,0 @@ -22.11.0 diff --git a/AGENTS.md b/AGENTS.md index 782861ad36..7d96ac3a6d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -12,12 +12,8 @@ The codebase is split into: ## Backend Workflow +- Read `api/AGENTS.md` for details - Run backend CLI commands through `uv run --project api `. - -- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`. - -- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks. - - Integration tests are CI-only and are not expected to run in the local environment. ## Frontend Workflow @@ -29,6 +25,30 @@ pnpm type-check:tsgo pnpm test ``` +### Frontend Linting + +ESLint is used for frontend code quality. Available commands: + +```bash +# Lint all files (report only) +pnpm lint + +# Lint and auto-fix issues +pnpm lint:fix + +# Lint specific files or directories +pnpm lint:fix app/components/base/button/ +pnpm lint:fix app/components/base/button/index.tsx + +# Lint quietly (errors only, no warnings) +pnpm lint:quiet + +# Check code complexity +pnpm lint:complexity +``` + +**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files. + ## Testing & Quality Practices - Follow TDD: red → green → refactor. diff --git a/Makefile b/Makefile index 07afd8187e..e92a7b1314 100644 --- a/Makefile +++ b/Makefile @@ -60,9 +60,11 @@ check: @echo "✅ Code check complete" lint: - @echo "🔧 Running ruff format, check with fixes, and import linter..." - @uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api' + @echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..." + @uv run --project api --dev ruff format ./api + @uv run --project api --dev ruff check --fix ./api @uv run --directory api --dev lint-imports + @uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example @echo "✅ Linting complete" type-check: @@ -72,7 +74,12 @@ type-check: test: @echo "🧪 Running backend unit tests..." - @uv run --project api --dev dev/pytest/pytest_unit_tests.sh + @if [ -n "$(TARGET_TESTS)" ]; then \ + echo "Target: $(TARGET_TESTS)"; \ + uv run --project api --dev pytest $(TARGET_TESTS); \ + else \ + uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \ + fi @echo "✅ Tests complete" # Build Docker images @@ -122,9 +129,9 @@ help: @echo "Backend Code Quality:" @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" - @echo " make lint - Format and fix code with ruff" + @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" @echo " make type-check - Run type checking with basedpyright" - @echo " make test - Run backend unit tests" + @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" @echo " make build-web - Build web Docker image" diff --git a/api/.env.example b/api/.env.example index c195f1ae87..bf998a6cdc 100644 --- a/api/.env.example +++ b/api/.env.example @@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key S3_REGION=your-region +# Workflow run and Conversation archive storage (S3-compatible) +ARCHIVE_STORAGE_ENABLED=false +ARCHIVE_STORAGE_ENDPOINT= +ARCHIVE_STORAGE_ARCHIVE_BUCKET= +ARCHIVE_STORAGE_EXPORT_BUCKET= +ARCHIVE_STORAGE_ACCESS_KEY= +ARCHIVE_STORAGE_SECRET_KEY= +ARCHIVE_STORAGE_REGION=auto + # Azure Blob Storage configuration AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_KEY=your-account-key @@ -128,6 +137,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key TENCENT_COS_SECRET_ID=your-secret-id TENCENT_COS_REGION=your-region TENCENT_COS_SCHEME=your-scheme +TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain # Huawei OBS Storage Configuration HUAWEI_OBS_BUCKET_NAME=your-bucket-name @@ -407,6 +417,8 @@ SMTP_USERNAME=123 SMTP_PASSWORD=abc SMTP_USE_TLS=true SMTP_OPPORTUNISTIC_TLS=false +# Optional: override the local hostname used for SMTP HELO/EHLO +SMTP_LOCAL_HOSTNAME= # Sendgid configuration SENDGRID_API_KEY= # Sentry configuration @@ -492,6 +504,8 @@ LOG_FILE_BACKUP_COUNT=5 LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S # Log Timezone LOG_TZ=UTC +# Log output format: text or json +LOG_OUTPUT_FORMAT=text # Log format LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s @@ -563,6 +577,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true # Celery beat configuration CELERY_BEAT_SCHEDULER_TIME=1 @@ -573,6 +591,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false ENABLE_CREATE_TIDB_SERVERLESS_TASK=false ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false ENABLE_CLEAN_MESSAGES=false +ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true @@ -721,3 +740,7 @@ PUBSUB_REDIS_USE_CLUSTERS=false ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true # Human input timeout check interval in minutes HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 + + +SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 + diff --git a/api/.importlinter b/api/.importlinter index c99c72c5e3..d7f3767442 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -3,9 +3,11 @@ root_packages = core configs controllers + extensions models tasks services +include_external_packages = True [importlinter:contract:workflow] name = Workflow @@ -25,7 +27,9 @@ ignore_imports = core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events core.workflow.nodes.loop.loop_node -> core.workflow.graph_events - core.workflow.nodes.node_factory -> core.workflow.graph + core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory + core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory + core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine core.workflow.nodes.iteration.iteration_node -> core.workflow.graph core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels @@ -35,6 +39,270 @@ ignore_imports = # TODO(QuantumGhost): fix the import violation later core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities +[importlinter:contract:workflow-infrastructure-dependencies] +name = Workflow Infrastructure Dependencies +type = forbidden +source_modules = + core.workflow +forbidden_modules = + extensions.ext_database + extensions.ext_redis +allow_indirect_imports = True +ignore_imports = + core.workflow.nodes.agent.agent_node -> extensions.ext_database + core.workflow.nodes.datasource.datasource_node -> extensions.ext_database + core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database + core.workflow.nodes.llm.file_saver -> extensions.ext_database + core.workflow.nodes.llm.llm_utils -> extensions.ext_database + core.workflow.nodes.llm.node -> extensions.ext_database + core.workflow.nodes.tool.tool_node -> extensions.ext_database + core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis + core.workflow.graph_engine.manager -> extensions.ext_redis + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis + +[importlinter:contract:workflow-external-imports] +name = Workflow External Imports +type = forbidden +source_modules = + core.workflow +forbidden_modules = + configs + controllers + extensions + models + services + tasks + core.agent + core.app + core.base + core.callback_handler + core.datasource + core.db + core.entities + core.errors + core.extension + core.external_data_tool + core.file + core.helper + core.hosting_configuration + core.indexing_runner + core.llm_generator + core.logging + core.mcp + core.memory + core.model_manager + core.moderation + core.ops + core.plugin + core.prompt + core.provider_manager + core.rag + core.repositories + core.schemas + core.tools + core.trigger + core.variables +ignore_imports = + core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory + core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis + core.workflow.workflow_entry -> core.app.workflow.layers.observability + core.workflow.graph_engine.worker_management.worker_pool -> configs + core.workflow.nodes.agent.agent_node -> core.model_manager + core.workflow.nodes.agent.agent_node -> core.provider_manager + core.workflow.nodes.agent.agent_node -> core.tools.tool_manager + core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor + core.workflow.nodes.datasource.datasource_node -> models.model + core.workflow.nodes.datasource.datasource_node -> models.tools + core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service + core.workflow.nodes.document_extractor.node -> configs + core.workflow.nodes.document_extractor.node -> core.file.file_manager + core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy + core.workflow.nodes.http_request.entities -> configs + core.workflow.nodes.http_request.executor -> configs + core.workflow.nodes.http_request.executor -> core.file.file_manager + core.workflow.nodes.http_request.node -> configs + core.workflow.nodes.http_request.node -> core.tools.tool_file_manager + core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.llm.llm_utils -> configs + core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities + core.workflow.nodes.llm.llm_utils -> core.file.models + core.workflow.nodes.llm.llm_utils -> core.model_manager + core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.llm.llm_utils -> models.model + core.workflow.nodes.llm.llm_utils -> models.provider + core.workflow.nodes.llm.llm_utils -> services.credit_pool_service + core.workflow.nodes.llm.node -> core.tools.signature + core.workflow.nodes.template_transform.template_transform_node -> configs + core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler + core.workflow.nodes.tool.tool_node -> core.tools.tool_engine + core.workflow.nodes.tool.tool_node -> core.tools.tool_manager + core.workflow.workflow_entry -> configs + core.workflow.workflow_entry -> models.workflow + core.workflow.nodes.agent.agent_node -> core.agent.entities + core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities + core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.start.entities -> core.app.app_config.entities + core.workflow.nodes.start.start_node -> core.app.app_config.entities + core.workflow.workflow_entry -> core.app.apps.exc + core.workflow.workflow_entry -> core.app.entities.app_invoke_entities + core.workflow.workflow_entry -> core.app.workflow.node_factory + core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager + core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager + core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager + core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager + core.workflow.node_events.node -> core.file + core.workflow.nodes.agent.agent_node -> core.file + core.workflow.nodes.datasource.datasource_node -> core.file + core.workflow.nodes.datasource.datasource_node -> core.file.enums + core.workflow.nodes.document_extractor.node -> core.file + core.workflow.nodes.http_request.executor -> core.file.enums + core.workflow.nodes.http_request.node -> core.file + core.workflow.nodes.http_request.node -> core.file.file_manager + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models + core.workflow.nodes.list_operator.node -> core.file + core.workflow.nodes.llm.file_saver -> core.file + core.workflow.nodes.llm.llm_utils -> core.variables.segments + core.workflow.nodes.llm.node -> core.file + core.workflow.nodes.llm.node -> core.file.file_manager + core.workflow.nodes.llm.node -> core.file.models + core.workflow.nodes.loop.entities -> core.variables.types + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file + core.workflow.nodes.protocols -> core.file + core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models + core.workflow.nodes.tool.tool_node -> core.file + core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer + core.workflow.nodes.tool.tool_node -> models + core.workflow.nodes.trigger_webhook.node -> core.file + core.workflow.runtime.variable_pool -> core.file + core.workflow.runtime.variable_pool -> core.file.file_manager + core.workflow.system_variable -> core.file.models + core.workflow.utils.condition.processor -> core.file + core.workflow.utils.condition.processor -> core.file.file_manager + core.workflow.workflow_entry -> core.file.models + core.workflow.workflow_type_encoder -> core.file.models + core.workflow.nodes.agent.agent_node -> models.model + core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider + core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider + core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider + core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor + core.workflow.nodes.datasource.datasource_node -> core.variables.variables + core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy + core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy + core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy + core.workflow.nodes.llm.node -> core.helper.code_executor + core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor + core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors + core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output + core.workflow.nodes.llm.node -> core.model_manager + core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.llm.node -> models.dataset + core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer + core.workflow.nodes.llm.file_saver -> core.tools.signature + core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager + core.workflow.nodes.tool.tool_node -> core.tools.errors + core.workflow.conversation_variable_updater -> core.variables + core.workflow.graph_engine.entities.commands -> core.variables.variables + core.workflow.nodes.agent.agent_node -> core.variables.segments + core.workflow.nodes.answer.answer_node -> core.variables + core.workflow.nodes.code.code_node -> core.variables.segments + core.workflow.nodes.code.code_node -> core.variables.types + core.workflow.nodes.code.entities -> core.variables.types + core.workflow.nodes.datasource.datasource_node -> core.variables.segments + core.workflow.nodes.document_extractor.node -> core.variables + core.workflow.nodes.document_extractor.node -> core.variables.segments + core.workflow.nodes.http_request.executor -> core.variables.segments + core.workflow.nodes.http_request.node -> core.variables.segments + core.workflow.nodes.iteration.iteration_node -> core.variables + core.workflow.nodes.iteration.iteration_node -> core.variables.segments + core.workflow.nodes.iteration.iteration_node -> core.variables.variables + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments + core.workflow.nodes.list_operator.node -> core.variables + core.workflow.nodes.list_operator.node -> core.variables.segments + core.workflow.nodes.llm.node -> core.variables + core.workflow.nodes.loop.loop_node -> core.variables + core.workflow.nodes.parameter_extractor.entities -> core.variables.types + core.workflow.nodes.parameter_extractor.exc -> core.variables.types + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types + core.workflow.nodes.tool.tool_node -> core.variables.segments + core.workflow.nodes.tool.tool_node -> core.variables.variables + core.workflow.nodes.trigger_webhook.node -> core.variables.types + core.workflow.nodes.trigger_webhook.node -> core.variables.variables + core.workflow.nodes.variable_aggregator.entities -> core.variables.types + core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments + core.workflow.nodes.variable_assigner.common.helpers -> core.variables + core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts + core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types + core.workflow.nodes.variable_assigner.v1.node -> core.variables + core.workflow.nodes.variable_assigner.v2.helpers -> core.variables + core.workflow.nodes.variable_assigner.v2.node -> core.variables + core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts + core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments + core.workflow.runtime.read_only_wrappers -> core.variables.segments + core.workflow.runtime.variable_pool -> core.variables + core.workflow.runtime.variable_pool -> core.variables.consts + core.workflow.runtime.variable_pool -> core.variables.segments + core.workflow.runtime.variable_pool -> core.variables.variables + core.workflow.utils.condition.processor -> core.variables + core.workflow.utils.condition.processor -> core.variables.segments + core.workflow.variable_loader -> core.variables + core.workflow.variable_loader -> core.variables.consts + core.workflow.workflow_type_encoder -> core.variables + core.workflow.graph_engine.manager -> extensions.ext_redis + core.workflow.nodes.agent.agent_node -> extensions.ext_database + core.workflow.nodes.datasource.datasource_node -> extensions.ext_database + core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis + core.workflow.nodes.llm.file_saver -> extensions.ext_database + core.workflow.nodes.llm.llm_utils -> extensions.ext_database + core.workflow.nodes.llm.node -> extensions.ext_database + core.workflow.nodes.tool.tool_node -> extensions.ext_database + core.workflow.workflow_entry -> extensions.otel.runtime + core.workflow.nodes.agent.agent_node -> models + core.workflow.nodes.base.node -> models.enums + core.workflow.nodes.llm.llm_utils -> models.provider_ids + core.workflow.nodes.llm.node -> models.model + core.workflow.workflow_entry -> models.enums + core.workflow.nodes.agent.agent_node -> services + core.workflow.nodes.tool.tool_node -> services + [importlinter:contract:rsc] name = RSC type = layers diff --git a/api/.ruff.toml b/api/.ruff.toml index 7206f7fa0f..8db0cbcb21 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -1,4 +1,8 @@ -exclude = ["migrations/*"] +exclude = [ + "migrations/*", + ".git", + ".git/**", +] line-length = 120 [format] diff --git a/api/AGENTS.md b/api/AGENTS.md index 17398ec4b8..13adb42276 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -1,62 +1,186 @@ -# Agent Skill Index +# API Agent Guide -Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it. +## Notes for Agent (must-check) -______________________________________________________________________ +Before changing any backend code under `api/`, you MUST read the surrounding docstrings and comments. These notes contain required context (invariants, edge cases, trade-offs) and are treated as part of the spec. -## Platform Foundations +Look for: -- **[Infrastructure Overview](agent_skills/infra.md)**\ - When to read this: +- The module (file) docstring at the top of a source code file +- Docstrings on classes and functions/methods +- Paragraph/block comments for non-obvious logic - - You need to understand where a feature belongs in the architecture. - - You’re wiring storage, Redis, vector stores, or OTEL. - - You’re about to add CLI commands or async jobs.\ - What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands. +### What to write where -- **[Coding Style](agent_skills/coding_style.md)**\ - When to read this: +- Keep notes scoped: module notes cover module-wide context, class notes cover class-wide context, function/method notes cover behavioural contracts, and paragraph/block comments cover local “why”. Avoid duplicating the same content across scopes unless repetition prevents misuse. +- **Module (file) docstring**: purpose, boundaries, key invariants, and “gotchas” that a new reader must know before editing. + - Include cross-links to the key collaborators (modules/services) when discovery is otherwise hard. + - Prefer stable facts (invariants, contracts) over ephemeral “today we…” notes. +- **Class docstring**: responsibility, lifecycle, invariants, and how it should be used (or not used). + - If the class is intentionally stateful, note what state exists and what methods mutate it. + - If concurrency/async assumptions matter, state them explicitly. +- **Function/method docstring**: behavioural contract. + - Document arguments, return shape, side effects (DB writes, external I/O, task dispatch), and raised domain exceptions. + - Add examples only when they prevent misuse. +- **Paragraph/block comments**: explain *why* (trade-offs, historical constraints, surprising edge cases), not what the code already states. + - Keep comments adjacent to the logic they justify; delete or rewrite comments that no longer match reality. - - You’re writing or reviewing backend code and need the authoritative checklist. - - You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns. - - You want the exact lint/type/test commands used in PRs.\ - Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management. +### Rules (must follow) -______________________________________________________________________ +In this section, “notes” means module/class/function docstrings plus any relevant paragraph/block comments. -## Plugin & Extension Development +- **Before working** + - Read the notes in the area you’ll touch; treat them as part of the spec. + - If a docstring or comment conflicts with the current code, treat the **code as the single source of truth** and update the docstring or comment to match reality. + - If important intent/invariants/edge cases are missing, add them in the closest docstring or comment (module for overall scope, function for behaviour). +- **During working** + - Keep the notes in sync as you discover constraints, make decisions, or change approach. + - If you move/rename responsibilities across modules/classes, update the affected docstrings and comments so readers can still find the “why” and the invariants. + - Record non-obvious edge cases, trade-offs, and the test/verification plan in the nearest docstring or comment that will stay correct. + - Keep the notes **coherent**: integrate new findings into the relevant docstrings and comments; avoid append-only “recent fix” / changelog-style additions. +- **When finishing** + - Update the notes to reflect what changed, why, and any new edge cases/tests. + - Remove or rewrite any comments that could be mistaken as current guidance but no longer apply. + - Keep docstrings and comments concise and accurate; they are meant to prevent repeated rediscovery. -- **[Plugin Systems](agent_skills/plugin.md)**\ - When to read this: +## Coding Style - - You’re building or debugging a marketplace plugin. - - You need to know how manifests, providers, daemons, and migrations fit together.\ - What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform. +This is the default standard for backend code in this repo. Follow it for new code and use it as the checklist when reviewing changes. -- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\ - When to read this: +### Linting & Formatting - - You must integrate OAuth for a plugin or datasource. - - You’re handling credential encryption or refresh flows.\ - Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows. +- Use Ruff for formatting and linting (follow `.ruff.toml`). +- Keep each line under 120 characters (including spaces). -______________________________________________________________________ +### Naming Conventions -## Workflow Entry & Execution +- Use `snake_case` for variables and functions. +- Use `PascalCase` for classes. +- Use `UPPER_CASE` for constants. -- **[Trigger Concepts](agent_skills/trigger.md)**\ - When to read this: - - You’re debugging why a workflow didn’t start. - - You’re adding a new trigger type or hook. - - You need to trace async execution, draft debugging, or webhook/schedule pipelines.\ - Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions. +### Typing & Class Layout -______________________________________________________________________ +- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values). +- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason. +- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance: -## Additional Notes for Agents +```python +from datetime import datetime -- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes. -- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`). -- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules. -- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`. -- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently. + +class Example: + user_id: str + created_at: datetime + + def __init__(self, user_id: str, created_at: datetime) -> None: + self.user_id = user_id + self.created_at = created_at +``` + +### General Rules + +- Use Pydantic v2 conventions. +- Use `uv` for Python package management in this repo (usually with `--project api`). +- Prefer simple functions over small “utility classes” for lightweight helpers. +- Avoid implementing dunder methods unless it’s clearly needed and matches existing patterns. +- Never start long-running services as part of agent work (`uv run app.py`, `flask run`, etc.); running tests is allowed. +- Keep files below ~800 lines; split when necessary. +- Keep code readable and explicit—avoid clever hacks. + +### Architecture & Boundaries + +- Mirror the layered architecture: controller → service → core/domain. +- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions. +- Optimise for observability: deterministic control flow, clear logging, actionable errors. + +### Logging & Errors + +- Never use `print`; use a module-level logger: + - `logger = logging.getLogger(__name__)` +- Include tenant/app/workflow identifiers in log context when relevant. +- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate them into HTTP responses in controllers. +- Log retryable events at `warning`, terminal failures at `error`. + +### SQLAlchemy Patterns + +- Models inherit from `models.base.TypeBase`; do not create ad-hoc metadata or engines. +- Open sessions with context managers: + +```python +from sqlalchemy.orm import Session + +with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Workflow).where( + Workflow.id == workflow_id, + Workflow.tenant_id == tenant_id, + ) + workflow = session.execute(stmt).scalar_one_or_none() +``` + +- Prefer SQLAlchemy expressions; avoid raw SQL unless necessary. +- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.). +- Introduce repository abstractions only for very large tables (e.g., workflow executions) or when alternative storage strategies are required. + +### Storage & External I/O + +- Access storage via `extensions.ext_storage.storage`. +- Use `core.helper.ssrf_proxy` for outbound HTTP fetches. +- Background tasks that touch storage must be idempotent, and should log relevant object identifiers. + +### Pydantic Usage + +- Define DTOs with Pydantic v2 models and forbid extras by default. +- Use `@field_validator` / `@model_validator` for domain rules. + +Example: + +```python +from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator + + +class TriggerConfig(BaseModel): + endpoint: HttpUrl + secret: str + + model_config = ConfigDict(extra="forbid") + + @field_validator("secret") + def ensure_secret_prefix(cls, value: str) -> str: + if not value.startswith("dify_"): + raise ValueError("secret must start with dify_") + return value +``` + +### Generics & Protocols + +- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces). +- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers. +- Validate dynamic inputs at runtime when generics cannot enforce safety alone. + +### Tooling & Checks + +Quick checks while iterating: + +- Format: `make format` +- Lint (includes auto-fix): `make lint` +- Type check: `make type-check` +- Targeted tests: `make test TARGET_TESTS=./api/tests/` + +Before opening a PR / submitting: + +- `make lint` +- `make type-check` +- `make test` + +### Controllers & Services + +- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic. +- Services: coordinate repositories, providers, background tasks; keep side effects explicit. +- Document non-obvious behaviour with concise docstrings and comments. + +### Miscellaneous + +- Use `configs.dify_config` for configuration—never read environment variables directly. +- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources. +- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection. +- Keep experimental scripts under `dev/`; do not ship them in production builds. diff --git a/api/Dockerfile b/api/Dockerfile index 02df91bfc1..a08d4e3aab 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -50,16 +50,33 @@ WORKDIR /app/api # Create non-root user ARG dify_uid=1001 +ARG NODE_MAJOR=22 +ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1 +ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4 RUN groupadd -r -g ${dify_uid} dify && \ useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \ chown -R dify:dify /app RUN \ apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + gnupg \ + && mkdir -p /etc/apt/keyrings \ + && curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \ + && gpg --show-keys --with-colons /tmp/nodesource.gpg \ + | awk -F: '/^fpr:/ {print $10}' \ + | grep -Fx "${NODESOURCE_KEY_FPR}" \ + && gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \ + && rm -f /tmp/nodesource.gpg \ + && echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \ + > /etc/apt/sources.list.d/nodesource.list \ + && apt-get update \ # Install dependencies && apt-get install -y --no-install-recommends \ # basic environment - curl nodejs \ + nodejs=${NODE_PACKAGE_VERSION} \ # for gmpy2 \ libgmp-dev libmpfr-dev libmpc-dev \ # For Security @@ -79,7 +96,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV} ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Download nltk data -RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \ +RUN mkdir -p /usr/local/share/nltk_data \ + && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \ && chmod -R 755 /usr/local/share/nltk_data ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache diff --git a/api/README.md b/api/README.md index 794b05d3af..9d89b490b0 100644 --- a/api/README.md +++ b/api/README.md @@ -1,6 +1,6 @@ # Dify Backend API -## Usage +## Setup and Run > [!IMPORTANT] > @@ -8,48 +8,77 @@ > [`uv`](https://docs.astral.sh/uv/) as the package manager > for Dify API backend service. -1. Start the docker-compose stack +`uv` and `pnpm` are required to run the setup and development commands below. - The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`. +### Using scripts (recommended) + +The scripts resolve paths relative to their location, so you can run them from anywhere. + +1. Run setup (copies env files and installs dependencies). ```bash - cd ../docker - cp middleware.env.example middleware.env - # change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate - docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d - cd ../api + ./dev/setup ``` -1. Copy `.env.example` to `.env` +1. Review `api/.env`, `web/.env.local`, and `docker/middleware.env` values (see the `SECRET_KEY` note below). - ```cli - cp .env.example .env +1. Start middleware (PostgreSQL/Redis/Weaviate). + + ```bash + ./dev/start-docker-compose ``` -> [!IMPORTANT] -> -> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies. +1. Start backend (runs migrations first). -1. Generate a `SECRET_KEY` in the `.env` file. - - bash for Linux - - ```bash for Linux - sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env + ```bash + ./dev/start-api ``` - bash for Mac +1. Start Dify [web](../web) service. - ```bash for Mac - secret_key=$(openssl rand -base64 42) - sed -i '' "/^SECRET_KEY=/c\\ - SECRET_KEY=${secret_key}" .env + ```bash + ./dev/start-web ``` -1. Create environment. +1. Set up your application by visiting `http://localhost:3000`. - Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies. - First, you need to add the uv package manager, if you don't have it already. +1. Optional: start the worker service (async tasks, runs from `api`). + + ```bash + ./dev/start-worker + ``` + +1. Optional: start Celery Beat (scheduled tasks). + + ```bash + ./dev/start-beat + ``` + +### Manual commands + +
+Show manual setup and run steps + +These commands assume you start from the repository root. + +1. Start the docker-compose stack. + + The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`. + + ```bash + cp docker/middleware.env.example docker/middleware.env + # Use mysql or another vector database profile if you are not using postgres/weaviate. + docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d + ``` + +1. Copy env files. + + ```bash + cp api/.env.example api/.env + cp web/.env.example web/.env.local + ``` + +1. Install UV if needed. ```bash pip install uv @@ -57,60 +86,96 @@ brew install uv ``` -1. Install dependencies +1. Install API dependencies. ```bash - uv sync --dev + cd api + uv sync --group dev ``` -1. Run migrate - - Before the first launch, migrate the database to the latest version. +1. Install web dependencies. ```bash + cd web + pnpm install + cd .. + ``` + +1. Start backend (runs migrations first, in a new terminal). + + ```bash + cd api uv run flask db upgrade - ``` - -1. Start backend - - ```bash uv run flask run --host 0.0.0.0 --port=5001 --debug ``` -1. Start Dify [web](../web) service. +1. Start Dify [web](../web) service (in a new terminal). -1. Setup your application by visiting `http://localhost:3000`. + ```bash + cd web + pnpm dev:inspect + ``` -1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. +1. Set up your application by visiting `http://localhost:3000`. -```bash -uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention -``` +1. Optional: start the worker service (async tasks, in a new terminal). -Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: + ```bash + cd api + uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention + ``` -```bash -uv run celery -A app.celery beat -``` +1. Optional: start Celery Beat (scheduled tasks, in a new terminal). + + ```bash + cd api + uv run celery -A app.celery beat + ``` + +
+ +### Environment notes + +> [!IMPORTANT] +> +> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies. + +- Generate a `SECRET_KEY` in the `.env` file. + + bash for Linux + + ```bash + sed -i "/^SECRET_KEY=/c\\SECRET_KEY=$(openssl rand -base64 42)" .env + ``` + + bash for Mac + + ```bash + secret_key=$(openssl rand -base64 42) + sed -i '' "/^SECRET_KEY=/c\\ + SECRET_KEY=${secret_key}" .env + ``` ## Testing 1. Install dependencies for both the backend and the test environment ```bash - uv sync --dev + cd api + uv sync --group dev ``` 1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md) ```bash + cd api uv run pytest # Run all tests uv run pytest tests/unit_tests/ # Unit tests only uv run pytest tests/integration_tests/ # Integration tests # Code quality - ../dev/reformat # Run all formatters and linters - uv run ruff check --fix ./ # Fix linting issues - uv run ruff format ./ # Format code - uv run basedpyright . # Type checking + ./dev/reformat # Run all formatters and linters + uv run ruff check --fix ./ # Fix linting issues + uv run ruff format ./ # Format code + uv run basedpyright . # Type checking ``` diff --git a/api/agent_skills/coding_style.md b/api/agent_skills/coding_style.md deleted file mode 100644 index a2b66f0bd5..0000000000 --- a/api/agent_skills/coding_style.md +++ /dev/null @@ -1,115 +0,0 @@ -## Linter - -- Always follow `.ruff.toml`. -- Run `uv run ruff check --fix --unsafe-fixes`. -- Keep each line under 100 characters (including spaces). - -## Code Style - -- `snake_case` for variables and functions. -- `PascalCase` for classes. -- `UPPER_CASE` for constants. - -## Rules - -- Use Pydantic v2 standard. -- Use `uv` for package management. -- Do not override dunder methods like `__init__`, `__iadd__`, etc. -- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed. -- Prefer simple functions over classes for lightweight helpers. -- Keep files below 800 lines; split when necessary. -- Keep code readable—no clever hacks. -- Never use `print`; log with `logger = logging.getLogger(__name__)`. - -## Guiding Principles - -- Mirror the project’s layered architecture: controller → service → core/domain. -- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions. -- Optimise for observability: deterministic control flow, clear logging, actionable errors. - -## SQLAlchemy Patterns - -- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines. - -- Open sessions with context managers: - - ```python - from sqlalchemy.orm import Session - - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Workflow).where( - Workflow.id == workflow_id, - Workflow.tenant_id == tenant_id, - ) - workflow = session.execute(stmt).scalar_one_or_none() - ``` - -- Use SQLAlchemy expressions; avoid raw SQL unless necessary. - -- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies. - -- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.). - -## Storage & External IO - -- Access storage via `extensions.ext_storage.storage`. -- Use `core.helper.ssrf_proxy` for outbound HTTP fetches. -- Background tasks that touch storage must be idempotent and log the relevant object identifiers. - -## Pydantic Usage - -- Define DTOs with Pydantic v2 models and forbid extras by default. - -- Use `@field_validator` / `@model_validator` for domain rules. - -- Example: - - ```python - from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator - - class TriggerConfig(BaseModel): - endpoint: HttpUrl - secret: str - - model_config = ConfigDict(extra="forbid") - - @field_validator("secret") - def ensure_secret_prefix(cls, value: str) -> str: - if not value.startswith("dify_"): - raise ValueError("secret must start with dify_") - return value - ``` - -## Generics & Protocols - -- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces). -- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers. -- Validate dynamic inputs at runtime when generics cannot enforce safety alone. - -## Error Handling & Logging - -- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers. -- Declare `logger = logging.getLogger(__name__)` at module top. -- Include tenant/app/workflow identifiers in log context. -- Log retryable events at `warning`, terminal failures at `error`. - -## Tooling & Checks - -- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`. -- Type checks: `uv run --directory api --dev basedpyright`. -- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`. -- Run all of the above before submitting your work. - -## Controllers & Services - -- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic. -- Services: coordinate repositories, providers, background tasks; keep side effects explicit. -- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables. -- Document non-obvious behaviour with concise comments. - -## Miscellaneous - -- Use `configs.dify_config` for configuration—never read environment variables directly. -- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources. -- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection. -- Keep experimental scripts under `dev/`; do not ship them in production builds. diff --git a/api/agent_skills/infra.md b/api/agent_skills/infra.md deleted file mode 100644 index bc36c7bf64..0000000000 --- a/api/agent_skills/infra.md +++ /dev/null @@ -1,96 +0,0 @@ -## Configuration - -- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly. -- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`. -- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing. -- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`. - -## Dependencies - -- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`. -- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group. -- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current. - -## Storage & Files - -- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend. -- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads. -- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly. -- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform. - -## Redis & Shared State - -- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`. -- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`. - -## Models - -- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`). -- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn. -- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories. -- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below. - -## Vector Stores - -- Vector client implementations live in `core/rag/datasource/vdb/`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`. -- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`. -- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions. -- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations. - -## Observability & OTEL - -- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads. -- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints. -- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`). -- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`. - -## Ops Integrations - -- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above. -- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules. -- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata. - -## Controllers, Services, Core - -- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`. -- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs). -- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`. - -## Plugins, Tools, Providers - -- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation. -- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`. -- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way. -- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application. -- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config). -- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly. - -## Async Workloads - -see `agent_skills/trigger.md` for more detailed documentation. - -- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`. -- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc. -- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs. - -## Database & Migrations - -- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`. -- Generate migrations with `uv run --project api flask db revision --autogenerate -m ""`, then review the diff; never hand-edit the database outside Alembic. -- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history. -- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables. - -## CLI Commands - -- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask `. -- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour. -- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations. -- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR. -- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes). - -## When You Add Features - -- Check for an existing helper or service before writing a new util. -- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`. -- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations). -- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes. diff --git a/api/agent_skills/plugin.md b/api/agent_skills/plugin.md deleted file mode 100644 index 954ddd236b..0000000000 --- a/api/agent_skills/plugin.md +++ /dev/null @@ -1 +0,0 @@ -// TBD diff --git a/api/agent_skills/plugin_oauth.md b/api/agent_skills/plugin_oauth.md deleted file mode 100644 index 954ddd236b..0000000000 --- a/api/agent_skills/plugin_oauth.md +++ /dev/null @@ -1 +0,0 @@ -// TBD diff --git a/api/agent_skills/trigger.md b/api/agent_skills/trigger.md deleted file mode 100644 index f4b076332c..0000000000 --- a/api/agent_skills/trigger.md +++ /dev/null @@ -1,53 +0,0 @@ -## Overview - -Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node. - -## Trigger nodes - -- `UserInput` -- `Trigger Webhook` -- `Trigger Schedule` -- `Trigger Plugin` - -### UserInput - -Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app` - -1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool. -1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node. -1. For its detailed implementation, please refer to `core/workflow/nodes/start` - -### Trigger Webhook - -Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`. - -Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution. - -### Trigger Schedule - -`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help. - -To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published. - -### Trigger Plugin - -`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it. - -1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint` -1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details. - -A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one. - -## Worker Pool / Async Task - -All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`. - -The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`. - -## Debug Strategy - -Dify divided users into 2 groups: builders / end users. - -Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`. - -A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type. diff --git a/api/app_factory.py b/api/app_factory.py index bcad88e9e0..07859a3758 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -2,9 +2,11 @@ import logging import time from opentelemetry.trace import get_current_span +from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID from configs import dify_config from contexts.wrapper import RecyclableContextVar +from core.logging.context import init_request_context from dify_app import DifyApp logger = logging.getLogger(__name__) @@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp: # add before request hook @dify_app.before_request def before_request(): - # add an unique identifier to each request + # Initialize logging context for this request + init_request_context() RecyclableContextVar.increment_thread_recycles() - # add after request hook for injecting X-Trace-Id header from OpenTelemetry span context + # add after request hook for injecting trace headers from OpenTelemetry span context + # Only adds headers when OTEL is enabled and has valid context @dify_app.after_request - def add_trace_id_header(response): + def add_trace_headers(response): try: span = get_current_span() ctx = span.get_span_context() if span else None - if ctx and ctx.is_valid: - trace_id_hex = format(ctx.trace_id, "032x") - # Avoid duplicates if some middleware added it - if "X-Trace-Id" not in response.headers: - response.headers["X-Trace-Id"] = trace_id_hex + + if not ctx or not ctx.is_valid: + return response + + # Inject trace headers from OTEL context + if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers: + response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x") + if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers: + response.headers["X-Span-Id"] = format(ctx.span_id, "016x") + except Exception: # Never break the response due to tracing header injection - logger.warning("Failed to add trace ID to response header", exc_info=True) + logger.warning("Failed to add trace headers to response", exc_info=True) return response # Capture the decorator's return value to avoid pyright reportUnusedFunction _ = before_request - _ = add_trace_id_header + _ = add_trace_headers return dify_app @@ -62,6 +71,8 @@ def create_app() -> DifyApp: def initialize_extensions(app: DifyApp): + # Initialize Flask context capture for workflow execution + from context.flask_app_context import init_flask_context from extensions import ( ext_app_metrics, ext_blueprints, @@ -70,6 +81,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_fastopenapi, ext_forward_refs, ext_hosting_provider, ext_import_modules, @@ -91,6 +103,8 @@ def initialize_extensions(app: DifyApp): ext_warnings, ) + init_flask_context() + extensions = [ ext_timezone, ext_logging, @@ -115,6 +129,7 @@ def initialize_extensions(app: DifyApp): ext_proxy_fix, ext_blueprints, ext_commands, + ext_fastopenapi, ext_otel, ext_request_logging, ext_session_factory, diff --git a/api/commands.py b/api/commands.py index a8d89ac200..3d68de4cb4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,7 +1,9 @@ import base64 +import datetime import json import logging import secrets +import time from typing import Any import click @@ -34,7 +36,7 @@ from libs.rsa import generate_key_pair from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile +from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.provider import Provider, ProviderModel from models.provider_ids import DatasourceProviderID, ToolProviderID @@ -45,6 +47,9 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_service import PluginService +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup from tasks.remove_app_and_related_data_task import delete_draft_variables_batch logger = logging.getLogger(__name__) @@ -62,8 +67,10 @@ def reset_password(email, new_password, password_confirm): if str(new_password).strip() != str(password_confirm).strip(): click.echo(click.style("Passwords do not match.", fg="red")) return + normalized_email = email.strip().lower() + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = session.query(Account).where(Account.email == email).one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) @@ -84,7 +91,7 @@ def reset_password(email, new_password, password_confirm): base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(email) + AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -100,20 +107,22 @@ def reset_email(email, new_email, email_confirm): if str(new_email).strip() != str(email_confirm).strip(): click.echo(click.style("New emails do not match.", fg="red")) return + normalized_new_email = new_email.strip().lower() + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = session.query(Account).where(Account.email == email).one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) return try: - email_validate(new_email) + email_validate(normalized_new_email) except: click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account.email = new_email + account.email = normalized_new_email click.echo(click.style("Email updated successfully.", fg="green")) @@ -235,7 +244,7 @@ def migrate_annotation_vector_database(): if annotations: for annotation in annotations: document = Document( - page_content=annotation.question, + page_content=annotation.question_text, metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, ) documents.append(document) @@ -658,7 +667,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No return # Create account - email = email.strip() + email = email.strip().lower() if "@" not in email: click.echo(click.style("Invalid email address.", fg="red")) @@ -852,6 +861,435 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[ click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) +@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") +@click.option( + "--before-days", + "--days", + default=30, + show_default=True, + type=click.IntRange(min=0), + help="Delete workflow runs created before N days ago.", +) +@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option( + "--dry-run", + is_flag=True, + help="Preview cleanup results without deleting any workflow run data.", +) +def clean_workflow_runs( + before_days: int, + batch_size: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + dry_run: bool, +): + """ + Clean workflow runs and related workflow data for free tenants. + """ + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + + if (from_days_ago is None) ^ (to_days_ago is None): + raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.") + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + raise click.UsageError("Choose either day offsets or explicit dates, not both.") + if from_days_ago <= to_days_ago: + raise click.UsageError("--from-days-ago must be greater than --to-days-ago.") + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + start_time = datetime.datetime.now(datetime.UTC) + click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white")) + + WorkflowRunCleanup( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + dry_run=dry_run, + ).run() + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Workflow run cleanup completed. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "archive-workflow-runs", + help="Archive workflow runs for paid plan tenants to S3-compatible storage.", +) +@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") +@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created at or after this timestamp (UTC if no timezone).", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created before this timestamp (UTC if no timezone).", +) +@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") +@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") +@click.option("--dry-run", is_flag=True, help="Preview without archiving.") +@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") +def archive_workflow_runs( + tenant_ids: str | None, + before_days: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + batch_size: int, + workers: int, + limit: int | None, + dry_run: bool, + delete_after_archive: bool, +): + """ + Archive workflow runs for paid plan tenants older than the specified days. + + This command archives the following tables to storage: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + + The workflow_runs and workflow_app_logs tables are preserved for UI listing. + """ + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + run_started_at = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting workflow run archiving at {run_started_at.isoformat()}.", + fg="white", + ) + ) + + if (start_from is None) ^ (end_before is None): + click.echo(click.style("start-from and end-before must be provided together.", fg="red")) + return + + if (from_days_ago is None) ^ (to_days_ago is None): + click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) + return + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) + return + if from_days_ago <= to_days_ago: + click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) + return + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + if start_from and end_before and start_from >= end_before: + click.echo(click.style("start-from must be earlier than end-before.", fg="red")) + return + if workers < 1: + click.echo(click.style("workers must be at least 1.", fg="red")) + return + + archiver = WorkflowRunArchiver( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + workers=workers, + tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, + limit=limit, + dry_run=dry_run, + delete_after_archive=delete_after_archive, + ) + summary = archiver.run() + click.echo( + click.style( + f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="cyan", + ) + ) + + run_finished_at = datetime.datetime.now(datetime.UTC) + elapsed = run_finished_at - run_started_at + click.echo( + click.style( + f"Workflow run archiving completed. start={run_started_at.isoformat()} " + f"end={run_finished_at.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "restore-workflow-runs", + help="Restore archived workflow runs from S3-compatible storage.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to restore.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") +@click.option("--dry-run", is_flag=True, help="Preview without restoring.") +def restore_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + workers: int, + limit: int, + dry_run: bool, +): + """ + Restore an archived workflow run from storage to the database. + + This restores the following tables: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + """ + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch restore.") + if workers < 1: + raise click.BadParameter("workers must be at least 1") + + start_time = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.", + fg="white", + ) + ) + + restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) + if run_id: + results = [restorer.restore_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = restorer.restore_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Restore completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + +@click.command( + "delete-archived-workflow-runs", + help="Delete archived workflow runs from the database.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to delete.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") +@click.option("--dry-run", is_flag=True, help="Preview without deleting.") +def delete_archived_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + limit: int, + dry_run: bool, +): + """ + Delete archived workflow runs from the database. + """ + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch delete.") + + start_time = datetime.datetime.now(datetime.UTC) + target_desc = f"workflow run {run_id}" if run_id else "workflow runs" + click.echo( + click.style( + f"Starting delete of {target_desc} at {start_time.isoformat()}.", + fg="white", + ) + ) + + deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) + if run_id: + results = [deleter.delete_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = deleter.delete_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + for result in results: + if result.success: + click.echo( + click.style( + f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " + f"workflow run {result.run_id} (tenant={result.tenant_id})", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Failed to delete workflow run {result.run_id}: {result.error}", + fg="red", + ) + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Delete completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + @click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") @click.command("clear-orphaned-file-records", help="Clear orphaned file records.") def clear_orphaned_file_records(force: bool): @@ -1184,6 +1622,217 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) +@click.command("file-usage", help="Query file usages and show where files are referenced.") +@click.option("--file-id", type=str, default=None, help="Filter by file UUID.") +@click.option("--key", type=str, default=None, help="Filter by storage key.") +@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').") +@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).") +@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).") +@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.") +def file_usage( + file_id: str | None, + key: str | None, + src: str | None, + limit: int, + offset: int, + output_json: bool, +): + """ + Query file usages and show where files are referenced in the database. + + This command reuses the same reference checking logic as clear-orphaned-file-records + and displays detailed information about where each file is referenced. + """ + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "id_column": "id", "key_column": "key"}, + {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, + ] + ids_tables = [ + {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"}, + {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"}, + {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"}, + {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"}, + {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"}, + {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"}, + {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "message", "pk_column": "id"}, + ] + + # Stream file usages with pagination to avoid holding all results in memory + paginated_usages = [] + total_count = 0 + + # First, build a mapping of file_id -> storage_key from the base tables + file_key_map = {} + for files_table in files_tables: + query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}" + + # If filtering by key or file_id, verify it exists + if file_id and file_id not in file_key_map: + if output_json: + click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"})) + else: + click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red")) + return + + if key: + valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"} + matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes] + if not matching_file_ids: + if output_json: + click.echo(json.dumps({"error": f"Key {key} not found in base tables"})) + else: + click.echo(click.style(f"Key {key} not found in base tables.", fg="red")) + return + + guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + + # For each reference table/column, find matching file IDs and record the references + for ids_table in ids_tables: + src_filter = f"{ids_table['table']}.{ids_table['column']}" + + # Skip if src filter doesn't match (use fnmatch for wildcard patterns) + if src: + if "%" in src or "_" in src: + import fnmatch + + # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?) + pattern = src.replace("%", "*").replace("_", "?") + if not fnmatch.fnmatch(src_filter, pattern): + continue + else: + if src_filter != src: + continue + + if ids_table["type"] == "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + + elif ids_table["type"] in ("text", "json"): + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + + # Output results + if output_json: + result = { + "total": total_count, + "offset": offset, + "limit": limit, + "usages": paginated_usages, + } + click.echo(json.dumps(result, indent=2)) + else: + click.echo( + click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white") + ) + click.echo("") + + if not paginated_usages: + click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow")) + return + + # Print table header + click.echo( + click.style( + f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}", + fg="cyan", + ) + ) + click.echo(click.style("-" * 190, fg="white")) + + # Print each usage + for usage in paginated_usages: + click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}") + + # Show pagination info + if offset + limit < total_count: + click.echo("") + click.echo( + click.style( + f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white" + ) + ) + click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white")) + + @click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") @click.option("--provider", prompt=True, help="Provider name") @click.option("--client-params", prompt=True, help="Client Params") @@ -1900,3 +2549,79 @@ def migrate_oss( except Exception as e: db.session.rollback() click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) + + +@click.command("clean-expired-messages", help="Clean expired messages.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Lower bound (inclusive) for created_at.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Upper bound (exclusive) for created_at.", +) +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") +@click.option( + "--graceful-period", + default=21, + show_default=True, + help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.", +) +@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting") +def clean_expired_messages( + batch_size: int, + graceful_period: int, + start_from: datetime.datetime, + end_before: datetime.datetime, + dry_run: bool, +): + """ + Clean expired messages and related data for tenants based on clean policy. + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + + start_at = time.perf_counter() + + try: + # Create policy based on billing configuration + # NOTE: graceful_period will be ignored when billing is disabled. + policy = create_message_clean_policy(graceful_period_days=graceful_period) + + # Create and run the cleanup service + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise + + click.echo(click.style("messages cleanup completed.", fg="green")) diff --git a/api/configs/extra/__init__.py b/api/configs/extra/__init__.py index 4543b5389d..de97adfc0e 100644 --- a/api/configs/extra/__init__.py +++ b/api/configs/extra/__init__.py @@ -1,9 +1,11 @@ +from configs.extra.archive_config import ArchiveStorageConfig from configs.extra.notion_config import NotionConfig from configs.extra.sentry_config import SentryConfig class ExtraServiceConfig( # place the configs in alphabet order + ArchiveStorageConfig, NotionConfig, SentryConfig, ): diff --git a/api/configs/extra/archive_config.py b/api/configs/extra/archive_config.py new file mode 100644 index 0000000000..a85628fa61 --- /dev/null +++ b/api/configs/extra/archive_config.py @@ -0,0 +1,43 @@ +from pydantic import Field +from pydantic_settings import BaseSettings + + +class ArchiveStorageConfig(BaseSettings): + """ + Configuration settings for workflow run logs archiving storage. + """ + + ARCHIVE_STORAGE_ENABLED: bool = Field( + description="Enable workflow run logs archiving to S3-compatible storage", + default=False, + ) + + ARCHIVE_STORAGE_ENDPOINT: str | None = Field( + description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')", + default=None, + ) + + ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field( + description="Name of the bucket to store archived workflow logs", + default=None, + ) + + ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field( + description="Name of the bucket to store exported workflow runs", + default=None, + ) + + ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field( + description="Access key ID for authenticating with storage", + default=None, + ) + + ARCHIVE_STORAGE_SECRET_KEY: str | None = Field( + description="Secret access key for authenticating with storage", + default=None, + ) + + ARCHIVE_STORAGE_REGION: str = Field( + description="Region for storage (use 'auto' if the provider supports it)", + default="auto", + ) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 0e20cc9f9e..f427466c1e 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -604,6 +604,11 @@ class LoggingConfig(BaseSettings): default="INFO", ) + LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field( + description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.", + default="text", + ) + LOG_FILE: str | None = Field( description="File path for log output.", default=None, @@ -961,6 +966,12 @@ class MailConfig(BaseSettings): default=False, ) + SMTP_LOCAL_HOSTNAME: str | None = Field( + description="Override the local hostname used in SMTP HELO/EHLO. " + "Useful behind NAT or when the default hostname causes rejections.", + default=None, + ) + EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field( description="Maximum number of emails allowed to be sent from the same IP address in a minute", default=50, @@ -971,6 +982,16 @@ class MailConfig(BaseSettings): default=None, ) + ENABLE_TRIAL_APP: bool = Field( + description="Enable trial app", + default=False, + ) + + ENABLE_EXPLORE_BANNER: bool = Field( + description="Enable explore banner", + default=False, + ) + class RagEtlConfig(BaseSettings): """ @@ -1113,6 +1134,10 @@ class CeleryScheduleTasksConfig(BaseSettings): description="Enable clean messages task", default=False, ) + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field( + description="Enable scheduled workflow run cleanup task", + default=False, + ) ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field( description="Enable mail clean document notify task", default=False, @@ -1308,6 +1333,10 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings): description="Retention days for sandbox expired workflow_run records and message records", default=30, ) + SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field( + description="Lock TTL for sandbox expired records clean task in seconds", + default=90000, + ) class FeatureConfig( diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 4ad30014c7..42ede718c4 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings): default="", ) + HOSTED_POOL_CREDITS: int = Field( + description="Pool credits for hosted service", + default=200, + ) + def get_model_credits(self, model_name: str) -> int: """ Get credit value for a specific model name. @@ -60,19 +65,46 @@ class HostedOpenAiConfig(BaseSettings): HOSTED_OPENAI_TRIAL_MODELS: str = Field( description="Comma-separated list of available models for trial access", - default="gpt-3.5-turbo," - "gpt-3.5-turbo-1106," - "gpt-3.5-turbo-instruct," + default="gpt-4," + "gpt-4-turbo-preview," + "gpt-4-turbo-2024-04-09," + "gpt-4-1106-preview," + "gpt-4-0125-preview," + "gpt-4-turbo," + "gpt-4.1," + "gpt-4.1-2025-04-14," + "gpt-4.1-mini," + "gpt-4.1-mini-2025-04-14," + "gpt-4.1-nano," + "gpt-4.1-nano-2025-04-14," + "gpt-3.5-turbo," "gpt-3.5-turbo-16k," "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-1106," "gpt-3.5-turbo-0613," "gpt-3.5-turbo-0125," - "text-davinci-003", - ) - - HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description="Quota limit for hosted OpenAI service usage", - default=200, + "gpt-3.5-turbo-instruct," + "text-davinci-003," + "chatgpt-4o-latest," + "gpt-4o," + "gpt-4o-2024-05-13," + "gpt-4o-2024-08-06," + "gpt-4o-2024-11-20," + "gpt-4o-audio-preview," + "gpt-4o-audio-preview-2025-06-03," + "gpt-4o-mini," + "gpt-4o-mini-2024-07-18," + "o3-mini," + "o3-mini-2025-01-31," + "gpt-5-mini-2025-08-07," + "gpt-5-mini," + "o4-mini," + "o4-mini-2025-04-16," + "gpt-5-chat-latest," + "gpt-5," + "gpt-5-2025-08-07," + "gpt-5-nano," + "gpt-5-nano-2025-08-07", ) HOSTED_OPENAI_PAID_ENABLED: bool = Field( @@ -87,6 +119,13 @@ class HostedOpenAiConfig(BaseSettings): "gpt-4-turbo-2024-04-09," "gpt-4-1106-preview," "gpt-4-0125-preview," + "gpt-4-turbo," + "gpt-4.1," + "gpt-4.1-2025-04-14," + "gpt-4.1-mini," + "gpt-4.1-mini-2025-04-14," + "gpt-4.1-nano," + "gpt-4.1-nano-2025-04-14," "gpt-3.5-turbo," "gpt-3.5-turbo-16k," "gpt-3.5-turbo-16k-0613," @@ -94,7 +133,150 @@ class HostedOpenAiConfig(BaseSettings): "gpt-3.5-turbo-0613," "gpt-3.5-turbo-0125," "gpt-3.5-turbo-instruct," - "text-davinci-003", + "text-davinci-003," + "chatgpt-4o-latest," + "gpt-4o," + "gpt-4o-2024-05-13," + "gpt-4o-2024-08-06," + "gpt-4o-2024-11-20," + "gpt-4o-audio-preview," + "gpt-4o-audio-preview-2025-06-03," + "gpt-4o-mini," + "gpt-4o-mini-2024-07-18," + "o3-mini," + "o3-mini-2025-01-31," + "gpt-5-mini-2025-08-07," + "gpt-5-mini," + "o4-mini," + "o4-mini-2025-04-16," + "gpt-5-chat-latest," + "gpt-5," + "gpt-5-2025-08-07," + "gpt-5-nano," + "gpt-5-nano-2025-08-07", + ) + + +class HostedGeminiConfig(BaseSettings): + """ + Configuration for fetching Gemini service + """ + + HOSTED_GEMINI_API_KEY: str | None = Field( + description="API key for hosted Gemini service", + default=None, + ) + + HOSTED_GEMINI_API_BASE: str | None = Field( + description="Base URL for hosted Gemini API", + default=None, + ) + + HOSTED_GEMINI_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted Gemini service", + default=None, + ) + + HOSTED_GEMINI_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted Gemini service", + default=False, + ) + + HOSTED_GEMINI_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,", + ) + + HOSTED_GEMINI_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted gemini service", + default=False, + ) + + HOSTED_GEMINI_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,", + ) + + +class HostedXAIConfig(BaseSettings): + """ + Configuration for fetching XAI service + """ + + HOSTED_XAI_API_KEY: str | None = Field( + description="API key for hosted XAI service", + default=None, + ) + + HOSTED_XAI_API_BASE: str | None = Field( + description="Base URL for hosted XAI API", + default=None, + ) + + HOSTED_XAI_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted XAI service", + default=None, + ) + + HOSTED_XAI_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted XAI service", + default=False, + ) + + HOSTED_XAI_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="grok-3,grok-3-mini,grok-3-mini-fast", + ) + + HOSTED_XAI_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted XAI service", + default=False, + ) + + HOSTED_XAI_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="grok-3,grok-3-mini,grok-3-mini-fast", + ) + + +class HostedDeepseekConfig(BaseSettings): + """ + Configuration for fetching Deepseek service + """ + + HOSTED_DEEPSEEK_API_KEY: str | None = Field( + description="API key for hosted Deepseek service", + default=None, + ) + + HOSTED_DEEPSEEK_API_BASE: str | None = Field( + description="Base URL for hosted Deepseek API", + default=None, + ) + + HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted Deepseek service", + default=None, + ) + + HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted Deepseek service", + default=False, + ) + + HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="deepseek-chat,deepseek-reasoner", + ) + + HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted Deepseek service", + default=False, + ) + + HOSTED_DEEPSEEK_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="deepseek-chat,deepseek-reasoner", ) @@ -144,16 +326,66 @@ class HostedAnthropicConfig(BaseSettings): default=False, ) - HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( - description="Quota limit for hosted Anthropic service usage", - default=600000, - ) - HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( description="Enable paid access to hosted Anthropic service", default=False, ) + HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="claude-opus-4-20250514," + "claude-sonnet-4-20250514," + "claude-3-5-haiku-20241022," + "claude-3-opus-20240229," + "claude-3-7-sonnet-20250219," + "claude-3-haiku-20240307", + ) + HOSTED_ANTHROPIC_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="claude-opus-4-20250514," + "claude-sonnet-4-20250514," + "claude-3-5-haiku-20241022," + "claude-3-opus-20240229," + "claude-3-7-sonnet-20250219," + "claude-3-haiku-20240307", + ) + + +class HostedTongyiConfig(BaseSettings): + """ + Configuration for hosted Tongyi service + """ + + HOSTED_TONGYI_API_KEY: str | None = Field( + description="API key for hosted Tongyi service", + default=None, + ) + + HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT: bool = Field( + description="Use international endpoint for hosted Tongyi service", + default=False, + ) + + HOSTED_TONGYI_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted Tongyi service", + default=False, + ) + + HOSTED_TONGYI_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted Anthropic service", + default=False, + ) + + HOSTED_TONGYI_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="", + ) + + HOSTED_TONGYI_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="", + ) + class HostedMinmaxConfig(BaseSettings): """ @@ -246,9 +478,13 @@ class HostedServiceConfig( HostedOpenAiConfig, HostedSparkConfig, HostedZhipuAIConfig, + HostedTongyiConfig, # moderation HostedModerationConfig, # credit config HostedCreditConfig, + HostedGeminiConfig, + HostedXAIConfig, + HostedDeepseekConfig, ): pass diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index e297e748e9..cdd10740f8 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings): description="Protocol scheme for COS requests: 'https' (recommended) or 'http'", default=None, ) + + TENCENT_COS_CUSTOM_DOMAIN: str | None = Field( + description="Tencent Cloud COS custom domain setting", + default=None, + ) diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py index be01f2dc36..2a35300401 100644 --- a/api/configs/middleware/storage/volcengine_tos_storage_config.py +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings class VolcengineTOSStorageConfig(BaseSettings): """ - Configuration settings for Volcengine Tinder Object Storage (TOS) + Configuration settings for Volcengine Torch Object Storage (TOS) """ VOLCENGINE_TOS_BUCKET_NAME: str | None = Field( diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 05cee51cc9..eb9b0ac2ab 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings): description="Authentication token for Milvus, if token-based authentication is enabled", default=None, ) - MILVUS_USER: str | None = Field( description="Username for authenticating with Milvus, if username/password authentication is enabled", default=None, diff --git a/api/context/__init__.py b/api/context/__init__.py new file mode 100644 index 0000000000..aebf9750ce --- /dev/null +++ b/api/context/__init__.py @@ -0,0 +1,74 @@ +""" +Core Context - Framework-agnostic context management. + +This module provides context management that is independent of any specific +web framework. Framework-specific implementations register their context +capture functions at application initialization time. + +This ensures the workflow layer remains completely decoupled from Flask +or any other web framework. +""" + +import contextvars +from collections.abc import Callable + +from core.workflow.context.execution_context import ( + ExecutionContext, + IExecutionContext, + NullAppContext, +) + +# Global capturer function - set by framework-specific modules +_capturer: Callable[[], IExecutionContext] | None = None + + +def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: + """ + Register a context capture function. + + This should be called by framework-specific modules (e.g., Flask) + during application initialization. + + Args: + capturer: Function that captures current context and returns IExecutionContext + """ + global _capturer + _capturer = capturer + + +def capture_current_context() -> IExecutionContext: + """ + Capture current execution context. + + This function uses the registered context capturer. If no capturer + is registered, it returns a minimal context with only contextvars + (suitable for non-framework environments like tests or standalone scripts). + + Returns: + IExecutionContext with captured context + """ + if _capturer is None: + # No framework registered - return minimal context + return ExecutionContext( + app_context=NullAppContext(), + context_vars=contextvars.copy_context(), + ) + + return _capturer() + + +def reset_context_provider() -> None: + """ + Reset the context capturer. + + This is primarily useful for testing to ensure a clean state. + """ + global _capturer + _capturer = None + + +__all__ = [ + "capture_current_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py new file mode 100644 index 0000000000..2d465c8cf4 --- /dev/null +++ b/api/context/flask_app_context.py @@ -0,0 +1,192 @@ +""" +Flask App Context - Flask implementation of AppContext interface. +""" + +import contextvars +import threading +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any, final + +from flask import Flask, current_app, g + +from core.workflow.context import register_context_capturer +from core.workflow.context.execution_context import ( + AppContext, + IExecutionContext, +) + + +@final +class FlaskAppContext(AppContext): + """ + Flask implementation of AppContext. + + This adapts Flask's app context to the AppContext interface. + """ + + def __init__(self, flask_app: Flask) -> None: + """ + Initialize Flask app context. + + Args: + flask_app: The Flask application instance + """ + self._flask_app = flask_app + + def get_config(self, key: str, default: Any = None) -> Any: + """Get configuration value from Flask app config.""" + return self._flask_app.config.get(key, default) + + def get_extension(self, name: str) -> Any: + """Get Flask extension by name.""" + return self._flask_app.extensions.get(name) + + @contextmanager + def enter(self) -> Generator[None, None, None]: + """Enter Flask app context.""" + with self._flask_app.app_context(): + yield + + @property + def flask_app(self) -> Flask: + """Get the underlying Flask app instance.""" + return self._flask_app + + +def capture_flask_context(user: Any = None) -> IExecutionContext: + """ + Capture current Flask execution context. + + This function captures the Flask app context and contextvars from the + current environment. It should be called from within a Flask request or + app context. + + Args: + user: Optional user object to include in context + + Returns: + IExecutionContext with captured Flask context + + Raises: + RuntimeError: If called outside Flask context + """ + # Get Flask app instance + flask_app = current_app._get_current_object() # type: ignore + + # Save current user if available + saved_user = user + if saved_user is None: + # Check for user in g (flask-login) + if hasattr(g, "_login_user"): + saved_user = g._login_user + + # Capture contextvars + context_vars = contextvars.copy_context() + + return FlaskExecutionContext( + flask_app=flask_app, + context_vars=context_vars, + user=saved_user, + ) + + +@final +class FlaskExecutionContext: + """ + Flask-specific execution context. + + This is a specialized version of ExecutionContext that includes Flask app + context. It provides the same interface as ExecutionContext but with + Flask-specific implementation. + """ + + def __init__( + self, + flask_app: Flask, + context_vars: contextvars.Context, + user: Any = None, + ) -> None: + """ + Initialize Flask execution context. + + Args: + flask_app: Flask application instance + context_vars: Python contextvars + user: Optional user object + """ + self._app_context = FlaskAppContext(flask_app) + self._context_vars = context_vars + self._user = user + self._flask_app = flask_app + self._local = threading.local() + + @property + def app_context(self) -> FlaskAppContext: + """Get Flask app context.""" + return self._app_context + + @property + def context_vars(self) -> contextvars.Context: + """Get context variables.""" + return self._context_vars + + @property + def user(self) -> Any: + """Get user object.""" + return self._user + + def __enter__(self) -> "FlaskExecutionContext": + """Enter the Flask execution context.""" + # Restore non-Flask context variables to avoid leaking Flask tokens across threads + for var, val in self._context_vars.items(): + var.set(val) + + # Enter Flask app context + cm = self._app_context.enter() + self._local.cm = cm + cm.__enter__() + + # Restore user in new app context + if self._user is not None: + g._login_user = self._user + + return self + + def __exit__(self, *args: Any) -> None: + """Exit the Flask execution context.""" + cm = getattr(self._local, "cm", None) + if cm is not None: + cm.__exit__(*args) + + @contextmanager + def enter(self) -> Generator[None, None, None]: + """Enter Flask execution context as context manager.""" + # Restore non-Flask context variables to avoid leaking Flask tokens across threads + for var, val in self._context_vars.items(): + var.set(val) + + # Enter Flask app context + with self._flask_app.app_context(): + # Restore user in new app context + if self._user is not None: + g._login_user = self._user + yield + + +def init_flask_context() -> None: + """ + Initialize Flask context capture by registering the capturer. + + This function should be called during Flask application initialization + to register the Flask-specific context capturer with the core context module. + + Example: + app = Flask(__name__) + init_flask_context() # Register Flask context capturer + + Note: + This function does not need the app instance as it uses Flask's + `current_app` to get the app when capturing context. + """ + register_context_capturer(capture_flask_context) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index df9de825de..c16a23fac8 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,62 +1,59 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from libs.helper import AppIconUrlField +from typing import Any, TypeAlias -parameters__system_parameters = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - "workflow_file_upload_limit": fields.Integer, -} +from pydantic import BaseModel, ConfigDict, computed_field + +from core.file import helpers as file_helpers +from models.model import IconType + +JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] +JSONObject: TypeAlias = dict[str, Any] -def build_system_parameters_model(api_or_ns: Api | Namespace): - """Build the system parameters model for the API or Namespace.""" - return api_or_ns.model("SystemParameters", parameters__system_parameters) +class SystemParameters(BaseModel): + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + file_size_limit: int + workflow_file_upload_limit: int -parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(parameters__system_parameters), -} +class Parameters(BaseModel): + opening_statement: str | None = None + suggested_questions: list[str] + suggested_questions_after_answer: JSONObject + speech_to_text: JSONObject + text_to_speech: JSONObject + retriever_resource: JSONObject + annotation_reply: JSONObject + more_like_this: JSONObject + user_input_form: list[JSONObject] + sensitive_word_avoidance: JSONObject + file_upload: JSONObject + system_parameters: SystemParameters -def build_parameters_model(api_or_ns: Api | Namespace): - """Build the parameters model for the API or Namespace.""" - copied_fields = parameters_fields.copy() - copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns)) - return api_or_ns.model("Parameters", copied_fields) +class Site(BaseModel): + model_config = ConfigDict(from_attributes=True) + title: str + chat_color_theme: str | None = None + chat_color_theme_inverted: bool + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + description: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + default_language: str + show_workflow_steps: bool + use_icon_as_answer_icon: bool -site_fields = { - "title": fields.String, - "chat_color_theme": fields.String, - "chat_color_theme_inverted": fields.Boolean, - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "description": fields.String, - "copyright": fields.String, - "privacy_policy": fields.String, - "custom_disclaimer": fields.String, - "default_language": fields.String, - "show_workflow_steps": fields.Boolean, - "use_icon_as_answer_icon": fields.Boolean, -} - - -def build_site_model(api_or_ns: Api | Namespace): - """Build the site model for the API or Namespace.""" - return api_or_ns.model("Site", site_fields) + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + if self.icon and self.icon_type == IconType.IMAGE: + return file_helpers.get_signed_file_url(self.icon) + return None diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py index e0896a8dc2..a5a3e4ebbd 100644 --- a/api/controllers/common/schema.py +++ b/api/controllers/common/schema.py @@ -1,7 +1,11 @@ """Helpers for registering Pydantic models with Flask-RESTX namespaces.""" +from enum import StrEnum + from flask_restx import Namespace -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter + +from controllers.console import console_ns DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -19,8 +23,25 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No register_schema_model(namespace, model) +def get_or_create_model(model_name: str, field_def): + existing = console_ns.models.get(model_name) + if existing is None: + existing = console_ns.model(model_name, field_def) + return existing + + +def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None: + """Register multiple StrEnum with a namespace.""" + for model in models: + namespace.schema_model( + model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) + ) + + __all__ = [ "DEFAULT_REF_TEMPLATE_SWAGGER_2_0", + "get_or_create_model", + "register_enum_models", "register_schema_model", "register_schema_models", ] diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 2cd7e230f4..902d67174b 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -108,10 +108,12 @@ from .datasets.rag_pipeline import ( # Import explore controllers from .explore import ( + banner, installed_app, parameter, recommended_app, saved_message, + trial, ) # Import tag controllers @@ -146,6 +148,7 @@ __all__ = [ "apikey", "app", "audio", + "banner", "billing", "bp", "completion", @@ -200,6 +203,7 @@ __all__ = [ "statistic", "tags", "tool_providers", + "trial", "trigger_providers", "version", "website", diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a25ca5ef51..e1ee2c24b8 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud from core.db.session_factory import session_factory from extensions.ext_database import db from libs.token import extract_access_token -from models.model import App, InstalledApp, RecommendedApp +from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp P = ParamSpec("P") R = TypeVar("R") @@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel): language: str = Field(...) category: str = Field(...) position: int = Field(...) + can_trial: bool = Field(default=False) + trial_limit: int = Field(default=0) @field_validator("language") @classmethod @@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel): return supported_language(value) +class InsertExploreBannerPayload(BaseModel): + category: str = Field(...) + title: str = Field(...) + description: str = Field(...) + img_src: str = Field(..., alias="img-src") + language: str = Field(default="en-US") + link: str = Field(...) + sort: int = Field(...) + + @field_validator("language") + @classmethod + def validate_language(cls, value: str) -> str: + return supported_language(value) + + model_config = {"populate_by_name": True} + + console_ns.schema_model( InsertExploreAppPayload.__name__, InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) +console_ns.schema_model( + InsertExploreBannerPayload.__name__, + InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + def admin_required(view: Callable[P, R]): @wraps(view) @@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource): ) db.session.add(recommended_app) + if payload.can_trial: + trial_app = db.session.execute( + select(TrialApp).where(TrialApp.app_id == payload.app_id) + ).scalar_one_or_none() + if not trial_app: + db.session.add( + TrialApp( + app_id=payload.app_id, + tenant_id=app.tenant_id, + trial_limit=payload.trial_limit, + ) + ) + else: + trial_app.trial_limit = payload.trial_limit app.is_public = True db.session.commit() @@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource): recommended_app.category = payload.category recommended_app.position = payload.position + if payload.can_trial: + trial_app = db.session.execute( + select(TrialApp).where(TrialApp.app_id == payload.app_id) + ).scalar_one_or_none() + if not trial_app: + db.session.add( + TrialApp( + app_id=payload.app_id, + tenant_id=app.tenant_id, + trial_limit=payload.trial_limit, + ) + ) + else: + trial_app.trial_limit = payload.trial_limit app.is_public = True db.session.commit() @@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource): for installed_app in installed_apps: session.delete(installed_app) + trial_app = session.execute( + select(TrialApp).where(TrialApp.app_id == recommended_app.app_id) + ).scalar_one_or_none() + if trial_app: + session.delete(trial_app) + db.session.delete(recommended_app) db.session.commit() return {"result": "success"}, 204 + + +@console_ns.route("/admin/insert-explore-banner") +class InsertExploreBannerApi(Resource): + @console_ns.doc("insert_explore_banner") + @console_ns.doc(description="Insert an explore banner") + @console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__]) + @console_ns.response(201, "Banner inserted successfully") + @only_edition_cloud + @admin_required + def post(self): + payload = InsertExploreBannerPayload.model_validate(console_ns.payload) + + content = { + "category": payload.category, + "title": payload.title, + "description": payload.description, + "img-src": payload.img_src, + } + + banner = ExporleBanner( + content=content, + link=payload.link, + sort=payload.sort, + language=payload.language, + ) + db.session.add(banner) + db.session.commit() + + return {"result": "success"}, 201 + + +@console_ns.route("/admin/delete-explore-banner/") +class DeleteExploreBannerApi(Resource): + @console_ns.doc("delete_explore_banner") + @console_ns.doc(description="Delete an explore banner") + @console_ns.doc(params={"banner_id": "Banner ID to delete"}) + @console_ns.response(204, "Banner deleted successfully") + @only_edition_cloud + @admin_required + def delete(self, banner_id): + banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none() + if not banner: + raise NotFound(f"Banner '{banner_id}' is not found") + + db.session.delete(banner) + db.session.commit() + + return {"result": "success"}, 204 diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 9b0d4b1a78..c81709e985 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -22,10 +22,10 @@ api_key_fields = { "created_at": TimestampField, } -api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} - api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields) +api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} + api_key_list_model = console_ns.model( "ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} ) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 62e997dae2..8c371da596 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,15 +1,19 @@ import uuid -from typing import Literal +from datetime import datetime +from typing import Any, Literal, TypeAlias from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest +from controllers.common.helpers import FileInfo +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model +from controllers.console.workspace.models import LoadBalancingPayload from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -18,27 +22,37 @@ from controllers.console.wraps import ( is_admin_or_owner_required, setup_required, ) +from core.file import helpers as file_helpers from core.ops.ops_trace_manager import OpsTraceManager -from core.workflow.enums import NodeType +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.workflow.enums import NodeType, WorkflowExecutionStatus from extensions.ext_database import db -from fields.app_fields import ( - deleted_tool_fields, - model_config_fields, - model_config_partial_fields, - site_fields, - tag_fields, -) -from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict -from libs.helper import AppIconUrlField, TimestampField from libs.login import current_account_with_tenant, login_required -from models import App, Workflow +from models import App, DatasetPermissionEnum, Workflow +from models.model import IconType from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService +from services.entities.knowledge_entities.knowledge_entities import ( + DataSource, + InfoList, + NotionIcon, + NotionInfo, + NotionPage, + PreProcessingRule, + RerankingModel, + Rule, + Segmentation, + WebsiteInfo, + WeightKeywordSetting, + WeightModel, + WeightVectorSetting, +) from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + +register_enum_models(console_ns, IconType) class AppListQuery(BaseModel): @@ -134,124 +148,310 @@ class AppTracePayload(BaseModel): return value -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +JSONValue: TypeAlias = Any -reg(AppListQuery) -reg(CreateAppPayload) -reg(UpdateAppPayload) -reg(CopyAppPayload) -reg(AppExportQuery) -reg(AppNamePayload) -reg(AppIconPayload) -reg(AppSiteStatusPayload) -reg(AppApiStatusPayload) -reg(AppTracePayload) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base models first -tag_model = console_ns.model("Tag", tag_fields) -workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -model_config_model = console_ns.model("ModelConfig", model_config_fields) -model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) -deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields) -site_model = console_ns.model("Site", site_fields) +class Tag(ResponseModel): + id: str + name: str + type: str -app_partial_model = console_ns.model( - "AppPartial", - { - "id": fields.String, - "name": fields.String, - "max_active_requests": fields.Raw(), - "description": fields.String(attribute="desc_or_prompt"), - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "tags": fields.List(fields.Nested(tag_model)), - "access_mode": fields.String, - "create_user_name": fields.String, - "author_name": fields.String, - "has_draft_trigger": fields.Boolean, - }, -) -app_detail_model = console_ns.model( - "AppDetail", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon": fields.String, - "icon_background": fields.String, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "tracing": fields.Raw, - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - }, -) +class WorkflowPartial(ResponseModel): + id: str + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None -app_detail_with_site_model = console_ns.model( - "AppDetailWithSite", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "api_base_url": fields.String, - "use_icon_as_answer_icon": fields.Boolean, - "max_active_requests": fields.Integer, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "deleted_tools": fields.List(fields.Nested(deleted_tool_model)), - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - "site": fields.Nested(site_model), - }, -) + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -app_pagination_model = console_ns.model( - "AppPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(app_partial_model), attribute="items"), - }, + +class ModelConfigPartial(ResponseModel): + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + pre_prompt: str | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class ModelConfig(ResponseModel): + opening_statement: str | None = None + suggested_questions: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions") + ) + suggested_questions_after_answer: JSONValue | None = Field( + default=None, + validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"), + ) + speech_to_text: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text") + ) + text_to_speech: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech") + ) + retriever_resource: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource") + ) + annotation_reply: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply") + ) + more_like_this: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this") + ) + sensitive_word_avoidance: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance") + ) + external_data_tools: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools") + ) + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + user_input_form: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form") + ) + dataset_query_variable: str | None = None + pre_prompt: str | None = None + agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode")) + prompt_type: str | None = None + chat_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config") + ) + completion_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config") + ) + dataset_configs: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs") + ) + file_upload: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload") + ) + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class Site(ResponseModel): + access_token: str | None = Field(default=None, validation_alias="code") + code: str | None = None + title: str | None = None + icon_type: str | IconType | None = None + icon: str | None = None + icon_background: str | None = None + description: str | None = None + default_language: str | None = None + chat_color_theme: str | None = None + chat_color_theme_inverted: bool | None = None + customize_domain: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + customize_token_strategy: str | None = None + prompt_public: bool | None = None + app_base_url: str | None = None + show_workflow_steps: bool | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("icon_type", mode="before") + @classmethod + def _normalize_icon_type(cls, value: str | IconType | None) -> str | None: + if isinstance(value, IconType): + return value.value + return value + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DeletedTool(ResponseModel): + type: str + tool_name: str + provider_id: str + + +class AppPartial(ResponseModel): + id: str + name: str + max_active_requests: int | None = None + description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description")) + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + model_config_: ModelConfigPartial | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + tags: list[Tag] = Field(default_factory=list) + access_mode: str | None = None + create_user_name: str | None = None + author_name: str | None = None + has_draft_trigger: bool | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetail(ResponseModel): + id: str + name: str + description: str | None = None + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon: str | None = None + icon_background: str | None = None + enable_site: bool + enable_api: bool + model_config_: ModelConfig | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + tracing: JSONValue | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + access_mode: str | None = None + tags: list[Tag] = Field(default_factory=list) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetailWithSite(AppDetail): + icon_type: str | None = None + api_base_url: str | None = None + max_active_requests: int | None = None + deleted_tools: list[DeletedTool] = Field(default_factory=list) + site: Site | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class AppPagination(ResponseModel): + page: int + limit: int = Field(validation_alias=AliasChoices("per_page", "limit")) + total: int + has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more")) + data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data")) + + +class AppExportResponse(ResponseModel): + data: str + + +register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum) + +register_schema_models( + console_ns, + AppListQuery, + CreateAppPayload, + UpdateAppPayload, + CopyAppPayload, + AppExportQuery, + AppNamePayload, + AppIconPayload, + AppSiteStatusPayload, + AppApiStatusPayload, + AppTracePayload, + Tag, + WorkflowPartial, + ModelConfigPartial, + ModelConfig, + Site, + DeletedTool, + AppPartial, + AppDetail, + AppDetailWithSite, + AppPagination, + AppExportResponse, + Segmentation, + PreProcessingRule, + Rule, + WeightVectorSetting, + WeightKeywordSetting, + WeightModel, + RerankingModel, + InfoList, + NotionInfo, + FileInfo, + WebsiteInfo, + NotionPage, + NotionIcon, + RerankingModel, + DataSource, + LoadBalancingPayload, ) @@ -260,7 +460,7 @@ class AppListApi(Resource): @console_ns.doc("list_apps") @console_ns.doc(description="Get list of applications with pagination and filtering") @console_ns.expect(console_ns.models[AppListQuery.__name__]) - @console_ns.response(200, "Success", app_pagination_model) + @console_ns.response(200, "Success", console_ns.models[AppPagination.__name__]) @setup_required @login_required @account_initialization_required @@ -276,7 +476,8 @@ class AppListApi(Resource): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict) if not app_pagination: - return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} + empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) + return empty.model_dump(mode="json"), 200 if FeatureService.get_system_features().webapp_auth.enabled: app_ids = [str(app.id) for app in app_pagination.items] @@ -320,18 +521,18 @@ class AppListApi(Resource): for app in app_pagination.items: app.has_draft_trigger = str(app.id) in draft_trigger_app_ids - return marshal(app_pagination, app_pagination_model), 200 + pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True) + return pagination_model.model_dump(mode="json"), 200 @console_ns.doc("create_app") @console_ns.doc(description="Create a new application") @console_ns.expect(console_ns.models[CreateAppPayload.__name__]) - @console_ns.response(201, "App created successfully", app_detail_model) + @console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(app_detail_model) @cloud_edition_billing_resource_check("apps") @edit_permission_required def post(self): @@ -341,8 +542,8 @@ class AppListApi(Resource): app_service = AppService() app = app_service.create_app(current_tenant_id, args.model_dump(), current_user) - - return app, 201 + app_detail = AppDetail.model_validate(app, from_attributes=True) + return app_detail.model_dump(mode="json"), 201 @console_ns.route("/apps/") @@ -350,13 +551,12 @@ class AppApi(Resource): @console_ns.doc("get_app_detail") @console_ns.doc(description="Get application details") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Success", app_detail_with_site_model) + @console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__]) @setup_required @login_required @account_initialization_required @enterprise_license_required - @get_app_model - @marshal_with(app_detail_with_site_model) + @get_app_model(mode=None) def get(self, app_model): """Get app detail""" app_service = AppService() @@ -367,21 +567,21 @@ class AppApi(Resource): app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("update_app") @console_ns.doc(description="Update application details") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[UpdateAppPayload.__name__]) - @console_ns.response(200, "App updated successfully", app_detail_with_site_model) + @console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def put(self, app_model): """Update app""" args = UpdateAppPayload.model_validate(console_ns.payload) @@ -398,8 +598,8 @@ class AppApi(Resource): "max_active_requests": args.max_active_requests or 0, } app_model = app_service.update_app(app_model, args_dict) - - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("delete_app") @console_ns.doc(description="Delete application") @@ -425,14 +625,13 @@ class AppCopyApi(Resource): @console_ns.doc(description="Create a copy of an existing application") @console_ns.doc(params={"app_id": "Application ID to copy"}) @console_ns.expect(console_ns.models[CopyAppPayload.__name__]) - @console_ns.response(201, "App copied successfully", app_detail_with_site_model) + @console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def post(self, app_model): """Copy app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -458,7 +657,8 @@ class AppCopyApi(Resource): stmt = select(App).where(App.id == result.app_id) app = session.scalar(stmt) - return app, 201 + response_model = AppDetailWithSite.model_validate(app, from_attributes=True) + return response_model.model_dump(mode="json"), 201 @console_ns.route("/apps//export") @@ -467,11 +667,7 @@ class AppExportApi(Resource): @console_ns.doc(description="Export application configuration as DSL") @console_ns.doc(params={"app_id": "Application ID to export"}) @console_ns.expect(console_ns.models[AppExportQuery.__name__]) - @console_ns.response( - 200, - "App exported successfully", - console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), - ) + @console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @get_app_model @setup_required @@ -482,13 +678,14 @@ class AppExportApi(Resource): """Export app""" args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return { - "data": AppDslService.export_dsl( + payload = AppExportResponse( + data=AppDslService.export_dsl( app_model=app_model, include_secret=args.include_secret, workflow_id=args.workflow_id, ) - } + ) + return payload.model_dump(mode="json") @console_ns.route("/apps//name") @@ -497,20 +694,19 @@ class AppNameApi(Resource): @console_ns.doc(description="Check if app name is available") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppNamePayload.__name__]) - @console_ns.response(200, "Name availability checked") + @console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__]) @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppNamePayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_name(app_model, args.name) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//icon") @@ -524,16 +720,15 @@ class AppIconApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//site-enable") @@ -542,21 +737,20 @@ class AppSiteStatus(Resource): @console_ns.doc(description="Enable or disable app site") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__]) - @console_ns.response(200, "Site status updated successfully", app_detail_model) + @console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppSiteStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_site_status(app_model, args.enable_site) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//api-enable") @@ -565,21 +759,20 @@ class AppApiStatus(Resource): @console_ns.doc(description="Enable or disable app API") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__]) - @console_ns.response(200, "API status updated successfully", app_detail_model) + @console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @is_admin_or_owner_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) def post(self, app_model): args = AppApiStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_api_status(app_model, args.enable_api) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//trace") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 22e2aeb720..fdef54ba5a 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -41,14 +41,14 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class AppImportPayload(BaseModel): mode: str = Field(..., description="Import mode") - yaml_content: str | None = None - yaml_url: str | None = None - name: str | None = None - description: str | None = None - icon_type: str | None = None - icon: str | None = None - icon_background: str | None = None - app_id: str | None = None + yaml_content: str | None = Field(None) + yaml_url: str | None = Field(None) + name: str | None = Field(None) + description: str | None = Field(None) + icon_type: str | None = Field(None) + icon: str | None = Field(None) + icon_background: str | None = Field(None) + app_id: str | None = Field(None) console_ns.schema_model( diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 1501d39485..14910c5895 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.conversation_fields import MessageTextField from fields.raws import FilesContainedField from libs.datetime_utils import naive_utc_now, parse_time_range from libs.helper import TimestampField @@ -178,6 +177,12 @@ annotation_hit_history_model = console_ns.model( }, ) + +class MessageTextField(fields.Raw): + def format(self, value): + return value[0]["text"] if value else "" + + # Simple message detail model simple_message_detail_model = console_ns.model( "SimpleMessageDetail", @@ -344,10 +349,13 @@ class CompletionConversationApi(Resource): ) if args.keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( - Message.query.ilike(f"%{args.keyword}%"), - Message.answer.ilike(f"%{args.keyword}%"), + Message.query.ilike(f"%{escaped_keyword}%", escape="\\"), + Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) @@ -456,7 +464,10 @@ class ChatConversationApi(Resource): query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args.keyword: - keyword_filter = f"%{args.keyword}%" + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) + keyword_filter = f"%{escaped_keyword}%" query = ( query.join( Message, @@ -465,11 +476,11 @@ class ChatConversationApi(Resource): .join(subquery, subquery.c.conversation_id == Conversation.id) .where( or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter), + Message.query.ilike(keyword_filter, escape="\\"), + Message.answer.ilike(keyword_filter, escape="\\"), + Conversation.name.ilike(keyword_filter, escape="\\"), + Conversation.introduction.ilike(keyword_filter, escape="\\"), + subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"), ), ) .group_by(Conversation.id) @@ -582,9 +593,12 @@ def _get_conversation(app_model, conversation_id): if not conversation: raise NotFound("Conversation Not Exists.") - if not conversation.read_at: - conversation.read_at = naive_utc_now() - conversation.read_account_id = current_user.id - db.session.commit() + db.session.execute( + sa.update(Conversation) + .where(Conversation.id == conversation_id, Conversation.read_at.is_(None)) + .values(read_at=naive_utc_now(), read_account_id=current_user.id) + ) + db.session.commit() + db.session.refresh(conversation) return conversation diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index fbd7901646..3fa15d6d6d 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException): class DraftWorkflowNotExist(BaseHTTPException): error_code = "draft_workflow_not_exist" description = "Draft workflow need to be initialized." - code = 400 + code = 404 class DraftWorkflowNotSync(BaseHTTPException): error_code = "draft_workflow_not_sync" description = "Workflow graph might have been modified, please refresh and resubmit." - code = 400 + code = 409 class TracingConfigNotExist(BaseHTTPException): @@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException): error_code = "rate_limit_error" description = "Rate Limit Error" code = 429 + + +class NeedAddIdsError(BaseHTTPException): + error_code = "need_add_ids" + description = "Need to add ids." + code = 400 diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index b0bdf2657d..94b9693929 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -12,6 +12,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.console import console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync +from controllers.console.app.workflow_run import workflow_run_node_execution_model from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError @@ -36,7 +37,6 @@ from extensions.ext_database import db from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value @@ -89,26 +89,6 @@ workflow_pagination_fields_copy = workflow_pagination_fields.copy() workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items") workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy) -# Reuse workflow_run_node_execution_model from workflow_run.py if already registered -# Otherwise register it here -from fields.end_user_fields import simple_end_user_fields - -simple_end_user_model = None -try: - simple_end_user_model = console_ns.models.get("SimpleEndUser") -except AttributeError: - pass -if simple_end_user_model is None: - simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields) - -workflow_run_node_execution_model = None -try: - workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution") -except AttributeError: - pass -if workflow_run_node_execution_model is None: - workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields) - class SyncDraftWorkflowPayload(BaseModel): graph: dict[str, Any] @@ -471,7 +451,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ current_user, _ = current_account_with_tenant() - args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) + args = LoopNodeRunPayload.model_validate(console_ns.payload or {}) try: response = AppGenerateService.generate_single_loop( @@ -509,7 +489,7 @@ class WorkflowDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ current_user, _ = current_account_with_tenant() - args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) + args = LoopNodeRunPayload.model_validate(console_ns.payload or {}) try: response = AppGenerateService.generate_single_loop( @@ -1176,6 +1156,7 @@ class DraftWorkflowTriggerRunApi(Resource): if not event: return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN}) workflow_args = dict(event.workflow_args) + workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True return helper.compact_generate_response( AppGenerateService.generate( @@ -1324,6 +1305,7 @@ class DraftWorkflowTriggerRunAllApi(Resource): try: workflow_args = dict(trigger_debug_event.workflow_args) + workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True response = AppGenerateService.generate( app_model=app_model, diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index fa67fb8154..6736f24a2e 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.workflow.enums import WorkflowExecutionStatus from extensions.ext_database import db -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.workflow_app_log_fields import ( + build_workflow_app_log_pagination_model, + build_workflow_archived_log_pagination_model, +) from libs.login import login_required from models import App from models.model import AppMode @@ -61,6 +64,7 @@ console_ns.schema_model( # Register model for flask_restx to avoid dict type issues in Swagger workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) +workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns) @console_ns.route("/apps//workflow-app-logs") @@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination + + +@console_ns.route("/apps//workflow-archived-logs") +class WorkflowArchivedLogApi(Resource): + @console_ns.doc("get_workflow_archived_logs") + @console_ns.doc(description="Get workflow archived execution logs") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) + @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @marshal_with(workflow_archived_log_pagination_model) + def get(self, app_model: App): + """ + Get workflow archived logs + """ + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + workflow_app_service = WorkflowAppService() + with Session(db.engine) as session: + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( + session=session, + app_model=app_model, + page=args.page, + limit=args.limit, + ) + + return workflow_app_log_pagination diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 36b2e40928..585ede6f5d 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,8 +1,10 @@ +from datetime import UTC, datetime, timedelta from typing import Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from configs import dify_config @@ -25,12 +27,14 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required -from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom +from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME from services.workflow_run_service import WorkflowRunService @@ -45,6 +49,7 @@ def _build_backstage_input_url(form_token: str | None) -> str | None: # Workflow run status choices for filtering WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] +EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -111,6 +116,15 @@ workflow_run_node_execution_list_model = console_ns.model( "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy ) +workflow_run_export_fields = console_ns.model( + "WorkflowRunExport", + { + "status": fields.String(description="Export status: success/failed"), + "presigned_url": fields.String(description="Pre-signed URL for download", required=False), + "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False), + }, +) + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -199,6 +213,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource): return result +@console_ns.route("/apps//workflow-runs//export") +class WorkflowRunExportApi(Resource): + @console_ns.doc("get_workflow_run_export_url") + @console_ns.doc(description="Generate a download URL for an archived workflow run.") + @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @console_ns.response(200, "Export URL generated", workflow_run_export_fields) + @setup_required + @login_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App, run_id: str): + tenant_id = str(app_model.tenant_id) + app_id = str(app_model.id) + run_id_str = str(run_id) + + run_created_at = db.session.scalar( + select(WorkflowArchiveLog.run_created_at) + .where( + WorkflowArchiveLog.tenant_id == tenant_id, + WorkflowArchiveLog.app_id == app_id, + WorkflowArchiveLog.workflow_run_id == run_id_str, + ) + .limit(1) + ) + if not run_created_at: + return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404 + + prefix = ( + f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/" + f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}" + ) + archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + try: + archive_storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + return {"code": "archive_storage_not_configured", "message": str(e)}, 500 + + presigned_url = archive_storage.generate_presigned_url( + archive_key, + expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS, + ) + expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS) + return { + "status": "success", + "presigned_url": presigned_url, + "presigned_url_expires_at": expires_at.isoformat(), + }, 200 + + @console_ns.route("/apps//advanced-chat/workflow-runs/count") class AdvancedChatAppWorkflowRunCountApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs_count") diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 9433b732e4..8236e766ae 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -1,13 +1,14 @@ import logging from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config +from controllers.common.schema import get_or_create_model from extensions.ext_database import db from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields from libs.login import current_user, login_required @@ -22,6 +23,14 @@ from ..wraps import account_initialization_required, edit_permission_required, s logger = logging.getLogger(__name__) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields) + +triggers_list_fields_copy = triggers_list_fields.copy() +triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model)) +triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy) + +webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields) + class Parser(BaseModel): node_id: str @@ -48,7 +57,7 @@ class WebhookTriggerApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(webhook_trigger_fields) + @marshal_with(webhook_trigger_model) def get(self, app_model: App): """Get webhook trigger for a node""" args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -80,7 +89,7 @@ class AppTriggersApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(triggers_list_fields) + @marshal_with(triggers_list_model) def get(self, app_model: App): """Get app triggers list""" assert isinstance(current_user, Account) @@ -120,7 +129,7 @@ class AppTriggerEnableApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(trigger_fields) + @marshal_with(trigger_model) def post(self, app_model: App): """Update app trigger (enable/disable)""" args = ParserEnable.model_validate(console_ns.payload) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 9bb2718f89..e687d980fa 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None: return app_model +def _load_app_model_with_trial(app_id: str) -> App | None: + app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + return app_model + + def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func: Callable[P1, R1]): @wraps(view_func) @@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li return decorator else: return decorator(view) + + +def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): + def decorator(view_func: Callable[P, R]): + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + if not kwargs.get("app_id"): + raise ValueError("missing app_id in path parameters") + + app_id = kwargs.get("app_id") + app_id = str(app_id) + + del kwargs["app_id"] + + app_model = _load_app_model_with_trial(app_id) + + if not app_model: + raise AppNotFoundError() + + app_mode = AppMode.value_of(app_model.mode) + + if mode is not None: + if isinstance(mode, list): + modes = mode + else: + modes = [mode] + + if app_mode not in modes: + mode_values = {m.value for m in modes} + raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + + kwargs["app_model"] = app_model + + return view_func(*args, **kwargs) + + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index fe70d930fb..f741107b87 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -63,13 +63,19 @@ class ActivateCheckApi(Resource): args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore workspaceId = args.workspace_id - reg_email = args.email token = args.token - invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) + invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token) if invitation: data = invitation.get("data", {}) tenant = invitation.get("tenant", None) + + # Check workspace permission + if tenant: + from libs.workspace_permission import check_workspace_member_invite_permission + + check_workspace_member_invite_permission(tenant.id) + workspace_name = tenant.name if tenant else None workspace_id = tenant.id if tenant else None invitee_email = data.get("email") if data else None @@ -100,11 +106,12 @@ class ActivateApi(Resource): def post(self): args = ActivatePayload.model_validate(console_ns.payload) - invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token) + normalized_request_email = args.email.lower() if args.email else None + invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args.workspace_id, args.email, args.token) + RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token) account = invitation["account"] account.name = args.name diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index fa082c735d..c2a95ddad2 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource): @email_register_enabled def post(self): args = EmailRegisterSendPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource): if args.language in languages: language = args.language - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() - token = None - token = AccountService.send_email_register_email(email=args.email, account=account, language=language) + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource): def post(self): args = EmailRegisterValidityPayload.model_validate(console_ns.payload) - user_email = args.email + user_email = args.email.lower() - is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email) + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email) if is_email_register_error_rate_limit: raise EmailRegisterLimitError() @@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_email_register_error_rate_limit(args.email) + AccountService.add_email_register_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource): user_email, code=args.code, additional_data={"phase": "register"} ) - AccountService.reset_email_register_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_email_register_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/email-register") @@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource): AccountService.revoke_email_register_token(args.token) email = register_data.get("email", "") + normalized_email = email.lower() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: raise EmailAlreadyInUseError() else: - account = self._create_new_account(email, args.password_confirm) + account = self._create_new_account(normalized_email, args.password_confirm) if not account: raise AccountNotFoundError() token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(email) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} - def _create_new_account(self, email, password) -> Account | None: + def _create_new_account(self, email: str, password: str) -> Account | None: # Create new account if allowed account = None try: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 661f591182..394f205d93 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from controllers.console import console_ns @@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password -from models import Account from services.account_service import AccountService, TenantService from services.feature_service import FeatureService @@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource): @email_password_login_enabled def post(self): args = ForgotPasswordSendPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( account=account, - email=args.email, + email=normalized_email, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, ) @@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource): def post(self): args = ForgotPasswordCheckPayload.model_validate(console_ns.payload) - user_email = args.email + user_email = args.email.lower() - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() @@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args.email) + AccountService.add_forgot_password_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource): # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=args.code, additional_data={"phase": "reset"} + token_email, code=args.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_forgot_password_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/forgot-password/resets") @@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: self._update_existing_account(account, password_hashed, salt, session) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 772d98822e..400df138b8 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,3 +1,5 @@ +from typing import Any + import flask_login from flask import make_response, request from flask_restx import Resource @@ -88,33 +90,38 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" args = LoginPayload.model_validate(console_ns.payload) + request_email = args.email + normalized_email = request_email.lower() - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email) + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email) if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() - # TODO: why invitation is re-assigned with different type? - invitation = args.invite_token # type: ignore - if invitation: - invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore + invite_token = args.invite_token + invitation_data: dict[str, Any] | None = None + if invite_token: + invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) + if invitation_data is None: + invite_token = None try: - if invitation: - data = invitation.get("data", {}) # type: ignore + if invitation_data: + data = invitation_data.get("data", {}) invitee_email = data.get("email") if data else None - if invitee_email != args.email: + invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email + if invitee_email_normalized != normalized_email: raise InvalidEmailError() - account = AccountService.authenticate(args.email, args.password, args.invite_token) - else: - account = AccountService.authenticate(args.email, args.password) + account = _authenticate_account_with_case_fallback( + request_email, normalized_email, args.password, invite_token + ) except services.errors.account.AccountLoginError: raise AccountBannedError() - except services.errors.account.AccountPasswordError: - AccountService.add_login_error_rate_limit(args.email) - raise AuthenticationFailedError() + except services.errors.account.AccountPasswordError as exc: + AccountService.add_login_error_rate_limit(normalized_email) + raise AuthenticationFailedError() from exc # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: @@ -129,7 +136,7 @@ class LoginApi(Resource): } token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args.email) + AccountService.reset_login_error_rate_limit(normalized_email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) @@ -169,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource): @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): args = EmailPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" try: - account = AccountService.get_user_through_email(args.email) + account = _get_account_with_case_fallback(args.email) except AccountRegisterError: raise AccountInFreezeError() token = AccountService.send_reset_password_email( - email=args.email, + email=normalized_email, account=account, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, @@ -195,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource): @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): args = EmailPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -205,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource): else: language = "en-US" try: - account = AccountService.get_user_through_email(args.email) + account = _get_account_with_case_fallback(args.email) except AccountRegisterError: raise AccountInFreezeError() if account is None: if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_email_code_login_email(email=args.email, language=language) + token = AccountService.send_email_code_login_email(email=normalized_email, language=language) else: raise AccountNotFound() else: @@ -228,14 +237,17 @@ class EmailCodeLoginApi(Resource): def post(self): args = EmailCodeLoginPayload.model_validate(console_ns.payload) - user_email = args.email + original_email = args.email + user_email = original_email.lower() language = args.language token_data = AccountService.get_email_code_login_data(args.token) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args.email: + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if normalized_token_email != user_email: raise InvalidEmailError() if token_data["code"] != args.code: @@ -243,7 +255,7 @@ class EmailCodeLoginApi(Resource): AccountService.revoke_email_code_login_token(args.token) try: - account = AccountService.get_user_through_email(user_email) + account = _get_account_with_case_fallback(original_email) except AccountRegisterError: raise AccountInFreezeError() if account: @@ -274,7 +286,7 @@ class EmailCodeLoginApi(Resource): except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args.email) + AccountService.reset_login_error_rate_limit(user_email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) @@ -308,3 +320,22 @@ class RefreshTokenApi(Resource): return response except Exception as e: return {"result": "fail", "message": str(e)}, 401 + + +def _get_account_with_case_fallback(email: str): + account = AccountService.get_user_through_email(email) + if account or email == email.lower(): + return account + + return AccountService.get_user_through_email(email.lower()) + + +def _authenticate_account_with_case_fallback( + original_email: str, normalized_email: str, password: str, invite_token: str | None +): + try: + return AccountService.authenticate(original_email, password, invite_token) + except services.errors.account.AccountPasswordError: + if original_email == normalized_email: + raise + return AccountService.authenticate(normalized_email, password, invite_token) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 7ad1e56373..112e152432 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,7 +3,6 @@ import logging import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -118,13 +117,16 @@ class OAuthCallback(Resource): invitation = RegisterService.get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) - if invitation_email != user_info.email: + invitation_email_normalized = ( + invitation_email.lower() if isinstance(invitation_email, str) else invitation_email + ) + if invitation_email_normalized != user_info.email.lower(): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") try: - account = _generate_account(provider, user_info) + account, oauth_new_user = _generate_account(provider, user_info) except AccountNotFoundError: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): @@ -159,7 +161,10 @@ class OAuthCallback(Resource): ip_address=extract_remote_ip(request), ) - response = redirect(f"{dify_config.CONSOLE_WEB_URL}") + base_url = dify_config.CONSOLE_WEB_URL + query_char = "&" if "?" in base_url else "?" + target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}" + response = redirect(target_url) set_access_token_to_cookie(request, response, token_pair.access_token) set_refresh_token_to_cookie(request, response, token_pair.refresh_token) @@ -172,14 +177,15 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> if not account: with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account -def _generate_account(provider: str, user_info: OAuthUserInfo): +def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]: # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) + oauth_new_user = False if account: tenants = TenantService.get_join_tenants(account) @@ -193,8 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): tenant_was_created.send(new_tenant) if not account: + normalized_email = user_info.email.lower() + oauth_new_user = True if not FeatureService.get_system_features().is_allow_register: - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountRegisterError( description=( "This email account has been deleted within the past " @@ -205,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" account = RegisterService.register( - email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider + email=normalized_email, + name=account_name, + password=None, + open_id=user_info.id, + provider=provider, ) # Set interface language @@ -220,4 +232,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): # Link account AccountService.link_account_integrate(provider, user_info.id, account) - return account + return account, oauth_new_user diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index cd958bbb36..01e9bf77c0 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -3,13 +3,13 @@ from collections.abc import Generator from typing import Any, cast from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.common.schema import register_schema_model +from controllers.common.schema import get_or_create_model, register_schema_model from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.indexing_runner import IndexingRunner @@ -17,7 +17,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db -from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields +from fields.data_source_fields import ( + integrate_fields, + integrate_icon_fields, + integrate_list_fields, + integrate_notion_info_list_fields, + integrate_page_fields, + integrate_workspace_fields, +) from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DataSourceOauthBinding, Document @@ -36,9 +43,62 @@ class NotionEstimatePayload(BaseModel): doc_language: str = Field(default="English") +class DataSourceNotionListQuery(BaseModel): + dataset_id: str | None = Field(default=None, description="Dataset ID") + credential_id: str = Field(..., description="Credential ID", min_length=1) + datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string") + + +class DataSourceNotionPreviewQuery(BaseModel): + credential_id: str = Field(..., description="Credential ID", min_length=1) + + register_schema_model(console_ns, NotionEstimatePayload) +integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields) + +integrate_page_fields_copy = integrate_page_fields.copy() +integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True) +integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy) + +integrate_workspace_fields_copy = integrate_workspace_fields.copy() +integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model)) +integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy) + +integrate_fields_copy = integrate_fields.copy() +integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model) +integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy) + +integrate_list_fields_copy = integrate_list_fields.copy() +integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model)) +integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy) + +notion_page_fields = { + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_model, allow_null=True), + "is_bound": fields.Boolean, + "parent_id": fields.String, + "type": fields.String, +} +notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields) + +notion_workspace_fields = { + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(notion_page_model)), +} +notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields) + +integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy() +integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model)) +integrate_notion_info_list_model = get_or_create_model( + "NotionIntegrateInfoList", integrate_notion_info_list_fields_copy +) + + @console_ns.route( "/data-source/integrates", "/data-source/integrates//", @@ -47,7 +107,7 @@ class DataSourceApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(integrate_list_fields) + @marshal_with(integrate_list_model) def get(self): _, current_tenant_id = current_account_with_tenant() @@ -132,30 +192,19 @@ class DataSourceNotionListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(integrate_notion_info_list_fields) + @marshal_with(integrate_notion_info_list_model) def get(self): current_user, current_tenant_id = current_account_with_tenant() - dataset_id = request.args.get("dataset_id", default=None, type=str) - credential_id = request.args.get("credential_id", default=None, type=str) - if not credential_id: - raise ValueError("Credential id is required.") + query = DataSourceNotionListQuery.model_validate(request.args.to_dict()) # Get datasource_parameters from query string (optional, for GitHub and other datasources) - datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str) - datasource_parameters = {} - if datasource_parameters_str: - try: - datasource_parameters = json.loads(datasource_parameters_str) - if not isinstance(datasource_parameters, dict): - raise ValueError("datasource_parameters must be a JSON object.") - except json.JSONDecodeError: - raise ValueError("Invalid datasource_parameters JSON format.") + datasource_parameters = query.datasource_parameters or {} datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, - credential_id=credential_id, + credential_id=query.credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) @@ -164,8 +213,8 @@ class DataSourceNotionListApi(Resource): exist_page_ids = [] with Session(db.engine) as session: # import notion in the exist dataset - if dataset_id: - dataset = DatasetService.get_dataset(dataset_id) + if query.dataset_id: + dataset = DatasetService.get_dataset(query.dataset_id) if not dataset: raise NotFound("Dataset not found.") if dataset.data_source_type != "notion_import": @@ -173,7 +222,7 @@ class DataSourceNotionListApi(Resource): documents = session.scalars( select(Document).filter_by( - dataset_id=dataset_id, + dataset_id=query.dataset_id, tenant_id=current_tenant_id, data_source_type="notion_import", enabled=True, @@ -240,13 +289,12 @@ class DataSourceNotionApi(Resource): def get(self, page_id, page_type): _, current_tenant_id = current_account_with_tenant() - credential_id = request.args.get("credential_id", default=None, type=str) - if not credential_id: - raise ValueError("Credential id is required.") + query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict()) + datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, - credential_id=credential_id, + credential_id=query.credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8ceb896d4f..8fbbc51e21 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden, NotFound import services from configs import dify_config -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.apikey import ( api_key_item_model, @@ -34,6 +34,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.app_fields import app_detail_kernel_fields, related_app_list from fields.dataset_fields import ( + content_fields, dataset_detail_fields, dataset_fields, dataset_query_detail_fields, @@ -41,6 +42,7 @@ from fields.dataset_fields import ( doc_metadata_fields, external_knowledge_info_fields, external_retrieval_model_fields, + file_info_fields, icon_info_fields, keyword_setting_fields, reranking_model_fields, @@ -55,41 +57,33 @@ from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService - -def _get_or_create_model(model_name: str, field_def): - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing - - # Register models for flask_restx to avoid dict type issues in Swagger -dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields) +dataset_base_model = get_or_create_model("DatasetBase", dataset_fields) -tag_model = _get_or_create_model("Tag", tag_fields) +tag_model = get_or_create_model("Tag", tag_fields) -keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields) -vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields) +keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields) +vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields) weighted_score_fields_copy = weighted_score_fields.copy() weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model) weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model) -weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy) +weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy) -reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields) +reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields) dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy() dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model) dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True) -dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy) +dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy) -external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields) +external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields) -external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields) +external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields) -doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields) +doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields) -icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields) +icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields) dataset_detail_fields_copy = dataset_detail_fields.copy() dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model) @@ -98,14 +92,22 @@ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_k dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True) dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model)) dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model) -dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy) +dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy) -dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields) +file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields) -app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields) +content_fields_copy = content_fields.copy() +content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True) +content_model = get_or_create_model("DatasetContent", content_fields_copy) + +dataset_query_detail_fields_copy = dataset_query_detail_fields.copy() +dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model) +dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy) + +app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields) related_app_list_copy = related_app_list.copy() related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model)) -related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy) +related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy) def _validate_indexing_technique(value: str | None) -> str | None: @@ -176,7 +178,18 @@ class IndexingEstimatePayload(BaseModel): return result -register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload) +class ConsoleDatasetListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + include_all: bool = Field(default=False, description="Include all datasets") + ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs") + tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs") + + +register_schema_models( + console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery +) def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]: @@ -275,18 +288,19 @@ class DatasetListApi(Resource): @enterprise_license_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - ids = request.args.getlist("ids") + query = ConsoleDatasetListQuery.model_validate(request.args.to_dict()) # provider = request.args.get("provider", default="vendor") - search = request.args.get("keyword", default=None, type=str) - tag_ids = request.args.getlist("tag_ids") - include_all = request.args.get("include_all", default="false").lower() == "true" - if ids: - datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id) + if query.ids: + datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id) else: datasets, total = DatasetService.get_datasets( - page, limit, current_tenant_id, current_user, search, tag_ids, include_all + query.page, + query.limit, + current_tenant_id, + current_user, + query.keyword, + query.tag_ids, + query.include_all, ) # check embedding setting @@ -318,7 +332,13 @@ class DatasetListApi(Resource): else: item.update({"partial_member_list": []}) - response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} + response = { + "data": data, + "has_more": len(datasets) == query.limit, + "limit": query.limit, + "total": total, + "page": query.page, + } return response, 200 @console_ns.doc("create_dataset") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e94768f985..57fb9abf29 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -2,17 +2,19 @@ import json import logging from argparse import ArgumentTypeError from collections.abc import Sequence -from typing import Literal, cast +from contextlib import ExitStack +from typing import Any, Literal, cast +from uuid import UUID import sqlalchemy as sa -from flask import request +from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel +from pydantic import BaseModel, Field from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -42,6 +44,7 @@ from models import DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel +from services.file_service import FileService from ..app.error import ( ProviderModelCurrentlyNotSupportError, @@ -65,35 +68,31 @@ from ..wraps import ( logger = logging.getLogger(__name__) - -def _get_or_create_model(model_name: str, field_def): - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing +# NOTE: Keep constants near the top of the module for discoverability. +DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 # Register models for flask_restx to avoid dict type issues in Swagger -dataset_model = _get_or_create_model("Dataset", dataset_fields) +dataset_model = get_or_create_model("Dataset", dataset_fields) -document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields) +document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) document_fields_copy = document_fields.copy() document_fields_copy["doc_metadata"] = fields.List( fields.Nested(document_metadata_model), attribute="doc_metadata_details" ) -document_model = _get_or_create_model("Document", document_fields_copy) +document_model = get_or_create_model("Document", document_fields_copy) document_with_segments_fields_copy = document_with_segments_fields.copy() document_with_segments_fields_copy["doc_metadata"] = fields.List( fields.Nested(document_metadata_model), attribute="doc_metadata_details" ) -document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) +document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) dataset_and_document_fields_copy = dataset_and_document_fields.copy() dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) -dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) +dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) class DocumentRetryPayload(BaseModel): @@ -104,6 +103,21 @@ class DocumentRenamePayload(BaseModel): name: str +class DocumentBatchDownloadZipPayload(BaseModel): + """Request payload for bulk downloading documents as a zip archive.""" + + document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) + + +class DocumentDatasetListParam(BaseModel): + page: int = Field(1, title="Page", description="Page number.") + limit: int = Field(20, title="Limit", description="Page size.") + search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.") + sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.") + status: str | None = Field(None, title="Status", description="Document status.") + fetch_val: str = Field("false", alias="fetch") + + register_schema_models( console_ns, KnowledgeConfig, @@ -111,6 +125,7 @@ register_schema_models( RetrievalModel, DocumentRetryPayload, DocumentRenamePayload, + DocumentBatchDownloadZipPayload, ) @@ -225,14 +240,16 @@ class DatasetDocumentListApi(Resource): def get(self, dataset_id): current_user, current_tenant_id = current_account_with_tenant() dataset_id = str(dataset_id) - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) - sort = request.args.get("sort", default="-created_at", type=str) - status = request.args.get("status", default=None, type=str) + raw_args = request.args.to_dict() + param = DocumentDatasetListParam.model_validate(raw_args) + page = param.page + limit = param.limit + search = param.search + sort = param.sort_by + status = param.status # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch_val = request.args.get("fetch", default="false") + fetch_val = param.fetch_val if isinstance(fetch_val, bool): fetch = fetch_val else: @@ -751,12 +768,12 @@ class DocumentApi(DocumentResource): elif metadata == "without": dataset_process_rules = DatasetService.get_process_rules(dataset_id) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} - data_source_info = document.data_source_detail_dict response = { "id": document.id, "position": document.position, "data_source_type": document.data_source_type, - "data_source_info": data_source_info, + "data_source_info": document.data_source_info_dict, + "data_source_detail_dict": document.data_source_detail_dict, "dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule": dataset_process_rules, "document_process_rule": document_process_rules, @@ -784,12 +801,12 @@ class DocumentApi(DocumentResource): else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} - data_source_info = document.data_source_detail_dict response = { "id": document.id, "position": document.position, "data_source_type": document.data_source_type, - "data_source_info": data_source_info, + "data_source_info": document.data_source_info_dict, + "data_source_detail_dict": document.data_source_detail_dict, "dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule": dataset_process_rules, "document_process_rule": document_process_rules, @@ -842,6 +859,62 @@ class DocumentApi(DocumentResource): return {"result": "success"}, 204 +@console_ns.route("/datasets//documents//download") +class DocumentDownloadApi(DocumentResource): + """Return a signed download URL for a dataset document's original uploaded file.""" + + @console_ns.doc("get_dataset_document_download_url") + @console_ns.doc(description="Get a signed download URL for a dataset document's original uploaded file") + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def get(self, dataset_id: str, document_id: str) -> dict[str, Any]: + # Reuse the shared permission/tenant checks implemented in DocumentResource. + document = self.get_document(str(dataset_id), str(document_id)) + return {"url": DocumentService.get_document_download_url(document)} + + +@console_ns.route("/datasets//documents/download-zip") +class DocumentBatchDownloadZipApi(DocumentResource): + """Download multiple uploaded-file documents as a single ZIP (avoids browser multi-download limits).""" + + @console_ns.doc("download_dataset_documents_as_zip") + @console_ns.doc(description="Download selected dataset documents as a single ZIP archive (upload-file only)") + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__]) + def post(self, dataset_id: str): + """Stream a ZIP archive containing the requested uploaded documents.""" + # Parse and validate request payload. + payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {}) + + current_user, current_tenant_id = current_account_with_tenant() + dataset_id = str(dataset_id) + document_ids: list[str] = [str(document_id) for document_id in payload.document_ids] + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset_id, + document_ids=document_ids, + tenant_id=current_tenant_id, + current_user=current_user, + ) + + # Delegate ZIP packing to FileService, but keep Flask response+cleanup in the route. + with ExitStack() as stack: + zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files)) + response = send_file( + zip_path, + mimetype="application/zip", + as_attachment=True, + download_name=download_name, + ) + cleanup = stack.pop_all() + response.call_on_close(cleanup.close) + return response + + @console_ns.route("/datasets//documents//processing/") class DocumentProcessingApi(DocumentResource): @console_ns.doc("update_document_processing") @@ -1098,7 +1171,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @marshal_with(document_fields) + @marshal_with(document_model) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index e73abc2555..08e1ddd3e0 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,10 +3,12 @@ import uuid from flask import request from flask_restx import Resource, marshal from pydantic import BaseModel, Field -from sqlalchemy import select +from sqlalchemy import String, cast, func, or_, select +from sqlalchemy.dialects.postgresql import JSONB from werkzeug.exceptions import Forbidden, NotFound import services +from configs import dify_config from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError @@ -28,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -87,6 +90,7 @@ register_schema_models( ChildChunkCreatePayload, ChildChunkUpdatePayload, ChildChunkBatchUpdatePayload, + ChildChunkUpdateArgs, ) @@ -143,7 +147,31 @@ class DatasetDocumentSegmentListApi(Resource): query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword) + # Search in both content and keywords fields + # Use database-specific methods for JSON array search + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + # PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text + keywords_condition = func.array_to_string( + func.array( + select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB))) + .correlate(DocumentSegment) + .scalar_subquery() + ), + ",", + ).ilike(f"%{escaped_keyword}%", escape="\\") + else: + # MySQL: Cast JSON to string for pattern matching + # MySQL stores Chinese text directly in JSON without Unicode escaping + keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\") + + query = query.where( + or_( + DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"), + keywords_condition, + ) + ) if args.enabled.lower() != "all": if args.enabled.lower() == "true": diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 89c9fcad36..86090bcd10 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required @@ -28,34 +28,27 @@ from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService -def _get_or_create_model(model_name: str, field_def): - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing - - def _build_dataset_detail_model(): - keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields) - vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields) + keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields) + vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields) weighted_score_fields_copy = weighted_score_fields.copy() weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model) weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model) - weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy) + weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy) - reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields) + reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields) dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy() dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model) dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True) - dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy) + dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy) - tag_model = _get_or_create_model("Tag", tag_fields) - doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields) - external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields) - external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields) - icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields) + tag_model = get_or_create_model("Tag", tag_fields) + doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields) + external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields) + external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields) + icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields) dataset_detail_fields_copy = dataset_detail_fields.copy() dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model) @@ -64,7 +57,7 @@ def _build_dataset_detail_model(): dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True) dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model)) dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model) - return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy) + return get_or_create_model("DatasetDetail", dataset_detail_fields_copy) try: @@ -81,7 +74,7 @@ class ExternalKnowledgeApiPayload(BaseModel): class ExternalDatasetCreatePayload(BaseModel): external_knowledge_api_id: str external_knowledge_id: str - name: str = Field(..., min_length=1, max_length=40) + name: str = Field(..., min_length=1, max_length=100) description: str | None = Field(None, max_length=400) external_retrieval_model: dict[str, object] | None = None @@ -98,12 +91,19 @@ class BedrockRetrievalPayload(BaseModel): knowledge_id: str +class ExternalApiTemplateListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + + register_schema_models( console_ns, ExternalKnowledgeApiPayload, ExternalDatasetCreatePayload, ExternalHitTestingPayload, BedrockRetrievalPayload, + ExternalApiTemplateListQuery, ) @@ -124,19 +124,17 @@ class ExternalApiTemplateListApi(Resource): @account_initialization_required def get(self): _, current_tenant_id = current_account_with_tenant() - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) + query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict()) external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( - page, limit, current_tenant_id, search + query.page, query.limit, current_tenant_id, query.keyword ) response = { "data": [item.to_dict() for item in external_knowledge_apis], - "has_more": len(external_knowledge_apis) == limit, - "limit": limit, + "has_more": len(external_knowledge_apis) == query.limit, + "limit": query.limit, "total": total, - "page": page, + "page": query.page, } return response, 200 diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index db7c50f422..db1a874437 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal, reqparse +from flask_restx import marshal from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -56,15 +56,10 @@ class DatasetsHitTestingBase: HitTestingService.hit_testing_args_check(args) @staticmethod - def parse_args(): - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, required=False, location="json") - .add_argument("attachment_ids", type=list, required=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - return parser.parse_args() + def parse_args(payload: dict[str, Any]) -> dict[str, Any]: + """Validate and return hit-testing arguments from an incoming payload.""" + hit_testing_payload = HitTestingPayload.model_validate(payload or {}) + return hit_testing_payload.model_dump(exclude_none=True) @staticmethod def perform_hit_testing(dataset, args): diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 8eead1696a..05fc4cd714 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -4,14 +4,16 @@ from flask_restx import Resource, marshal_with from pydantic import BaseModel from werkzeug.exceptions import NotFound -from controllers.common.schema import register_schema_model, register_schema_models +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from fields.dataset_fields import dataset_metadata_fields from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, MetadataArgs, + MetadataDetail, MetadataOperationData, ) from services.metadata_service import MetadataService @@ -21,8 +23,9 @@ class MetadataUpdatePayload(BaseModel): name: str -register_schema_models(console_ns, MetadataArgs, MetadataOperationData) -register_schema_model(console_ns, MetadataUpdatePayload) +register_schema_models( + console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail +) @console_ns.route("/datasets//metadata") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 720e2ce365..2911b1cf18 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -2,7 +2,7 @@ import logging from typing import Any, NoReturn from flask import Response, request -from flask_restx import Resource, fields, marshal, marshal_with +from flask_restx import Resource, marshal, marshal_with from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -14,7 +14,9 @@ from controllers.console.app.error import ( ) from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage] - _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage] + workflow_draft_variable_list_model, + workflow_draft_variable_list_without_value_model, + workflow_draft_variable_model, ) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required @@ -27,7 +29,6 @@ from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline -from models.workflow import WorkflowDraftVariable from services.rag_pipeline.rag_pipeline import RagPipelineService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService @@ -52,20 +53,6 @@ class WorkflowDraftVariablePatchPayload(BaseModel): register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) -def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: - return var_list.variables - - -_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { - "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), - "total": fields.Raw(), -} - -_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { - "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), -} - - def _api_prerequisite(f): """Common prerequisites for all draft workflow variable APIs. @@ -92,7 +79,7 @@ def _api_prerequisite(f): @console_ns.route("/rag/pipelines//workflows/draft/variables") class RagPipelineVariableCollectionApi(Resource): @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) + @marshal_with(workflow_draft_variable_list_without_value_model) def get(self, pipeline: Pipeline): """ Get draft workflow @@ -150,7 +137,7 @@ def validate_node_id(node_id: str) -> NoReturn | None: @console_ns.route("/rag/pipelines//workflows/draft/nodes//variables") class RagPipelineNodeVariableCollectionApi(Resource): @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + @marshal_with(workflow_draft_variable_list_model) def get(self, pipeline: Pipeline, node_id: str): validate_node_id(node_id) with Session(bind=db.engine, expire_on_commit=False) as session: @@ -176,7 +163,7 @@ class RagPipelineVariableApi(Resource): _PATCH_VALUE_FIELD = "value" @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + @marshal_with(workflow_draft_variable_model) def get(self, pipeline: Pipeline, variable_id: str): draft_var_srv = WorkflowDraftVariableService( session=db.session(), @@ -189,7 +176,7 @@ class RagPipelineVariableApi(Resource): return variable @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + @marshal_with(workflow_draft_variable_model) @console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__]) def patch(self, pipeline: Pipeline, variable_id: str): # Request payload for file types: @@ -307,7 +294,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList @console_ns.route("/rag/pipelines//workflows/draft/system-variables") class RagPipelineSystemVariableCollectionApi(Resource): @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + @marshal_with(workflow_draft_variable_list_model) def get(self, pipeline: Pipeline): return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index d43ee9a6e0..af142b4646 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,9 +1,9 @@ from flask import request -from flask_restx import Resource, marshal_with # type: ignore +from flask_restx import Resource, fields, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( @@ -12,7 +12,11 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields +from fields.rag_pipeline_fields import ( + leaked_dependency_fields, + pipeline_import_check_dependencies_fields, + pipeline_import_fields, +) from libs.login import current_account_with_tenant, login_required from models.dataset import Pipeline from services.app_dsl_service import ImportStatus @@ -38,13 +42,25 @@ class IncludeSecretQuery(BaseModel): register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery) +pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields) + +leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields) +pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy() +pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List( + fields.Nested(leaked_dependency_model) +) +pipeline_import_check_dependencies_model = get_or_create_model( + "RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy +) + + @console_ns.route("/rag/pipelines/imports") class RagPipelineImportApi(Resource): @setup_required @login_required @account_initialization_required @edit_permission_required - @marshal_with(pipeline_import_fields) + @marshal_with(pipeline_import_model) @console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__]) def post(self): # Check user role first @@ -81,7 +97,7 @@ class RagPipelineImportConfirmApi(Resource): @login_required @account_initialization_required @edit_permission_required - @marshal_with(pipeline_import_fields) + @marshal_with(pipeline_import_model) def post(self, import_id): current_user, _ = current_account_with_tenant() @@ -106,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource): @get_rag_pipeline @account_initialization_required @edit_permission_required - @marshal_with(pipeline_import_check_dependencies_fields) + @marshal_with(pipeline_import_check_dependencies_model) def get(self, pipeline: Pipeline): with Session(db.engine) as session: import_service = RagPipelineDslService(session) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 46d67f0581..d34fd5088d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -17,6 +17,13 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, DraftWorkflowNotSync, ) +from controllers.console.app.workflow import workflow_model, workflow_pagination_model +from controllers.console.app.workflow_run import ( + workflow_run_detail_model, + workflow_run_node_execution_list_model, + workflow_run_node_execution_model, + workflow_run_pagination_model, +) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, @@ -30,13 +37,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory -from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from fields.workflow_run_fields import ( - workflow_run_detail_fields, - workflow_run_node_execution_fields, - workflow_run_node_execution_list_fields, - workflow_run_pagination_fields, -) from libs import helper from libs.helper import TimestampField from libs.login import current_account_with_tenant, current_user, login_required @@ -145,7 +145,7 @@ class DraftRagPipelineApi(Resource): @account_initialization_required @get_rag_pipeline @edit_permission_required - @marshal_with(workflow_fields) + @marshal_with(workflow_model) def get(self, pipeline: Pipeline): """ Get draft rag pipeline's workflow @@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource): pipeline=pipeline, user=current_user, args=args, - invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE, streaming=streaming, ) @@ -521,7 +521,7 @@ class RagPipelineDraftNodeRunApi(Resource): @edit_permission_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_fields) + @marshal_with(workflow_run_node_execution_model) def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow node @@ -569,7 +569,7 @@ class PublishedRagPipelineApi(Resource): @account_initialization_required @edit_permission_required @get_rag_pipeline - @marshal_with(workflow_fields) + @marshal_with(workflow_model) def get(self, pipeline: Pipeline): """ Get published pipeline @@ -664,7 +664,7 @@ class PublishedAllRagPipelineApi(Resource): @account_initialization_required @edit_permission_required @get_rag_pipeline - @marshal_with(workflow_pagination_fields) + @marshal_with(workflow_pagination_model) def get(self, pipeline: Pipeline): """ Get published workflows @@ -708,7 +708,7 @@ class RagPipelineByIdApi(Resource): @account_initialization_required @edit_permission_required @get_rag_pipeline - @marshal_with(workflow_fields) + @marshal_with(workflow_model) def patch(self, pipeline: Pipeline, workflow_id: str): """ Update workflow attributes @@ -830,7 +830,7 @@ class RagPipelineWorkflowRunListApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_pagination_fields) + @marshal_with(workflow_run_pagination_model) def get(self, pipeline: Pipeline): """ Get workflow run list @@ -858,7 +858,7 @@ class RagPipelineWorkflowRunDetailApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_detail_fields) + @marshal_with(workflow_run_detail_model) def get(self, pipeline: Pipeline, run_id): """ Get workflow run detail @@ -877,7 +877,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_list_fields) + @marshal_with(workflow_run_node_execution_list_model) def get(self, pipeline: Pipeline, run_id: str): """ Get workflow run node execution list @@ -911,7 +911,7 @@ class RagPipelineWorkflowLastRunApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_fields) + @marshal_with(workflow_run_node_execution_model) def get(self, pipeline: Pipeline, node_id: str): rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) @@ -952,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource): @account_initialization_required @get_rag_pipeline @edit_permission_required - @marshal_with(workflow_run_node_execution_fields) + @marshal_with(workflow_run_node_execution_model) def post(self, pipeline: Pipeline): """ Set datasource variables diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py new file mode 100644 index 0000000000..da306fbc9d --- /dev/null +++ b/api/controllers/console/explore/banner.py @@ -0,0 +1,43 @@ +from flask import request +from flask_restx import Resource + +from controllers.console import api +from controllers.console.explore.wraps import explore_banner_enabled +from extensions.ext_database import db +from models.model import ExporleBanner + + +class BannerApi(Resource): + """Resource for banner list.""" + + @explore_banner_enabled + def get(self): + """Get banner list.""" + language = request.args.get("language", "en-US") + + # Build base query for enabled banners + base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled") + + # Try to get banners in the requested language + banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all() + + # Fallback to en-US if no banners found and language is not en-US + if not banners and language != "en-US": + banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all() + # Convert banners to serializable format + result = [] + for banner in banners: + banner_data = { + "id": banner.id, + "content": banner.content, # Already parsed as JSON by SQLAlchemy + "link": banner.link, + "sort": banner.sort, + "status": banner.status, + "created_at": banner.created_at.isoformat() if banner.created_at else None, + } + result.append(banner_data) + + return result + + +api.add_resource(BannerApi, "/explore/banners") diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 51995b8b8a..933c80f509 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,8 +1,7 @@ from typing import Any from flask import request -from flask_restx import marshal_with -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, TypeAdapter, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -11,7 +10,11 @@ from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from fields.conversation_fields import ( + ConversationInfiniteScrollPagination, + ResultResponse, + SimpleConversation, +) from libs.helper import UUIDStrOrEmpty from libs.login import current_user from models import Account @@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl endpoint="installed_app_conversations", ) class ConversationListApi(InstalledAppResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) @console_ns.expect(console_ns.models[ConversationListQuery.__name__]) def get(self, installed_app): app_model = installed_app.app @@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource): if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") with Session(db.engine) as session: - return WebConversationService.pagination_by_last_id( + pagination = WebConversationService.pagination_by_last_id( session=session, app_model=app_model, user=current_user, @@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource): invoke_from=InvokeFrom.EXPLORE, pinned=args.pinned, ) + adapter = TypeAdapter(SimpleConversation) + conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] + return ConversationInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=conversations, + ).model_dump(mode="json") except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 @console_ns.route( @@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource): endpoint="installed_app_conversation_rename", ) class ConversationRenameApi(InstalledAppResource): - @marshal_with(simple_conversation_fields) @console_ns.expect(console_ns.models[ConversationRenamePayload.__name__]) def post(self, installed_app, c_id): app_model = installed_app.app @@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource): try: if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") - return ConversationService.rename( + conversation = ConversationService.rename( app_model, conversation_id, current_user, payload.name, payload.auto_generate ) + return ( + TypeAdapter(SimpleConversation) + .validate_python(conversation, from_attributes=True) + .model_dump(mode="json") + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @console_ns.route( @@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource): raise ValueError("current_user must be an Account instance") WebConversationService.unpin(app_model, conversation_id, current_user) - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 1e05ff4206..e96fa64f84 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException): error_code = "access_denied" description = "App access denied." code = 403 + + +class TrialAppNotAllowed(BaseHTTPException): + """*403* `Trial App Not Allowed` + + Raise if the user has reached the trial app limit. + """ + + error_code = "trial_app_not_allowed" + code = 403 + description = "the app is not allowed to be trial." + + +class TrialAppLimitExceeded(BaseHTTPException): + """*403* `Trial App Limit Exceeded` + + Raise if the user has exceeded the trial app limit. + """ + + error_code = "trial_app_limit_exceeded" + code = 403 + description = "The user has exceeded the trial app limit." diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index e42db10ba6..aca766567f 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,16 +2,17 @@ import logging from typing import Any from flask import request -from flask_restx import Resource, marshal_with -from pydantic import BaseModel +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound +from controllers.common.schema import get_or_create_model from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from fields.installed_app_fields import installed_app_list_fields +from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp @@ -28,22 +29,37 @@ class InstalledAppUpdatePayload(BaseModel): is_pinned: bool | None = None +class InstalledAppsListQuery(BaseModel): + app_id: str | None = Field(default=None, description="App ID to filter by") + + logger = logging.getLogger(__name__) +app_model = get_or_create_model("InstalledAppInfo", app_fields) + +installed_app_fields_copy = installed_app_fields.copy() +installed_app_fields_copy["app"] = fields.Nested(app_model) +installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy) + +installed_app_list_fields_copy = installed_app_list_fields.copy() +installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model)) +installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy) + + @console_ns.route("/installed-apps") class InstalledAppsListApi(Resource): @login_required @account_initialization_required - @marshal_with(installed_app_list_fields) + @marshal_with(installed_app_list_model) def get(self): - app_id = request.args.get("app_id", default=None, type=str) + query = InstalledAppsListQuery.model_validate(request.args.to_dict()) current_user, current_tenant_id = current_account_with_tenant() - if app_id: + if query.app_id: installed_apps = db.session.scalars( select(InstalledApp).where( - and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) + and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id) ) ).all() else: diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 229b7c8865..88487ac96f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,10 +1,8 @@ import logging from typing import Literal -from uuid import UUID from flask import request -from flask_restx import marshal_with -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -24,8 +22,10 @@ from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from fields.message_fields import message_infinite_scroll_pagination_fields +from fields.conversation_fields import ResultResponse +from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from libs import helper +from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -44,8 +44,8 @@ logger = logging.getLogger(__name__) class MessageListQuery(BaseModel): - conversation_id: UUID - first_id: UUID | None = None + conversation_id: UUIDStrOrEmpty + first_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) @@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor endpoint="installed_app_messages", ) class MessageListApi(InstalledAppResource): - @marshal_with(message_infinite_scroll_pagination_fields) @console_ns.expect(console_ns.models[MessageListQuery.__name__]) def get(self, installed_app): current_user, _ = current_account_with_tenant() @@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource): args = MessageListQuery.model_validate(request.args.to_dict()) try: - return MessageService.pagination_by_first_id( + pagination = MessageService.pagination_by_first_id( app_model, current_user, str(args.conversation_id), str(args.first_id) if args.first_id else None, args.limit, ) + adapter = TypeAdapter(MessageListItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return MessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except FirstMessageNotExistsError: @@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource): except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @console_ns.route( @@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource): logger.exception("internal server error.") raise InternalServerError() - return {"data": questions} + return SuggestedQuestionsResponse(data=questions).model_dump(mode="json") diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 9c6b2aedfb..660a4d5aea 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,5 +1,3 @@ -from flask_restx import marshal_with - from controllers.common import fields from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError @@ -13,7 +11,6 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" - @marshal_with(fields.parameters_fields) def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app @@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return fields.Parameters.model_validate(parameters).model_dump(mode="json") @console_ns.route("/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 2b2f807694..c9920c97cf 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -3,6 +3,7 @@ from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from constants.languages import languages +from controllers.common.schema import get_or_create_model from controllers.console import console_ns from controllers.console.wraps import account_initialization_required from libs.helper import AppIconUrlField @@ -19,8 +20,10 @@ app_fields = { "icon_background": fields.String, } +app_model = get_or_create_model("RecommendedAppInfo", app_fields) + recommended_app_fields = { - "app": fields.Nested(app_fields, attribute="app"), + "app": fields.Nested(app_model, attribute="app"), "app_id": fields.String, "description": fields.String(attribute="description"), "copyright": fields.String, @@ -29,13 +32,18 @@ recommended_app_fields = { "category": fields.String, "position": fields.Integer, "is_listed": fields.Boolean, + "can_trial": fields.Boolean, } +recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields) + recommended_app_list_fields = { - "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), + "recommended_apps": fields.List(fields.Nested(recommended_app_model)), "categories": fields.List(fields.String), } +recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields) + class RecommendedAppsQuery(BaseModel): language: str | None = Field(default=None) @@ -52,7 +60,7 @@ class RecommendedAppListApi(Resource): @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__]) @login_required @account_initialization_required - @marshal_with(recommended_app_list_fields) + @marshal_with(recommended_app_list_model) def get(self): # language args args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 6a9e274a0e..ea3de91741 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,55 +1,33 @@ -from uuid import UUID - from flask import request -from flask_restx import fields, marshal_with -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import NotFound from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField +from fields.conversation_fields import ResultResponse +from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem +from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService class SavedMessageListQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) class SavedMessageCreatePayload(BaseModel): - message_id: UUID + message_id: UUIDStrOrEmpty register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) -feedback_fields = {"rating": fields.String} - -message_fields = { - "id": fields.String, - "inputs": fields.Raw, - "query": fields.String, - "answer": fields.String, - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), - "created_at": TimestampField, -} - - @console_ns.route("/installed-apps//saved-messages", endpoint="installed_app_saved_messages") class SavedMessageListApi(InstalledAppResource): - saved_message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), - } - - @marshal_with(saved_message_infinite_scroll_pagination_fields) @console_ns.expect(console_ns.models[SavedMessageListQuery.__name__]) def get(self, installed_app): current_user, _ = current_account_with_tenant() @@ -59,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource): args = SavedMessageListQuery.model_validate(request.args.to_dict()) - return SavedMessageService.pagination_by_last_id( + pagination = SavedMessageService.pagination_by_last_id( app_model, current_user, str(args.last_id) if args.last_id else None, args.limit, ) + adapter = TypeAdapter(SavedMessageItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return SavedMessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") @console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__]) def post(self, installed_app): @@ -80,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource): except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @console_ns.route( @@ -98,4 +83,4 @@ class SavedMessageApi(InstalledAppResource): SavedMessageService.delete(app_model, current_user, message_id) - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py new file mode 100644 index 0000000000..1eb0cdb019 --- /dev/null +++ b/api/controllers/console/explore/trial.py @@ -0,0 +1,555 @@ +import logging +from typing import Any, cast + +from flask import request +from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from controllers.common.fields import Parameters as ParametersResponse +from controllers.common.fields import Site as SiteResponse +from controllers.common.schema import get_or_create_model +from controllers.console import api, console_ns +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + ConversationCompletedError, + NeedAddIdsError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) +from controllers.console.app.wraps import get_app_model_with_trial +from controllers.console.explore.error import ( + AppSuggestedQuestionsAfterAnswerDisabledError, + NotChatAppError, + NotCompletionAppError, + NotWorkflowAppError, +) +from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from core.model_runtime.errors.invoke import InvokeError +from core.workflow.graph_engine.manager import GraphEngineManager +from extensions.ext_database import db +from fields.app_fields import ( + app_detail_fields_with_site, + deleted_tool_fields, + model_config_fields, + site_fields, + tag_fields, +) +from fields.dataset_fields import dataset_fields +from fields.member_fields import build_simple_account_model +from fields.workflow_fields import ( + conversation_variable_fields, + pipeline_variable_fields, + workflow_fields, + workflow_partial_fields, +) +from libs import helper +from libs.helper import uuid_value +from libs.login import current_user +from models import Account +from models.account import TenantStatus +from models.model import AppMode, Site +from models.workflow import Workflow +from services.app_generate_service import AppGenerateService +from services.app_service import AppService +from services.audio_service import AudioService +from services.dataset_service import DatasetService +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) +from services.errors.conversation import ConversationNotExistsError +from services.errors.llm import InvokeRateLimitError +from services.errors.message import ( + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) +from services.message_service import MessageService +from services.recommended_app_service import RecommendedAppService + +logger = logging.getLogger(__name__) + + +model_config_model = get_or_create_model("TrialAppModelConfig", model_config_fields) +workflow_partial_model = get_or_create_model("TrialWorkflowPartial", workflow_partial_fields) +deleted_tool_model = get_or_create_model("TrialDeletedTool", deleted_tool_fields) +tag_model = get_or_create_model("TrialTag", tag_fields) +site_model = get_or_create_model("TrialSite", site_fields) + +app_detail_fields_with_site_copy = app_detail_fields_with_site.copy() +app_detail_fields_with_site_copy["model_config"] = fields.Nested( + model_config_model, attribute="app_model_config", allow_null=True +) +app_detail_fields_with_site_copy["workflow"] = fields.Nested(workflow_partial_model, allow_null=True) +app_detail_fields_with_site_copy["deleted_tools"] = fields.List(fields.Nested(deleted_tool_model)) +app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model)) +app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) +app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) + +simple_account_model = build_simple_account_model(console_ns) +conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) +pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) + +workflow_fields_copy = workflow_fields.copy() +workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account") +workflow_fields_copy["updated_by"] = fields.Nested( + simple_account_model, attribute="updated_by_account", allow_null=True +) +workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model)) +workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model)) +workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy) + + +class TrialAppWorkflowRunApi(TrialAppResource): + def post(self, trial_app): + """ + Run workflow + """ + app_model = trial_app + if not app_model: + raise NotWorkflowAppError() + app_mode = AppMode.value_of(app_model.mode) + if app_mode != AppMode.WORKFLOW: + raise NotWorkflowAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + args = parser.parse_args() + assert current_user is not None + try: + app_id = app_model.id + user_id = current_user.id + response = AppGenerateService.generate( + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True + ) + RecommendedAppService.add_trial_app_record(app_id, user_id) + return helper.compact_generate_response(response) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialAppWorkflowTaskStopApi(TrialAppResource): + def post(self, trial_app, task_id: str): + """ + Stop workflow task + """ + app_model = trial_app + if not app_model: + raise NotWorkflowAppError() + app_mode = AppMode.value_of(app_model.mode) + if app_mode != AppMode.WORKFLOW: + raise NotWorkflowAppError() + assert current_user is not None + + # Stop using both mechanisms for backward compatibility + # Legacy stop flag mechanism (without user check) + AppQueueManager.set_stop_flag_no_user_check(task_id) + + # New graph engine command channel mechanism + GraphEngineManager.send_stop_command(task_id) + + return {"result": "success"} + + +class TrialChatApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + args = parser.parse_args() + + args["auto_generate_name"] = False + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + + # Get IDs before they might be detached from session + app_id = app_model.id + user_id = current_user.id + + response = AppGenerateService.generate( + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True + ) + RecommendedAppService.add_trial_app_record(app_id, user_id) + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialMessageSuggestedQuestionApi(TrialAppResource): + @trial_feature_enable + def get(self, trial_app, message_id): + app_model = trial_app + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + message_id = str(message_id) + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + questions = MessageService.get_suggested_questions_after_answer( + app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE + ) + except MessageNotExistsError: + raise NotFound("Message not found") + except ConversationNotExistsError: + raise NotFound("Conversation not found") + except SuggestedQuestionsAfterAnswerDisabledError: + raise AppSuggestedQuestionsAfterAnswerDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + return {"data": questions} + + +class TrialChatAudioApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + + file = request.files["file"] + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + + # Get IDs before they might be detached from session + app_id = app_model.id + user_id = current_user.id + + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) + RecommendedAppService.add_trial_app_record(app_id, user_id) + return response + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except NoAudioUploadedServiceError: + raise NoAudioUploadedError() + except AudioTooLargeServiceError as e: + raise AudioTooLargeError(str(e)) + except UnsupportedAudioTypeServiceError: + raise UnsupportedAudioTypeError() + except ProviderNotSupportSpeechToTextServiceError: + raise ProviderNotSupportSpeechToTextError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception as e: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialChatTextApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + try: + parser = reqparse.RequestParser() + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") + args = parser.parse_args() + + message_id = args.get("message_id", None) + text = args.get("text", None) + voice = args.get("voice", None) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + + # Get IDs before they might be detached from session + app_id = app_model.id + user_id = current_user.id + + response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + RecommendedAppService.add_trial_app_record(app_id, user_id) + return response + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except NoAudioUploadedServiceError: + raise NoAudioUploadedError() + except AudioTooLargeServiceError as e: + raise AudioTooLargeError(str(e)) + except UnsupportedAudioTypeServiceError: + raise UnsupportedAudioTypeError() + except ProviderNotSupportSpeechToTextServiceError: + raise ProviderNotSupportSpeechToTextError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception as e: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialCompletionApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + if app_model.mode != "completion": + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + args = parser.parse_args() + + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + + # Get IDs before they might be detached from session + app_id = app_model.id + user_id = current_user.id + + response = AppGenerateService.generate( + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming + ) + + RecommendedAppService.add_trial_app_record(app_id, user_id) + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialSitApi(Resource): + """Resource for trial app sites.""" + + @trial_feature_enable + @get_app_model_with_trial + def get(self, app_model): + """Retrieve app site info. + + Returns the site configuration for the application including theme, icons, and text. + """ + site = db.session.query(Site).where(Site.app_id == app_model.id).first() + + if not site: + raise Forbidden() + + assert app_model.tenant + if app_model.tenant.status == TenantStatus.ARCHIVE: + raise Forbidden() + + return SiteResponse.model_validate(site).model_dump(mode="json") + + +class TrialAppParameterApi(Resource): + """Resource for app variables.""" + + @trial_feature_enable + @get_app_model_with_trial + def get(self, app_model): + """Retrieve app parameters.""" + + if app_model is None: + raise AppUnavailableError() + + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form(to_old_structure=True) + else: + app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get("user_input_form", []) + + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return ParametersResponse.model_validate(parameters).model_dump(mode="json") + + +class AppApi(Resource): + @trial_feature_enable + @get_app_model_with_trial + @marshal_with(app_detail_with_site_model) + def get(self, app_model): + """Get app detail""" + + app_service = AppService() + app_model = app_service.get_app(app_model) + + return app_model + + +class AppWorkflowApi(Resource): + @trial_feature_enable + @get_app_model_with_trial + @marshal_with(workflow_model) + def get(self, app_model): + """Get workflow detail""" + if not app_model.workflow_id: + raise AppUnavailableError() + + workflow = ( + db.session.query(Workflow) + .where( + Workflow.id == app_model.workflow_id, + ) + .first() + ) + return workflow + + +class DatasetListApi(Resource): + @trial_feature_enable + @get_app_model_with_trial + def get(self, app_model): + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + ids = request.args.getlist("ids") + + tenant_id = app_model.tenant_id + if ids: + datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id) + else: + raise NeedAddIdsError() + + data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields)) + + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} + return response + + +api.add_resource(TrialChatApi, "/trial-apps//chat-messages", endpoint="trial_app_chat_completion") + +api.add_resource( + TrialMessageSuggestedQuestionApi, + "/trial-apps//messages//suggested-questions", + endpoint="trial_app_suggested_question", +) + +api.add_resource(TrialChatAudioApi, "/trial-apps//audio-to-text", endpoint="trial_app_audio") +api.add_resource(TrialChatTextApi, "/trial-apps//text-to-audio", endpoint="trial_app_text") + +api.add_resource(TrialCompletionApi, "/trial-apps//completion-messages", endpoint="trial_app_completion") + +api.add_resource(TrialSitApi, "/trial-apps//site") + +api.add_resource(TrialAppParameterApi, "/trial-apps//parameters", endpoint="trial_app_parameters") + +api.add_resource(AppApi, "/trial-apps/", endpoint="trial_app") + +api.add_resource(TrialAppWorkflowRunApi, "/trial-apps//workflows/run", endpoint="trial_app_workflow_run") +api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps//workflows/tasks//stop") + +api.add_resource(AppWorkflowApi, "/trial-apps//workflows", endpoint="trial_app_workflow") +api.add_resource(DatasetListApi, "/trial-apps//datasets", endpoint="trial_app_datasets") diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 2a97d312aa..38f0a04904 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -2,14 +2,15 @@ from collections.abc import Callable from functools import wraps from typing import Concatenate, ParamSpec, TypeVar +from flask import abort from flask_restx import Resource from werkzeug.exceptions import NotFound -from controllers.console.explore.error import AppAccessDeniedError +from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required -from models import InstalledApp +from models import AccountTrialAppRecord, App, InstalledApp, TrialApp from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | return decorator +def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): + def decorator(view: Callable[Concatenate[App, P], R]): + @wraps(view) + def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs): + current_user, _ = current_account_with_tenant() + + trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first() + + if trial_app is None: + raise TrialAppNotAllowed() + app = trial_app.app + + if app is None: + raise TrialAppNotAllowed() + + account_trial_app_record = ( + db.session.query(AccountTrialAppRecord) + .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id) + .first() + ) + if account_trial_app_record: + if account_trial_app_record.count >= trial_app.trial_limit: + raise TrialAppLimitExceeded() + + return view(app, *args, **kwargs) + + return decorated + + if view: + return decorator(view) + return decorator + + +def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]: + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if not features.enable_trial_app: + abort(403, "Trial app feature is not enabled.") + return view(*args, **kwargs) + + return decorated + + +def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]: + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if not features.enable_explore_banner: + abort(403, "Explore banner feature is not enabled.") + return view(*args, **kwargs) + + return decorated + + class InstalledAppResource(Resource): # must be reversed if there are multiple decorators @@ -80,3 +136,13 @@ class InstalledAppResource(Resource): account_initialization_required, login_required, ] + + +class TrialAppResource(Resource): + # must be reversed if there are multiple decorators + + method_decorators = [ + trial_app_required, + account_initialization_required, + login_required, + ] diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 6951c906e9..d3811e2d1b 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,6 +1,7 @@ from flask_restx import Resource, fields +from werkzeug.exceptions import Unauthorized -from libs.login import current_account_with_tenant, login_required +from libs.login import current_account_with_tenant, current_user, login_required from services.feature_service import FeatureService from . import console_ns @@ -39,5 +40,21 @@ class SystemFeatureApi(Resource): ), ) def get(self): - """Get system-wide feature configuration""" - return FeatureService.get_system_features().model_dump() + """Get system-wide feature configuration + + NOTE: This endpoint is unauthenticated by design, as it provides system features + data required for dashboard initialization. + + Authentication would create circular dependency (can't login without dashboard loading). + + Only non-sensitive configuration data should be returned by this endpoint. + """ + # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` + # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` + # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will + # raise `Unauthorized` exception if authentication token is not provided. + try: + is_authenticated = current_user.is_authenticated + except Unauthorized: + is_authenticated = False + return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump() diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 29417dc896..109a3cd0d3 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,7 +1,7 @@ from typing import Literal from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource from werkzeug.exceptions import Forbidden import services @@ -15,18 +15,21 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, setup_required, ) from extensions.ext_database import db -from fields.file_fields import file_fields, upload_config_fields +from fields.file_fields import FileResponse, UploadConfig from libs.login import current_account_with_tenant, login_required from services.file_service import FileService from . import console_ns +register_schema_models(console_ns, UploadConfig, FileResponse) + PREVIEW_WORDS_LIMIT = 3000 @@ -35,26 +38,27 @@ class FileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(upload_config_fields) + @console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__]) def get(self): - return { - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT, - "file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT, - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, - "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, - "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, - "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, - }, 200 + config = UploadConfig( + file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT, + batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT, + file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT, + image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT, + single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, + ) + return config.model_dump(mode="json"), 200 @setup_required @login_required @account_initialization_required - @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") + @console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() source_str = request.form.get("source") @@ -90,7 +94,8 @@ class FileApi(Resource): except services.errors.file.BlockedFileExtensionError as blocked_extension_error: raise BlockedFileExtensionError(blocked_extension_error.description) - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 @console_ns.route("/files//preview") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 25a3d80522..d480af312b 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,17 +1,17 @@ -from flask_restx import Resource, fields +from pydantic import BaseModel, Field -from . import console_ns +from controllers.fastopenapi import console_router -@console_ns.route("/ping") -class PingApi(Resource): - @console_ns.doc("health_check") - @console_ns.doc(description="Health check endpoint for connection testing") - @console_ns.response( - 200, - "Success", - console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}), - ) - def get(self): - """Health check endpoint for connection testing""" - return {"result": "pong"} +class PingResponse(BaseModel): + result: str = Field(description="Health check result", examples=["pong"]) + + +@console_router.get( + "/ping", + response_model=PingResponse, + tags=["console"], +) +def ping() -> PingResponse: + """Health check endpoint for connection testing.""" + return PingResponse(result="pong") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 47eef7eb7e..70c7b80ffa 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,7 +1,7 @@ import urllib.parse import httpx -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field import services @@ -11,19 +11,22 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db -from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant from services.file_service import FileService from . import console_ns +register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl) + @console_ns.route("/remote-files/") class RemoteFileInfoApi(Resource): - @marshal_with(remote_file_info_fields) + @console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__]) def get(self, url): decoded_url = urllib.parse.unquote(url) resp = ssrf_proxy.head(decoded_url) @@ -31,10 +34,11 @@ class RemoteFileInfoApi(Resource): # failed back to get method resp = ssrf_proxy.get(decoded_url, timeout=3) resp.raise_for_status() - return { - "file_type": resp.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(resp.headers.get("Content-Length", 0)), - } + info = RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", 0)), + ) + return info.model_dump(mode="json") class RemoteFileUploadPayload(BaseModel): @@ -50,7 +54,7 @@ console_ns.schema_model( @console_ns.route("/remote-files/upload") class RemoteFileUploadApi(Resource): @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) - @marshal_with(file_fields_with_signed_url) + @console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__]) def post(self): args = RemoteFileUploadPayload.model_validate(console_ns.payload) url = args.url @@ -85,13 +89,14 @@ class RemoteFileUploadApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at, - }, 201 + payload = FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) + return payload.model_dump(mode="json"), 201 diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 7fa02ae280..e1ea007232 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,20 +1,19 @@ +from typing import Literal + from flask import request -from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator from configs import dify_config +from controllers.fastopenapi import console_router from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models.model import DifySetup, db from services.account_service import RegisterService, TenantService -from . import console_ns from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class SetupRequestPayload(BaseModel): email: EmailStr = Field(..., description="Admin email address") @@ -28,77 +27,66 @@ class SetupRequestPayload(BaseModel): return valid_password(value) -console_ns.schema_model( - SetupRequestPayload.__name__, - SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +class SetupStatusResponse(BaseModel): + step: Literal["not_started", "finished"] = Field(description="Setup step status") + setup_at: str | None = Field(default=None, description="Setup completion time (ISO format)") + + +class SetupResponse(BaseModel): + result: str = Field(description="Setup result", examples=["success"]) + + +@console_router.get( + "/setup", + response_model=SetupStatusResponse, + tags=["console"], ) +def get_setup_status_api() -> SetupStatusResponse: + """Get system setup status.""" + if dify_config.EDITION == "SELF_HOSTED": + setup_status = get_setup_status() + if setup_status and not isinstance(setup_status, bool): + return SetupStatusResponse(step="finished", setup_at=setup_status.setup_at.isoformat()) + if setup_status: + return SetupStatusResponse(step="finished") + return SetupStatusResponse(step="not_started") + return SetupStatusResponse(step="finished") -@console_ns.route("/setup") -class SetupApi(Resource): - @console_ns.doc("get_setup_status") - @console_ns.doc(description="Get system setup status") - @console_ns.response( - 200, - "Success", - console_ns.model( - "SetupStatusResponse", - { - "step": fields.String(description="Setup step status", enum=["not_started", "finished"]), - "setup_at": fields.String(description="Setup completion time (ISO format)", required=False), - }, - ), +@console_router.post( + "/setup", + response_model=SetupResponse, + tags=["console"], + status_code=201, +) +@only_edition_self_hosted +def setup_system(payload: SetupRequestPayload) -> SetupResponse: + """Initialize system setup with admin account.""" + if get_setup_status(): + raise AlreadySetupError() + + tenant_count = TenantService.get_tenant_count() + if tenant_count > 0: + raise AlreadySetupError() + + if not get_init_validate_status(): + raise NotInitValidateError() + + normalized_email = payload.email.lower() + + RegisterService.setup( + email=normalized_email, + name=payload.name, + password=payload.password, + ip_address=extract_remote_ip(request), + language=payload.language, ) - def get(self): - """Get system setup status""" - if dify_config.EDITION == "SELF_HOSTED": - setup_status = get_setup_status() - # Check if setup_status is a DifySetup object rather than a bool - if setup_status and not isinstance(setup_status, bool): - return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} - elif setup_status: - return {"step": "finished"} - return {"step": "not_started"} - return {"step": "finished"} - @console_ns.doc("setup_system") - @console_ns.doc(description="Initialize system setup with admin account") - @console_ns.expect(console_ns.models[SetupRequestPayload.__name__]) - @console_ns.response( - 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")}) - ) - @console_ns.response(400, "Already setup or validation failed") - @only_edition_self_hosted - def post(self): - """Initialize system setup with admin account""" - # is set up - if get_setup_status(): - raise AlreadySetupError() - - # is tenant created - tenant_count = TenantService.get_tenant_count() - if tenant_count > 0: - raise AlreadySetupError() - - if not get_init_validate_status(): - raise NotInitValidateError() - - args = SetupRequestPayload.model_validate(console_ns.payload) - - # setup - RegisterService.setup( - email=args.email, - name=args.name, - password=args.password, - ip_address=extract_remote_ip(request), - language=args.language, - ) - - return {"result": "success"}, 201 + return SetupResponse(result="success") -def get_setup_status(): +def get_setup_status() -> DifySetup | bool | None: if dify_config.EDITION == "SELF_HOSTED": return db.session.query(DifySetup).first() - else: - return True + + return True diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index e9fbb515e4..9988524a80 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -30,11 +30,17 @@ class TagBindingRemovePayload(BaseModel): type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") +class TagListQueryParam(BaseModel): + type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter") + keyword: str | None = Field(None, description="Search keyword") + + register_schema_models( console_ns, TagBasePayload, TagBindingPayload, TagBindingRemovePayload, + TagListQueryParam, ) @@ -43,12 +49,15 @@ class TagListApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.doc( + params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} + ) @marshal_with(dataset_tag_fields) def get(self): _, current_tenant_id = current_account_with_tenant() - tag_type = request.args.get("type", type=str, default="") - keyword = request.args.get("keyword", default=None, type=str) - tags = TagService.get_tags(tag_type, current_tenant_id, keyword) + raw_args = request.args.to_dict() + param = TagListQueryParam.model_validate(raw_args) + tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) return tags, 200 diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 419261ba2a..fdb23acf52 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,15 +1,11 @@ -import json import logging import httpx -from flask import request -from flask_restx import Resource, fields from packaging import version from pydantic import BaseModel, Field from configs import dify_config - -from . import console_ns +from controllers.fastopenapi import console_router logger = logging.getLogger(__name__) @@ -18,69 +14,61 @@ class VersionQuery(BaseModel): current_version: str = Field(..., description="Current application version") -console_ns.schema_model( - VersionQuery.__name__, - VersionQuery.model_json_schema(ref_template="#/definitions/{model}"), +class VersionFeatures(BaseModel): + can_replace_logo: bool = Field(description="Whether logo replacement is supported") + model_load_balancing_enabled: bool = Field(description="Whether model load balancing is enabled") + + +class VersionResponse(BaseModel): + version: str = Field(description="Latest version number") + release_date: str = Field(description="Release date of latest version") + release_notes: str = Field(description="Release notes for latest version") + can_auto_update: bool = Field(description="Whether auto-update is supported") + features: VersionFeatures = Field(description="Feature flags and capabilities") + + +@console_router.get( + "/version", + response_model=VersionResponse, + tags=["console"], ) +def check_version_update(query: VersionQuery) -> VersionResponse: + """Check for application version updates.""" + check_update_url = dify_config.CHECK_UPDATE_URL - -@console_ns.route("/version") -class VersionApi(Resource): - @console_ns.doc("check_version_update") - @console_ns.doc(description="Check for application version updates") - @console_ns.expect(console_ns.models[VersionQuery.__name__]) - @console_ns.response( - 200, - "Success", - console_ns.model( - "VersionResponse", - { - "version": fields.String(description="Latest version number"), - "release_date": fields.String(description="Release date of latest version"), - "release_notes": fields.String(description="Release notes for latest version"), - "can_auto_update": fields.Boolean(description="Whether auto-update is supported"), - "features": fields.Raw(description="Feature flags and capabilities"), - }, + result = VersionResponse( + version=dify_config.project.version, + release_date="", + release_notes="", + can_auto_update=False, + features=VersionFeatures( + can_replace_logo=dify_config.CAN_REPLACE_LOGO, + model_load_balancing_enabled=dify_config.MODEL_LB_ENABLED, ), ) - def get(self): - """Check for application version updates""" - args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - check_update_url = dify_config.CHECK_UPDATE_URL - result = { - "version": dify_config.project.version, - "release_date": "", - "release_notes": "", - "can_auto_update": False, - "features": { - "can_replace_logo": dify_config.CAN_REPLACE_LOGO, - "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, - }, - } - - if not check_update_url: - return result - - try: - response = httpx.get( - check_update_url, - params={"current_version": args.current_version}, - timeout=httpx.Timeout(timeout=10.0, connect=3.0), - ) - except Exception as error: - logger.warning("Check update version error: %s.", str(error)) - result["version"] = args.current_version - return result - - content = json.loads(response.content) - if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"): - result["version"] = content["version"] - result["release_date"] = content["releaseDate"] - result["release_notes"] = content["releaseNotes"] - result["can_auto_update"] = content["canAutoUpdate"] + if not check_update_url: return result + try: + response = httpx.get( + check_update_url, + params={"current_version": query.current_version}, + timeout=httpx.Timeout(timeout=10.0, connect=3.0), + ) + content = response.json() + except Exception as error: + logger.warning("Check update version error: %s.", str(error)) + result.version = query.current_version + return result + latest_version = content.get("version", result.version) + if _has_new_version(latest_version=latest_version, current_version=f"{query.current_version}"): + result.version = latest_version + result.release_date = content.get("releaseDate", "") + result.release_notes = content.get("releaseNotes", "") + result.can_auto_update = content.get("canAutoUpdate", False) + return result + def _has_new_version(*, latest_version: str, current_version: str) -> bool: try: diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 55eaa2f09f..38c66525b3 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import Literal @@ -39,7 +41,7 @@ from fields.member_fields import account_fields from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required -from models import Account, AccountIntegrate, InvitationCode +from models import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel): repeat_new_password: str @model_validator(mode="after") - def check_passwords_match(self) -> "AccountPasswordPayload": + def check_passwords_match(self) -> AccountPasswordPayload: if self.new_password != self.repeat_new_password: raise RepeatPasswordNotMatchError() return self @@ -169,6 +171,19 @@ reg(ChangeEmailValidityPayload) reg(ChangeEmailResetPayload) reg(CheckEmailUniquePayload) +integrate_fields = { + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "link": fields.String, +} + +integrate_model = console_ns.model("AccountIntegrate", integrate_fields) +integrate_list_model = console_ns.model( + "AccountIntegrateList", + {"data": fields.List(fields.Nested(integrate_model))}, +) + @console_ns.route("/account/init") class AccountInitApi(Resource): @@ -334,21 +349,10 @@ class AccountPasswordApi(Resource): @console_ns.route("/account/integrates") class AccountIntegrateApi(Resource): - integrate_fields = { - "provider": fields.String, - "created_at": TimestampField, - "is_bound": fields.Boolean, - "link": fields.String, - } - - integrate_list_fields = { - "data": fields.List(fields.Nested(integrate_fields)), - } - @setup_required @login_required @account_initialization_required - @marshal_with(integrate_list_fields) + @marshal_with(integrate_list_model) def get(self): account, _ = current_account_with_tenant() @@ -534,7 +538,8 @@ class ChangeEmailSendEmailApi(Resource): else: language = "en-US" account = None - user_email = args.email + user_email = None + email_for_sending = args.email.lower() if args.phase is not None and args.phase == "new_email": if args.token is None: raise InvalidTokenError() @@ -544,16 +549,24 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidTokenError() user_email = reset_data.get("email", "") - if user_email != current_user.email: + if user_email.lower() != current_user.email.lower(): raise InvalidEmailError() + + user_email = current_user.email else: with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) if account is None: raise AccountNotFound() + email_for_sending = account.email + user_email = account.email token = AccountService.send_change_email_email( - account=account, email=args.email, old_email=user_email, language=language, phase=args.phase + account=account, + email=email_for_sending, + old_email=user_email, + language=language, + phase=args.phase, ) return {"result": "success", "data": token} @@ -569,9 +582,9 @@ class ChangeEmailCheckApi(Resource): payload = console_ns.payload or {} args = ChangeEmailValidityPayload.model_validate(payload) - user_email = args.email + user_email = args.email.lower() - is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email) + is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email) if is_change_email_error_rate_limit: raise EmailChangeLimitError() @@ -579,11 +592,13 @@ class ChangeEmailCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_change_email_error_rate_limit(args.email) + AccountService.add_change_email_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -594,8 +609,8 @@ class ChangeEmailCheckApi(Resource): user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} ) - AccountService.reset_change_email_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_change_email_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/account/change-email/reset") @@ -609,11 +624,12 @@ class ChangeEmailResetApi(Resource): def post(self): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) + normalized_new_email = args.new_email.lower() - if AccountService.is_account_in_freeze(args.new_email): + if AccountService.is_account_in_freeze(normalized_new_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(args.new_email): + if not AccountService.check_email_unique(normalized_new_email): raise EmailAlreadyInUseError() reset_data = AccountService.get_change_email_data(args.token) @@ -624,13 +640,13 @@ class ChangeEmailResetApi(Resource): old_email = reset_data.get("old_email", "") current_user, _ = current_account_with_tenant() - if current_user.email != old_email: + if current_user.email.lower() != old_email.lower(): raise AccountNotFound() - updated_account = AccountService.update_account_email(current_user, email=args.new_email) + updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) AccountService.send_change_email_completed_notify_email( - email=args.new_email, + email=normalized_new_email, ) return updated_account @@ -643,8 +659,9 @@ class CheckEmailUnique(Resource): def post(self): payload = console_ns.payload or {} args = CheckEmailUniquePayload.model_validate(payload) - if AccountService.is_account_in_freeze(args.email): + normalized_email = args.email.lower() + if AccountService.is_account_in_freeze(normalized_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(args.email): + if not AccountService.check_email_unique(normalized_email): raise EmailAlreadyInUseError() return {"result": "success"} diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 9bf393ea2e..ccb60b1461 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,6 +1,8 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType @@ -10,10 +12,20 @@ from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService +class LoadBalancingCredentialPayload(BaseModel): + model: str + model_type: ModelType + credentials: dict[str, object] + + +register_schema_models(console_ns, LoadBalancingCredentialPayload) + + @console_ns.route( "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate" ) class LoadBalancingCredentialsValidateApi(Resource): + @console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource): tenant_id = current_tenant_id - parser = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {}) # validate model load balancing credentials model_load_balancing_service = ModelLoadBalancingService() @@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], + model=payload.model, + model_type=payload.model_type, + credentials=payload.credentials, ) except CredentialsValidateFailedError as ex: result = False @@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource): "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate" ) class LoadBalancingConfigCredentialsValidateApi(Resource): + @console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): tenant_id = current_tenant_id - parser = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {}) # validate model load balancing config credentials model_load_balancing_service = ModelLoadBalancingService() @@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], + model=payload.model, + model_type=payload.model_type, + credentials=payload.credentials, config_id=config_id, ) except CredentialsValidateFailedError as ex: diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 0142e14fb0..271cdce3c3 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,11 +1,12 @@ from urllib import parse from flask import abort, request -from flask_restx import Resource, marshal_with +from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field import services from configs import dify_config +from controllers.common.schema import get_or_create_model, register_enum_models from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, @@ -24,7 +25,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_with_role_list_fields +from fields.member_fields import account_with_role_fields, account_with_role_list_fields from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole @@ -67,6 +68,13 @@ reg(MemberRoleUpdatePayload) reg(OwnerTransferEmailPayload) reg(OwnerTransferCheckPayload) reg(OwnerTransferPayload) +register_enum_models(console_ns, TenantAccountRole) + +account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields) + +account_with_role_list_fields_copy = account_with_role_list_fields.copy() +account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model)) +account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy) @console_ns.route("/workspaces/current/members") @@ -76,7 +84,7 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_fields) + @marshal_with(account_with_role_list_model) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: @@ -107,6 +115,12 @@ class MemberInviteEmailApi(Resource): inviter = current_user if not inviter.current_tenant: raise ValueError("No current tenant") + + # Check workspace permission for member invitations + from libs.workspace_permission import check_workspace_member_invite_permission + + check_workspace_member_invite_permission(inviter.current_tenant.id) + invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL @@ -116,26 +130,31 @@ class MemberInviteEmailApi(Resource): raise WorkspaceMembersLimitExceeded() for invitee_email in invitee_emails: + normalized_invitee_email = invitee_email.lower() try: if not inviter.current_tenant: raise ValueError("No current tenant") token = RegisterService.invite_new_member( - inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter + tenant=inviter.current_tenant, + email=invitee_email, + language=interface_language, + role=invitee_role, + inviter=inviter, ) - encoded_invitee_email = parse.quote(invitee_email) + encoded_invitee_email = parse.quote(normalized_invitee_email) invitation_results.append( { "status": "success", - "email": invitee_email, + "email": normalized_invitee_email, "url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}", } ) except AccountAlreadyInTenantError: invitation_results.append( - {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} + {"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"} ) except Exception as e: - invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) + invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)}) return { "result": "success", @@ -216,7 +235,7 @@ class DatasetOperatorMemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_fields) + @marshal_with(account_with_role_list_model) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 2def57ed7b..583e3e3057 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -5,6 +5,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.entities.model_entities import ModelType @@ -23,12 +24,13 @@ class ParserGetDefault(BaseModel): model_type: ModelType -class ParserPostDefault(BaseModel): - class Inner(BaseModel): - model_type: ModelType - model: str | None = None - provider: str | None = None +class Inner(BaseModel): + model_type: ModelType + model: str | None = None + provider: str | None = None + +class ParserPostDefault(BaseModel): model_settings: list[Inner] @@ -105,19 +107,21 @@ class ParserParameter(BaseModel): model: str -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models( + console_ns, + ParserGetDefault, + ParserPostDefault, + ParserDeleteModels, + ParserPostModels, + ParserGetCredentials, + ParserCreateCredential, + ParserUpdateCredential, + ParserDeleteCredential, + ParserParameter, + Inner, +) - -reg(ParserGetDefault) -reg(ParserPostDefault) -reg(ParserDeleteModels) -reg(ParserPostModels) -reg(ParserGetCredentials) -reg(ParserCreateCredential) -reg(ParserUpdateCredential) -reg(ParserDeleteCredential) -reg(ParserParameter) +register_enum_models(console_ns, ModelType) @console_ns.route("/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index ea74fc0337..d1485bc1c0 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required @@ -20,57 +21,12 @@ from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_service import PluginService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -@console_ns.route("/workspaces/current/plugin/debugging-key") -class PluginDebuggingKeyApi(Resource): - @setup_required - @login_required - @account_initialization_required - @plugin_permission_required(debug_required=True) - def get(self): - _, tenant_id = current_account_with_tenant() - - try: - return { - "key": PluginService.get_debugging_key(tenant_id), - "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST, - "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, - } - except PluginDaemonClientSideError as e: - raise ValueError(e) - class ParserList(BaseModel): page: int = Field(default=1, ge=1, description="Page number") page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)") -reg(ParserList) - - -@console_ns.route("/workspaces/current/plugin/list") -class PluginListApi(Resource): - @console_ns.expect(console_ns.models[ParserList.__name__]) - @setup_required - @login_required - @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() - args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore - try: - plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) - except PluginDaemonClientSideError as e: - raise ValueError(e) - - return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) - - class ParserLatest(BaseModel): plugin_ids: list[str] @@ -180,23 +136,73 @@ class ParserReadme(BaseModel): language: str = Field(default="en-US") -reg(ParserLatest) -reg(ParserIcon) -reg(ParserAsset) -reg(ParserGithubUpload) -reg(ParserPluginIdentifiers) -reg(ParserGithubInstall) -reg(ParserPluginIdentifierQuery) -reg(ParserTasks) -reg(ParserMarketplaceUpgrade) -reg(ParserGithubUpgrade) -reg(ParserUninstall) -reg(ParserPermissionChange) -reg(ParserDynamicOptions) -reg(ParserDynamicOptionsWithCredentials) -reg(ParserPreferencesChange) -reg(ParserExcludePlugin) -reg(ParserReadme) +register_schema_models( + console_ns, + ParserList, + PluginAutoUpgradeSettingsPayload, + PluginPermissionSettingsPayload, + ParserLatest, + ParserIcon, + ParserAsset, + ParserGithubUpload, + ParserPluginIdentifiers, + ParserGithubInstall, + ParserPluginIdentifierQuery, + ParserTasks, + ParserMarketplaceUpgrade, + ParserGithubUpgrade, + ParserUninstall, + ParserPermissionChange, + ParserDynamicOptions, + ParserDynamicOptionsWithCredentials, + ParserPreferencesChange, + ParserExcludePlugin, + ParserReadme, +) + +register_enum_models( + console_ns, + TenantPluginPermission.DebugPermission, + TenantPluginAutoUpgradeStrategy.UpgradeMode, + TenantPluginAutoUpgradeStrategy.StrategySetting, + TenantPluginPermission.InstallPermission, +) + + +@console_ns.route("/workspaces/current/plugin/debugging-key") +class PluginDebuggingKeyApi(Resource): + @setup_required + @login_required + @account_initialization_required + @plugin_permission_required(debug_required=True) + def get(self): + _, tenant_id = current_account_with_tenant() + + try: + return { + "key": PluginService.get_debugging_key(tenant_id), + "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST, + "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, + } + except PluginDaemonClientSideError as e: + raise ValueError(e) + + +@console_ns.route("/workspaces/current/plugin/list") +class PluginListApi(Resource): + @console_ns.expect(console_ns.models[ParserList.__name__]) + @setup_required + @login_required + @account_initialization_required + def get(self): + _, tenant_id = current_account_with_tenant() + args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore + try: + plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) + except PluginDaemonClientSideError as e: + raise ValueError(e) + + return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) @console_ns.route("/workspaces/current/plugin/list/latest-versions") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index cb711d16e4..e9e7b72718 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,4 +1,5 @@ import io +import logging from urllib.parse import urlparse from flask import make_response, redirect, request, send_file @@ -17,8 +18,8 @@ from controllers.console.wraps import ( is_admin_or_owner_required, setup_required, ) +from core.db.session_factory import session_factory from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration -from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient @@ -40,6 +41,8 @@ from services.tools.tools_manage_service import ToolCommonService from services.tools.tools_transform_service import ToolTransformService from services.tools.workflow_tools_manage_service import WorkflowToolManageService +logger = logging.getLogger(__name__) + def is_valid_url(url: str) -> bool: if not url: @@ -945,8 +948,8 @@ class ToolProviderMCPApi(Resource): configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None - # Create provider in transaction - with Session(db.engine) as session, session.begin(): + # 1) Create provider in a short transaction (no network I/O inside) + with session_factory.create_session() as session, session.begin(): service = MCPToolManageService(session=session) result = service.create_provider( tenant_id=tenant_id, @@ -962,8 +965,26 @@ class ToolProviderMCPApi(Resource): authentication=authentication, ) - # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations - ToolProviderListCache.invalidate_cache(tenant_id) + # 2) Try to fetch tools immediately after creation so they appear without a second save. + # Perform network I/O outside any DB session to avoid holding locks. + try: + reconnect = MCPToolManageService.reconnect_with_url( + server_url=args["server_url"], + headers=args.get("headers") or {}, + timeout=configuration.timeout, + sse_read_timeout=configuration.sse_read_timeout, + ) + # Update just-created provider with authed/tools in a new short transaction + with session_factory.create_session() as session, session.begin(): + service = MCPToolManageService(session=session) + db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id) + db_provider.authed = reconnect.authed + db_provider.tools = reconnect.tools + + result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True) + except Exception: + # Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is + logger.warning("Failed to fetch MCP tools after creation", exc_info=True) return jsonable_encoder(result) @@ -1011,9 +1032,6 @@ class ToolProviderMCPApi(Resource): validation_result=validation_result, ) - # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations - ToolProviderListCache.invalidate_cache(current_tenant_id) - return {"result": "success"} @console_ns.expect(parser_mcp_delete) @@ -1028,9 +1046,6 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) - # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations - ToolProviderListCache.invalidate_cache(current_tenant_id) - return {"result": "success"} @@ -1081,8 +1096,6 @@ class ToolMCPAuthApi(Resource): credentials=provider_entity.credentials, authed=True, ) - # Invalidate cache after updating credentials - ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} except MCPAuthError as e: try: @@ -1096,22 +1109,16 @@ class ToolMCPAuthApi(Resource): with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) response = service.execute_auth_actions(auth_result) - # Invalidate cache after auth actions may have updated provider state - ToolProviderListCache.invalidate_cache(tenant_id) return response except MCPRefreshTokenError as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) - # Invalidate cache after clearing credentials - ToolProviderListCache.invalidate_cache(tenant_id) raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e except (MCPError, ValueError) as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) - # Invalidate cache after clearing credentials - ToolProviderListCache.invalidate_cache(tenant_id) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 497e62b790..6b642af613 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -1,15 +1,14 @@ import logging -from collections.abc import Mapping from typing import Any from flask import make_response, redirect, request -from flask_restx import Resource, reqparse -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config -from constants import HIDDEN_VALUE, UNKNOWN_VALUE +from controllers.common.schema import register_schema_models from controllers.web.error import NotFoundError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType @@ -36,29 +35,38 @@ from ..wraps import ( logger = logging.getLogger(__name__) -class TriggerSubscriptionUpdateRequest(BaseModel): - """Request payload for updating a trigger subscription""" - - name: str | None = Field(default=None, description="The name for the subscription") - credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription") - parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription") - properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription") +class TriggerSubscriptionBuilderCreatePayload(BaseModel): + credential_type: str = CredentialType.UNAUTHORIZED -class TriggerSubscriptionVerifyRequest(BaseModel): - """Request payload for verifying subscription credentials.""" - - credentials: Mapping[str, Any] = Field(description="The credentials to verify") +class TriggerSubscriptionBuilderVerifyPayload(BaseModel): + credentials: dict[str, Any] -console_ns.schema_model( - TriggerSubscriptionUpdateRequest.__name__, - TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"), -) +class TriggerSubscriptionBuilderUpdatePayload(BaseModel): + name: str | None = None + parameters: dict[str, Any] | None = None + properties: dict[str, Any] | None = None + credentials: dict[str, Any] | None = None -console_ns.schema_model( - TriggerSubscriptionVerifyRequest.__name__, - TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"), + @model_validator(mode="after") + def check_at_least_one_field(self): + if all(v is None for v in self.model_dump().values()): + raise ValueError("At least one of name, credentials, parameters, or properties must be provided") + return self + + +class TriggerOAuthClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enabled: bool | None = None + + +register_schema_models( + console_ns, + TriggerSubscriptionBuilderCreatePayload, + TriggerSubscriptionBuilderVerifyPayload, + TriggerSubscriptionBuilderUpdatePayload, + TriggerOAuthClientPayload, ) @@ -127,16 +135,11 @@ class TriggerSubscriptionListApi(Resource): raise -parser = reqparse.RequestParser().add_argument( - "credential_type", type=str, required=False, nullable=True, location="json" -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/create", ) class TriggerSubscriptionBuilderCreateApi(Resource): - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -146,10 +149,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser.parse_args() + payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {}) try: - credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value) + credential_type = CredentialType.of(payload.credential_type) subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( tenant_id=user.current_tenant_id, user_id=user.id, @@ -177,18 +180,11 @@ class TriggerSubscriptionBuilderGetApi(Resource): ) -parser_api = ( - reqparse.RequestParser() - # The credentials of the subscription builder - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/verify-and-update/", ) -class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource): - @console_ns.expect(parser_api) +class TriggerSubscriptionBuilderVerifyApi(Resource): + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__]) @setup_required @login_required @edit_permission_required @@ -198,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser_api.parse_args() + payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {}) try: # Use atomic update_and_verify to prevent race conditions @@ -208,7 +204,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - credentials=args.get("credentials", None), + credentials=payload.credentials, ), ) except Exception as e: @@ -216,24 +212,11 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource): raise ValueError(str(e)) from e -parser_update_api = ( - reqparse.RequestParser() - # The name of the subscription builder - .add_argument("name", type=str, required=False, nullable=True, location="json") - # The parameters of the subscription builder - .add_argument("parameters", type=dict, required=False, nullable=True, location="json") - # The properties of the subscription builder - .add_argument("properties", type=dict, required=False, nullable=True, location="json") - # The credentials of the subscription builder - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") -) - - @console_ns.route( "/workspaces/current/trigger-provider//subscriptions/builder/update/", ) class TriggerSubscriptionBuilderUpdateApi(Resource): - @console_ns.expect(parser_update_api) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -244,7 +227,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): assert isinstance(user, Account) assert user.current_tenant_id is not None - args = parser_update_api.parse_args() + payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) try: return jsonable_encoder( TriggerSubscriptionBuilderService.update_trigger_subscription_builder( @@ -252,10 +235,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - name=args.get("name", None), - parameters=args.get("parameters", None), - properties=args.get("properties", None), - credentials=args.get("credentials", None), + name=payload.name, + parameters=payload.parameters, + properties=payload.properties, + credentials=payload.credentials, ), ) ) @@ -290,7 +273,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource): "/workspaces/current/trigger-provider//subscriptions/builder/build/", ) class TriggerSubscriptionBuilderBuildApi(Resource): - @console_ns.expect(parser_update_api) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -299,7 +282,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource): """Build a subscription instance for a trigger provider""" user = current_user assert user.current_tenant_id is not None - args = parser_update_api.parse_args() + payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) try: # Use atomic update_and_build to prevent race conditions TriggerSubscriptionBuilderService.update_and_build_builder( @@ -308,9 +291,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource): provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( - name=args.get("name", None), - parameters=args.get("parameters", None), - properties=args.get("properties", None), + name=payload.name, + parameters=payload.parameters, + properties=payload.properties, ), ) return 200 @@ -323,7 +306,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource): "/workspaces/current/trigger-provider//subscriptions/update", ) class TriggerSubscriptionUpdateApi(Resource): - @console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__]) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__]) @setup_required @login_required @edit_permission_required @@ -333,7 +316,7 @@ class TriggerSubscriptionUpdateApi(Resource): user = current_user assert user.current_tenant_id is not None - args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload) + request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) subscription = TriggerProviderService.get_subscription_by_id( tenant_id=user.current_tenant_id, @@ -345,50 +328,32 @@ class TriggerSubscriptionUpdateApi(Resource): provider_id = TriggerProviderID(subscription.provider_id) try: - # rename only - if ( - args.name is not None - and args.credentials is None - and args.parameters is None - and args.properties is None - ): + # For rename only, just update the name + rename = request.name is not None and not any((request.credentials, request.parameters, request.properties)) + # When credential type is UNAUTHORIZED, it indicates the subscription was manually created + # For Manually created subscription, they dont have credentials, parameters + # They only have name and properties(which is input by user) + manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED + if rename or manually_created: TriggerProviderService.update_trigger_subscription( tenant_id=user.current_tenant_id, subscription_id=subscription_id, - name=args.name, + name=request.name, + properties=request.properties, ) return 200 - # rebuild for create automatically by the provider - match subscription.credential_type: - case CredentialType.UNAUTHORIZED: - TriggerProviderService.update_trigger_subscription( - tenant_id=user.current_tenant_id, - subscription_id=subscription_id, - name=args.name, - properties=args.properties, - ) - return 200 - case CredentialType.API_KEY | CredentialType.OAUTH2: - if args.credentials: - new_credentials: dict[str, Any] = { - key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE) - for key, value in args.credentials.items() - } - else: - new_credentials = subscription.credentials - - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=user.current_tenant_id, - name=args.name, - provider_id=provider_id, - subscription_id=subscription_id, - credentials=new_credentials, - parameters=args.parameters or subscription.parameters, - ) - return 200 - case _: - raise BadRequest("Invalid credential type") + # For the rest cases(API_KEY, OAUTH2) + # we need to call third party provider(e.g. GitHub) to rebuild the subscription + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=user.current_tenant_id, + name=request.name, + provider_id=provider_id, + subscription_id=subscription_id, + credentials=request.credentials or subscription.credentials, + parameters=request.parameters or subscription.parameters, + ) + return 200 except ValueError as e: raise BadRequest(str(e)) except Exception as e: @@ -581,13 +546,6 @@ class TriggerOAuthCallbackApi(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_oauth_client = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enabled", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/trigger-provider//oauth/client") class TriggerOAuthClientManageApi(Resource): @setup_required @@ -635,7 +593,7 @@ class TriggerOAuthClientManageApi(Resource): logger.exception("Error getting OAuth client", exc_info=e) raise - @console_ns.expect(parser_oauth_client) + @console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -645,15 +603,15 @@ class TriggerOAuthClientManageApi(Resource): user = current_user assert user.current_tenant_id is not None - args = parser_oauth_client.parse_args() + payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {}) try: provider_id = TriggerProviderID(provider) return TriggerProviderService.save_custom_oauth_client_params( tenant_id=user.current_tenant_id, provider_id=provider_id, - client_params=args.get("client_params"), - enabled=args.get("enabled"), + client_params=payload.client_params, + enabled=payload.enabled, ) except ValueError as e: @@ -689,7 +647,7 @@ class TriggerOAuthClientManageApi(Resource): "/workspaces/current/trigger-provider//subscriptions/verify/", ) class TriggerSubscriptionVerifyApi(Resource): - @console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__]) + @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__]) @setup_required @login_required @edit_permission_required @@ -699,9 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource): user = current_user assert user.current_tenant_id is not None - verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate( - console_ns.payload - ) + verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {}) try: result = TriggerProviderService.verify_subscription_credentials( diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 909a5ce201..94be81d94f 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -20,6 +20,7 @@ from controllers.console.error import AccountNotLinkTenantError from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + only_edition_enterprise, setup_required, ) from enums.cloud_plan import CloudPlan @@ -28,6 +29,7 @@ from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantStatus from services.account_service import TenantService +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.file_service import FileService from services.workspace_service import WorkspaceService @@ -80,6 +82,9 @@ tenant_fields = { "in_trial": fields.Boolean, "trial_end_reason": fields.String, "custom_config": fields.Raw(attribute="custom_config"), + "trial_credits": fields.Integer, + "trial_credits_used": fields.Integer, + "next_credit_reset_date": fields.Integer, } tenants_fields = { @@ -285,3 +290,31 @@ class WorkspaceInfoApi(Resource): db.session.commit() return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} + + +@console_ns.route("/workspaces/current/permission") +class WorkspacePermissionApi(Resource): + """Get workspace permissions for the current workspace.""" + + @setup_required + @login_required + @account_initialization_required + @only_edition_enterprise + def get(self): + """ + Get workspace permission settings. + Returns permission flags that control workspace features like member invitations and owner transfer. + """ + _, current_tenant_id = current_account_with_tenant() + + if not current_tenant_id: + raise ValueError("No current tenant") + + # Get workspace permissions from enterprise service + permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id) + + return { + "workspace_id": permission.workspace_id, + "allow_member_invite": permission.allow_member_invite, + "allow_owner_transfer": permission.allow_owner_transfer, + }, 200 diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 95fc006a12..fd928b077d 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -286,13 +286,12 @@ def enable_change_email(view: Callable[P, R]): def is_allow_transfer_owner(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - _, current_tenant_id = current_account_with_tenant() - features = FeatureService.get_features(current_tenant_id) - if features.is_allow_transfer_workspace: - return view(*args, **kwargs) + from libs.workspace_permission import check_workspace_owner_transfer_permission - # otherwise, return 403 - abort(403) + _, current_tenant_id = current_account_with_tenant() + # Check both billing/plan level and workspace policy level permissions + check_workspace_owner_transfer_permission(current_tenant_id) + return view(*args, **kwargs) return decorated diff --git a/api/controllers/fastopenapi.py b/api/controllers/fastopenapi.py new file mode 100644 index 0000000000..c13f22338b --- /dev/null +++ b/api/controllers/fastopenapi.py @@ -0,0 +1,3 @@ +from fastopenapi.routers import FlaskRouter + +console_router = FlaskRouter() diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 6096a87c56..28ec4b3935 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -4,18 +4,18 @@ from flask import request from flask_restx import Resource from flask_restx.api import HTTPStatus from pydantic import BaseModel, Field -from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden import services from core.file.helpers import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from ..common.errors import ( FileTooLargeError, UnsupportedFileTypeError, ) +from ..common.schema import register_schema_models from ..console.wraps import setup_required from ..files import files_ns from ..inner_api.plugin.wraps import get_user @@ -35,6 +35,8 @@ files_ns.schema_model( PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) +register_schema_models(files_ns, FileResponse) + @files_ns.route("/upload/for-plugin") class PluginUploadFileApi(Resource): @@ -51,7 +53,7 @@ class PluginUploadFileApi(Resource): 415: "Unsupported file type", } ) - @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED) + @files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__]) def post(self): """Upload a file for plugin usage. @@ -69,7 +71,7 @@ class PluginUploadFileApi(Resource): """ args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - file: FileStorage | None = request.files.get("file") + file = request.files.get("file") if file is None: raise Forbidden("File is required.") @@ -80,8 +82,8 @@ class PluginUploadFileApi(Resource): user_id = args.user_id user = get_user(tenant_id, user_id) - filename: str | None = file.filename - mimetype: str | None = file.mimetype + filename = file.filename + mimetype = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") @@ -111,22 +113,22 @@ class PluginUploadFileApi(Resource): preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) # Create a dictionary with all the necessary attributes - result = { - "id": tool_file.id, - "user_id": tool_file.user_id, - "tenant_id": tool_file.tenant_id, - "conversation_id": tool_file.conversation_id, - "file_key": tool_file.file_key, - "mimetype": tool_file.mimetype, - "original_url": tool_file.original_url, - "name": tool_file.name, - "size": tool_file.size, - "mime_type": mimetype, - "extension": extension, - "preview_url": preview_url, - } + result = FileResponse( + id=tool_file.id, + name=tool_file.name, + size=tool_file.size, + extension=extension, + mime_type=mimetype, + preview_url=preview_url, + source_url=tool_file.original_url, + original_url=tool_file.original_url, + user_id=tool_file.user_id, + tenant_id=tool_file.tenant_id, + conversation_id=tool_file.conversation_id, + file_key=tool_file.file_key, + ) - return result, 201 + return result.model_dump(mode="json"), 201 except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 63c373b50f..85ac9336d6 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,7 +1,7 @@ from typing import Literal from flask import request -from flask_restx import Api, Namespace, Resource, fields +from flask_restx import Namespace, Resource, fields from flask_restx.api import HTTPStatus from pydantic import BaseModel, Field @@ -92,7 +92,7 @@ annotation_list_fields = { } -def build_annotation_list_model(api_or_ns: Api | Namespace): +def build_annotation_list_model(api_or_ns: Namespace): """Build the annotation list model for the API or Namespace.""" copied_annotation_list_fields = annotation_list_fields.copy() copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 25d7ccccec..562f5e33cc 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,6 +1,6 @@ from flask_restx import Resource -from controllers.common.fields import build_parameters_model +from controllers.common.fields import Parameters from controllers.service_api import service_api_ns from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token @@ -23,7 +23,6 @@ class AppParameterApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_parameters_model(service_api_ns)) def get(self, app_model: App): """Retrieve app parameters. @@ -45,7 +44,8 @@ class AppParameterApi(Resource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return Parameters.model_validate(parameters).model_dump(mode="json") @service_api_ns.route("/meta") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 40e4bde389..62e8258e25 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -3,8 +3,7 @@ from uuid import UUID from flask import request from flask_restx import Resource -from flask_restx._http import HTTPStatus -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( - build_conversation_delete_model, - build_conversation_infinite_scroll_pagination_model, - build_simple_conversation_model, + ConversationDelete, + ConversationInfiniteScrollPagination, + SimpleConversation, ) from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, @@ -105,7 +104,6 @@ class ConversationApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): """List all conversations for the current user. @@ -120,7 +118,7 @@ class ConversationApi(Resource): try: with Session(db.engine) as session: - return ConversationService.pagination_by_last_id( + pagination = ConversationService.pagination_by_last_id( session=session, app_model=app_model, user=end_user, @@ -129,6 +127,13 @@ class ConversationApi(Resource): invoke_from=InvokeFrom.SERVICE_API, sort_by=query_args.sort_by, ) + adapter = TypeAdapter(SimpleConversation) + conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] + return ConversationInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=conversations, + ).model_dump(mode="json") except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -146,7 +151,6 @@ class ConversationDetailApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT) def delete(self, app_model: App, end_user: EndUser, c_id): """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) @@ -159,7 +163,7 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ConversationDelete(result="success").model_dump(mode="json"), 204 @service_api_ns.route("/conversations//name") @@ -176,7 +180,6 @@ class ConversationRenameApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns)) def post(self, app_model: App, end_user: EndUser, c_id): """Rename a conversation or auto-generate a name.""" app_mode = AppMode.value_of(app_model.mode) @@ -188,7 +191,14 @@ class ConversationRenameApi(Resource): payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) + conversation = ConversationService.rename( + app_model, conversation_id, end_user, payload.name, payload.auto_generate + ) + return ( + TypeAdapter(SimpleConversation) + .validate_python(conversation, from_attributes=True) + .model_dump(mode="json") + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index ffe4e0b492..6f6dadf768 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -10,13 +10,16 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from extensions.ext_database import db -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from models import App, EndUser from services.file_service import FileService +register_schema_models(service_api_ns, FileResponse) + @service_api_ns.route("/files/upload") class FileApi(Resource): @@ -31,8 +34,8 @@ class FileApi(Resource): 415: "Unsupported file type", } ) - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) - @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore + @service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__]) def post(self, app_model: App, end_user: EndUser): """Upload a file for use in conversations. @@ -64,4 +67,5 @@ class FileApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index d342f4e661..8981bbd7d5 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,11 +1,10 @@ -import json import logging from typing import Literal from uuid import UUID from flask import request -from flask_restx import Namespace, Resource, fields -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services @@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom -from fields.conversation_fields import build_message_file_model -from fields.message_fields import build_agent_thought_model, build_feedback_model -from fields.raws import FilesContainedField -from libs.helper import TimestampField +from fields.conversation_fields import ResultResponse +from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -48,49 +45,6 @@ class FeedbackListQuery(BaseModel): register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery) -def build_message_model(api_or_ns: Namespace): - """Build the message model for the API or Namespace.""" - # First build the nested models - feedback_model = build_feedback_model(api_or_ns) - agent_thought_model = build_agent_thought_model(api_or_ns) - message_file_model = build_message_file_model(api_or_ns) - - # Then build the message fields with nested models - message_fields = { - "id": fields.String, - "conversation_id": fields.String, - "parent_message_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_model)), - "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True), - "retriever_resources": fields.Raw( - attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", []) - if obj.message_metadata - else [] - ), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "status": fields.String, - "error": fields.String, - } - return api_or_ns.model("Message", message_fields) - - -def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace): - """Build the message infinite scroll pagination model for the API or Namespace.""" - # Build the nested message model first - message_model = build_message_model(api_or_ns) - - message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_model)), - } - return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields) - - @service_api_ns.route("/messages") class MessageListApi(Resource): @service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__]) @@ -104,7 +58,6 @@ class MessageListApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): """List messages in a conversation. @@ -119,9 +72,16 @@ class MessageListApi(Resource): first_id = str(query_args.first_id) if query_args.first_id else None try: - return MessageService.pagination_by_first_id( + pagination = MessageService.pagination_by_first_id( app_model, end_user, conversation_id, first_id, query_args.limit ) + adapter = TypeAdapter(MessageListItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return MessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except FirstMessageNotExistsError: @@ -162,7 +122,7 @@ class MessageFeedbackApi(Resource): except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @service_api_ns.route("/app/feedbacks") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index 9f8324a84e..8b47a887bb 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,7 +1,7 @@ from flask_restx import Resource from werkzeug.exceptions import Forbidden -from controllers.common.fields import build_site_model +from controllers.common.fields import Site as SiteResponse from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db @@ -23,7 +23,6 @@ class AppSiteApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_site_model(service_api_ns)) def get(self, app_model: App): """Retrieve app site info. @@ -38,4 +37,4 @@ class AppSiteApi(Resource): if app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() - return site + return SiteResponse.model_validate(site).model_dump(mode="json") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 5d10c3e8fd..2d24e0f5f9 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -3,7 +3,7 @@ from typing import Any, Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Api, Namespace, Resource, fields +from flask_restx import Namespace, Resource, fields from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -110,7 +110,7 @@ workflow_run_fields = { } -def build_workflow_run_model(api_or_ns: Api | Namespace): +def build_workflow_run_model(api_or_ns: Namespace): """Build the workflow run model for the API or Namespace.""" return api_or_ns.model("WorkflowRun", workflow_run_fields) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 4f91f40c55..28864a140a 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,7 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound import services @@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, - validate_dataset_token, ) from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager @@ -27,6 +26,14 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.tag_service import TagService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +service_api_ns.schema_model( + DatasetPermissionEnum.__name__, + TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + class DatasetCreatePayload(BaseModel): name: str = Field(..., min_length=1, max_length=40) @@ -88,6 +95,14 @@ class TagUnbindingPayload(BaseModel): target_id: str +class DatasetListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + include_all: bool = Field(default=False, description="Include all datasets") + tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs") + + register_schema_models( service_api_ns, DatasetCreatePayload, @@ -97,6 +112,7 @@ register_schema_models( TagDeletePayload, TagBindingPayload, TagUnbindingPayload, + DatasetListQuery, ) @@ -114,15 +130,11 @@ class DatasetListApi(DatasetApiResource): ) def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) + query = DatasetListQuery.model_validate(request.args.to_dict()) # provider = request.args.get("provider", default="vendor") - search = request.args.get("keyword", default=None, type=str) - tag_ids = request.args.getlist("tag_ids") - include_all = request.args.get("include_all", default="false").lower() == "true" datasets, total = DatasetService.get_datasets( - page, limit, tenant_id, current_user, search, tag_ids, include_all + query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting provider_manager = ProviderManager() @@ -148,7 +160,13 @@ class DatasetListApi(DatasetApiResource): item["embedding_available"] = False else: item["embedding_available"] = True - response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} + response = { + "data": data, + "has_more": len(datasets) == query.limit, + "limit": query.limit, + "total": total, + "page": query.page, + } return response, 200 @service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__]) @@ -460,9 +478,8 @@ class DatasetTagsApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - @validate_dataset_token @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) - def get(self, _, dataset_id): + def get(self, _): """Get all knowledge type tags.""" assert isinstance(current_user, Account) cid = current_user.current_tenant_id @@ -482,8 +499,7 @@ class DatasetTagsApi(DatasetApiResource): } ) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) - @validate_dataset_token - def post(self, _, dataset_id): + def post(self, _): """Add a knowledge type tag.""" assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -506,8 +522,7 @@ class DatasetTagsApi(DatasetApiResource): } ) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) - @validate_dataset_token - def patch(self, _, dataset_id): + def patch(self, _): assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() @@ -533,9 +548,8 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @validate_dataset_token @edit_permission_required - def delete(self, _, dataset_id): + def delete(self, _): """Delete a knowledge type tag.""" payload = TagDeletePayload.model_validate(service_api_ns.payload or {}) TagService.delete_tag(payload.tag_id) @@ -555,8 +569,7 @@ class DatasetTagBindingApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @validate_dataset_token - def post(self, _, dataset_id): + def post(self, _): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -580,8 +593,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @validate_dataset_token - def post(self, _, dataset_id): + def post(self, _): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -604,7 +616,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - @validate_dataset_token def get(self, _, *args, **kwargs): """Get all knowledge type tags.""" dataset_id = kwargs.get("dataset_id") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c800c0e4e1..c85c1cf81e 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -16,6 +16,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_enum_models, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ( @@ -29,12 +30,20 @@ from controllers.service_api.wraps import ( cloud_edition_billing_resource_check, ) from core.errors.error import ProviderTokenNotInitError +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DatasetService, DocumentService -from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel +from services.entities.knowledge_entities.knowledge_entities import ( + KnowledgeConfig, + PreProcessingRule, + ProcessRule, + RetrievalModel, + Rule, + Segmentation, +) from services.file_service import FileService @@ -69,8 +78,26 @@ class DocumentTextUpdate(BaseModel): return self -for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]: - service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore +class DocumentListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + status: str | None = Field(default=None, description="Document status filter") + + +register_enum_models(service_api_ns, RetrievalMethod) + +register_schema_models( + service_api_ns, + ProcessRule, + RetrievalModel, + DocumentTextCreatePayload, + DocumentTextUpdate, + DocumentListQuery, + Rule, + PreProcessingRule, + Segmentation, +) @service_api_ns.route( @@ -261,17 +288,6 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - args = {} - if "data" in request.form: - args = json.loads(request.form["data"]) - if "doc_form" not in args: - args["doc_form"] = "text_model" - if "doc_language" not in args: - args["doc_language"] = "English" - - # get dataset info - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -280,6 +296,18 @@ class DocumentAddByFileApi(DatasetApiResource): if dataset.provider == "external": raise ValueError("External datasets are not supported.") + args = {} + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = dataset.chunk_structure or "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" + + # get dataset info + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + indexing_technique = args.get("indexing_technique") or dataset.indexing_technique if not indexing_technique: raise ValueError("indexing_technique is required.") @@ -370,17 +398,6 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - args = {} - if "data" in request.form: - args = json.loads(request.form["data"]) - if "doc_form" not in args: - args["doc_form"] = "text_model" - if "doc_language" not in args: - args["doc_language"] = "English" - - # get dataset info - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -389,6 +406,18 @@ class DocumentUpdateByFileApi(DatasetApiResource): if dataset.provider == "external": raise ValueError("External datasets are not supported.") + args = {} + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = dataset.chunk_structure or "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" + + # get dataset info + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique @@ -458,34 +487,33 @@ class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) - status = request.args.get("status", default=None, type=str) + query_params = DocumentListQuery.model_validate(request.args.to_dict()) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) - if status: - query = DocumentService.apply_display_status_filter(query, status) + if query_params.status: + query = DocumentService.apply_display_status_filter(query, query_params.status) - if search: - search = f"%{search}%" + if query_params.keyword: + search = f"%{query_params.keyword}%" query = query.where(Document.name.like(search)) query = query.order_by(desc(Document.created_at), desc(Document.position)) - paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate( + select=query, page=query_params.page, per_page=query_params.limit, max_per_page=100, error_out=False + ) documents = paginated_documents.items response = { "data": marshal(documents, document_fields), - "has_more": len(documents) == limit, - "limit": limit, + "has_more": len(documents) == query_params.limit, + "limit": query_params.limit, "total": paginated_documents.total, - "page": page, + "page": query_params.page, } return response diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index d81287d56f..8dbb690901 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - args = self.parse_args() + args = self.parse_args(service_api_ns.payload) self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index aab25c1af3..b8d9508004 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -11,7 +11,9 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_bill from fields.dataset_fields import dataset_metadata_fields from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, MetadataArgs, + MetadataDetail, MetadataOperationData, ) from services.metadata_service import MetadataService @@ -22,7 +24,13 @@ class MetadataUpdatePayload(BaseModel): register_schema_model(service_api_ns, MetadataUpdatePayload) -register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData) +register_schema_models( + service_api_ns, + MetadataArgs, + MetadataDetail, + DocumentMetadataOperation, + MetadataOperationData, +) @service_api_ns.route("/datasets//metadata") diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 0a2017e2bd..70b5030237 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource): pipeline=pipeline, user=current_user, args=payload.model_dump(), - invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER, streaming=payload.response_mode == "streaming", ) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index b242fd2c3e..95679e6fcb 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -60,6 +60,7 @@ register_schema_models( service_api_ns, SegmentCreatePayload, SegmentListQuery, + SegmentUpdateArgs, SegmentUpdatePayload, ChildChunkCreatePayload, ChildChunkListQuery, diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index db3b93a4dc..62ea532eac 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, ConfigDict, Field from werkzeug.exceptions import Unauthorized @@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: @@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return fields.Parameters.model_validate(parameters).model_dump(mode="json") @web_ns.route("/meta") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 86e19423e5..e76649495a 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,14 +1,21 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from typing import Literal + +from flask import request +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from fields.conversation_fields import ( + ConversationInfiniteScrollPagination, + ResultResponse, + SimpleConversation, +) from libs.helper import uuid_value from models.model import AppMode from services.conversation_service import ConversationService @@ -16,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +class ConversationListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at" + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + + +register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) + + @web_ns.route("/conversations") class ConversationListApi(WebApiResource): @web_ns.doc("Get Conversation List") @@ -54,54 +90,39 @@ class ConversationListApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args = request.args.to_dict() + query = ConversationListQuery.model_validate(raw_args) try: with Session(db.engine) as session: - return WebConversationService.pagination_by_last_id( + pagination = WebConversationService.pagination_by_last_id( session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=query.last_id, + limit=query.limit, invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], + pinned=query.pinned, + sort_by=query.sort_by, ) + adapter = TypeAdapter(SimpleConversation) + conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] + return ConversationInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=conversations, + ).model_dump(mode="json") except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @web_ns.route("/conversations/") class ConversationApi(WebApiResource): - delete_response_fields = { - "result": fields.String, - } - @web_ns.doc("Delete Conversation") @web_ns.doc(description="Delete a specific conversation.") @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) @@ -115,7 +136,6 @@ class ConversationApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(delete_response_fields) def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -126,7 +146,7 @@ class ConversationApi(WebApiResource): ConversationService.delete(app_model, conversation_id, end_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 @web_ns.route("/conversations//name") @@ -155,7 +175,6 @@ class ConversationRenameApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -163,25 +182,23 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) + conversation = ConversationService.rename( + app_model, conversation_id, end_user, payload.name, payload.auto_generate + ) + return ( + TypeAdapter(SimpleConversation) + .validate_python(conversation, from_attributes=True) + .model_dump(mode="json") + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @web_ns.route("/conversations//pin") class ConversationPinApi(WebApiResource): - pin_response_fields = { - "result": fields.String, - } - @web_ns.doc("Pin Conversation") @web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.") @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) @@ -195,7 +212,6 @@ class ConversationPinApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(pin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -208,15 +224,11 @@ class ConversationPinApi(WebApiResource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @web_ns.route("/conversations//unpin") class ConversationUnPinApi(WebApiResource): - unpin_response_fields = { - "result": fields.String, - } - @web_ns.doc("Unpin Conversation") @web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.") @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) @@ -230,7 +242,6 @@ class ConversationUnPinApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(unpin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -239,4 +250,4 @@ class ConversationUnPinApi(WebApiResource): conversation_id = str(c_id) WebConversationService.unpin(app_model, conversation_id, end_user) - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index cce3dae95d..2540bf02f4 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -17,5 +17,15 @@ class SystemFeatureApi(Resource): Returns: dict: System feature configuration object + + This endpoint is akin to the `SystemFeatureApi` endpoint in api/controllers/console/feature.py, + except it is intended for use by the web app, instead of the console dashboard. + + NOTE: This endpoint is unauthenticated by design, as it provides system features + data required for webapp initialization. + + Authentication would create circular dependency (can't authenticate without webapp loading). + + Only non-sensitive configuration data should be returned by this endpoint. """ return FeatureService.get_system_features().model_dump() diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 80ad61e549..0036c90800 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,4 @@ from flask import request -from flask_restx import marshal_with import services from controllers.common.errors import ( @@ -9,12 +8,15 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from services.file_service import FileService +register_schema_models(web_ns, FileResponse) + @web_ns.route("/files/upload") class FileApi(WebApiResource): @@ -28,7 +30,7 @@ class FileApi(WebApiResource): 415: "Unsupported file type", } ) - @marshal_with(build_file_model(web_ns)) + @web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__]) def post(self, app_model, end_user): """Upload a file for use in web applications. @@ -81,4 +83,5 @@ class FileApi(WebApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 690b76655f..91d206f727 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from controllers.common.schema import register_schema_models @@ -22,7 +21,7 @@ from controllers.web import web_ns from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password -from models import Account +from models.account import Account from services.account_service import AccountService @@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource): def post(self): payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {}) + request_email = payload.email + normalized_email = request_email.lower() + ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() @@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) token = None if account is None: raise AuthenticationFailedError() else: - token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language) + token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language) return {"result": "success", "data": token} @@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource): def post(self): payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {}) - user_email = payload.email + user_email = payload.email.lower() - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() @@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + + if user_email != normalized_token_email: raise InvalidEmailError() if payload.code != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(payload.email) + AccountService.add_forgot_password_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource): # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=payload.code, additional_data={"phase": "reset"} + token_email, code=payload.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(payload.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_forgot_password_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @web_ns.route("/forgot-password/resets") @@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: self._update_existing_account(account, password_hashed, salt, session) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 538d0c44be..a824f6d487 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,19 +1,26 @@ from flask import make_response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource from jwt import InvalidTokenError +from pydantic import BaseModel, Field, field_validator import services from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( AuthenticationFailedError, EmailCodeError, InvalidEmailError, ) from controllers.console.error import AccountBannedError -from controllers.console.wraps import only_edition_enterprise, setup_required +from controllers.console.wraps import ( + decrypt_code_field, + decrypt_password_field, + only_edition_enterprise, + setup_required, +) from controllers.web import web_ns from controllers.web.wraps import decode_jwt_token -from libs.helper import email +from libs.helper import EmailStr from libs.passport import PassportService from libs.password import valid_password from libs.token import ( @@ -25,10 +32,35 @@ from services.app_service import AppService from services.webapp_auth_service import WebAppAuthService +class LoginPayload(BaseModel): + email: EmailStr + password: str + + @field_validator("password") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +class EmailCodeLoginSendPayload(BaseModel): + email: EmailStr + language: str | None = None + + +class EmailCodeLoginVerifyPayload(BaseModel): + email: EmailStr + code: str + token: str = Field(min_length=1) + + +register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload) + + @web_ns.route("/login") class LoginApi(Resource): """Resource for web app email/password login.""" + @web_ns.expect(web_ns.models[LoginPayload.__name__]) @setup_required @only_edition_enterprise @web_ns.doc("web_app_login") @@ -42,17 +74,13 @@ class LoginApi(Resource): 404: "Account not found", } ) + @decrypt_password_field def post(self): """Authenticate user and login.""" - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("password", type=valid_password, required=True, location="json") - ) - args = parser.parse_args() + payload = LoginPayload.model_validate(web_ns.payload or {}) try: - account = WebAppAuthService.authenticate(args["email"], args["password"]) + account = WebAppAuthService.authenticate(payload.email, payload.password) except services.errors.account.AccountLoginError: raise AccountBannedError() except services.errors.account.AccountPasswordError: @@ -139,6 +167,7 @@ class EmailCodeLoginSendEmailApi(Resource): @only_edition_enterprise @web_ns.doc("send_email_code_login") @web_ns.doc(description="Send email verification code for login") + @web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__]) @web_ns.doc( responses={ 200: "Email code sent successfully", @@ -147,19 +176,14 @@ class EmailCodeLoginSendEmailApi(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {}) - if args["language"] is not None and args["language"] == "zh-Hans": + if payload.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" - account = WebAppAuthService.get_user_through_email(args["email"]) + account = WebAppAuthService.get_user_through_email(payload.email) if account is None: raise AuthenticationFailedError() else: @@ -173,6 +197,7 @@ class EmailCodeLoginApi(Resource): @only_edition_enterprise @web_ns.doc("verify_email_code_login") @web_ns.doc(description="Verify email code and complete login") + @web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__]) @web_ns.doc( responses={ 200: "Email code verified and login successful", @@ -181,34 +206,33 @@ class EmailCodeLoginApi(Resource): 404: "Account not found", } ) + @decrypt_code_field def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, location="json") - ) - args = parser.parse_args() + payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {}) - user_email = args["email"] + user_email = payload.email.lower() - token_data = WebAppAuthService.get_email_code_login_data(args["token"]) + token_data = WebAppAuthService.get_email_code_login_data(payload.token) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + if normalized_token_email != user_email: raise InvalidEmailError() - if token_data["code"] != args["code"]: + if token_data["code"] != payload.code: raise EmailCodeError() - WebAppAuthService.revoke_email_code_login_token(args["token"]) - account = WebAppAuthService.get_user_through_email(user_email) + WebAppAuthService.revoke_email_code_login_token(payload.token) + account = WebAppAuthService.get_user_through_email(token_email) if not account: raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(user_email) response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) return response diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index a02d226762..80035ba818 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,8 +2,7 @@ import logging from typing import Literal from flask import request -from flask_restx import fields, marshal_with -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -22,11 +21,10 @@ from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from fields.conversation_fields import message_file_fields -from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields -from fields.raws import FilesContainedField +from fields.conversation_fields import ResultResponse +from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -70,30 +68,6 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message @web_ns.route("/messages") class MessageListApi(WebApiResource): - message_fields = { - "id": fields.String, - "conversation_id": fields.String, - "parent_message_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), - "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), - "extra_contents": fields.List(fields.Raw), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - } - - message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), - } - @web_ns.doc("Get Message List") @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") @web_ns.doc( @@ -122,7 +96,6 @@ class MessageListApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -132,9 +105,16 @@ class MessageListApi(WebApiResource): query = MessageListQuery.model_validate(raw_args) try: - return MessageService.pagination_by_first_id( + pagination = MessageService.pagination_by_first_id( app_model, end_user, query.conversation_id, query.first_id, query.limit ) + adapter = TypeAdapter(WebMessageListItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return WebMessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except FirstMessageNotExistsError: @@ -143,10 +123,6 @@ class MessageListApi(WebApiResource): @web_ns.route("/messages//feedbacks") class MessageFeedbackApi(WebApiResource): - feedback_response_fields = { - "result": fields.String, - } - @web_ns.doc("Create Message Feedback") @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) @@ -171,7 +147,6 @@ class MessageFeedbackApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(feedback_response_fields) def post(self, app_model, end_user, message_id): message_id = str(message_id) @@ -188,7 +163,7 @@ class MessageFeedbackApi(WebApiResource): except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @web_ns.route("/messages//more-like-this") @@ -248,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource): @web_ns.route("/messages//suggested-questions") class MessageSuggestedQuestionApi(WebApiResource): - suggested_questions_response_fields = { - "data": fields.List(fields.String), - } - @web_ns.doc("Get Suggested Questions") @web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).") @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) @@ -265,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(suggested_questions_response_fields) def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -278,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource): app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP ) # questions is a list of strings, not a list of Message objects - # so we can directly return it except MessageNotExistsError: raise NotFound("Message not found") except ConversationNotExistsError: @@ -297,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource): logger.exception("internal server error.") raise InternalServerError() - return {"data": questions} + return SuggestedQuestionsResponse(data=questions).model_dump(mode="json") diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index c1f976829f..b08b3fe858 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from flask_restx import marshal_with from pydantic import BaseModel, Field, HttpUrl import services @@ -14,7 +13,7 @@ from controllers.common.errors import ( from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db -from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model +from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from services.file_service import FileService from ..common.schema import register_schema_models @@ -26,7 +25,7 @@ class RemoteFileUploadPayload(BaseModel): url: HttpUrl = Field(description="Remote file URL") -register_schema_models(web_ns, RemoteFileUploadPayload) +register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl) @web_ns.route("/remote-files/") @@ -41,7 +40,7 @@ class RemoteFileInfoApi(WebApiResource): 500: "Failed to fetch remote file", } ) - @marshal_with(build_remote_file_info_model(web_ns)) + @web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__]) def get(self, app_model, end_user, url): """Get information about a remote file. @@ -65,10 +64,11 @@ class RemoteFileInfoApi(WebApiResource): # failed back to get method resp = ssrf_proxy.get(decoded_url, timeout=3) resp.raise_for_status() - return { - "file_type": resp.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(resp.headers.get("Content-Length", -1)), - } + info = RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", -1)), + ) + return info.model_dump(mode="json") @web_ns.route("/remote-files/upload") @@ -84,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource): 500: "Failed to fetch remote file", } ) - @marshal_with(build_file_with_signed_url_model(web_ns)) + @web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__]) def post(self, app_model, end_user): """Upload a file from a remote URL. @@ -139,13 +139,14 @@ class RemoteFileUploadApi(WebApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at, - }, 201 + payload1 = FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) + return payload1.model_dump(mode="json"), 201 diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 865f3610a7..29993100f6 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,40 +1,32 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource -from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField, uuid_value +from fields.conversation_fields import ResultResponse +from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem +from libs.helper import UUIDStrOrEmpty from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = {"rating": fields.String} -message_fields = { - "id": fields.String, - "inputs": fields.Raw, - "query": fields.String, - "answer": fields.String, - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), - "created_at": TimestampField, -} +class SavedMessageListQuery(BaseModel): + last_id: UUIDStrOrEmpty | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SavedMessageCreatePayload(BaseModel): + message_id: UUIDStrOrEmpty + + +register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) @web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): - saved_message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), - } - - post_response_fields = { - "result": fields.String, - } - @web_ns.doc("Get Saved Messages") @web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.") @web_ns.doc( @@ -58,19 +50,21 @@ class SavedMessageListApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict() + query = SavedMessageListQuery.model_validate(raw_args) - return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) + adapter = TypeAdapter(SavedMessageItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return SavedMessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") @web_ns.doc("Save Message") @web_ns.doc(description="Save a specific message for later reference.") @@ -89,28 +83,22 @@ class SavedMessageListApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(post_response_fields) def post(self, app_model, end_user): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, args["message_id"]) + SavedMessageService.save(app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @web_ns.route("/saved-messages/") class SavedMessageApi(WebApiResource): - delete_response_fields = { - "result": fields.String, - } - @web_ns.doc("Delete Saved Message") @web_ns.doc(description="Remove a message from saved messages.") @web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}}) @@ -124,7 +112,6 @@ class SavedMessageApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(delete_response_fields) def delete(self, app_model, end_user, message_id): message_id = str(message_id) @@ -133,4 +120,4 @@ class SavedMessageApi(WebApiResource): SavedMessageService.delete(app_model, end_user, message_id) - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 3cbb07a296..95d8c6d5a5 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,8 +1,10 @@ import logging +from typing import Any -from flask_restx import reqparse +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( CompletionRequestError, @@ -27,19 +29,22 @@ from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError + +class WorkflowRunPayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the workflow") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow") + + logger = logging.getLogger(__name__) +register_schema_models(web_ns, WorkflowRunPayload) + @web_ns.route("/workflows/run") class WorkflowRunApi(WebApiResource): @web_ns.doc("Run Workflow") @web_ns.doc(description="Execute a workflow with provided inputs and files.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the workflow", "type": "object", "required": True}, - "files": {"description": "Files to be processed by the workflow", "type": "array", "required": False}, - } - ) + @web_ns.expect(web_ns.models[WorkflowRunPayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -58,12 +63,8 @@ class WorkflowRunApi(WebApiResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("files", type=list, required=False, location="json") - ) - args = parser.parse_args() + payload = WorkflowRunPayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = AppGenerateService.generate( diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index c196dbbdf1..3c6d36afe4 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,6 +1,7 @@ import json import logging import uuid +from decimal import Decimal from typing import Union, cast from sqlalchemy import select @@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory +from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) @@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner): thought = MessageAgentThought( message_id=message_id, message_chain_id=None, + tool_process_data=None, thought="", tool=tool_name, tool_labels_str="{}", @@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner): tool_input=tool_input, message=message, message_token=0, - message_unit_price=0, - message_price_unit=0, + message_unit_price=Decimal(0), + message_price_unit=Decimal("0.001"), message_files=json.dumps(messages_ids) if messages_ids else "", answer="", observation="", answer_token=0, - answer_unit_price=0, - answer_price_unit=0, + answer_unit_price=Decimal(0), + answer_price_unit=Decimal("0.001"), tokens=0, - total_price=0, + total_price=Decimal(0), position=self.agent_thought_count + 1, currency="USD", latency=0, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=self.user_id, ) @@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner): raise ValueError("agent thought not found") if thought: - agent_thought.thought += thought + existing_thought = agent_thought.thought or "" + agent_thought.thought = f"{existing_thought}{thought}" if tool_name: agent_thought.tool = tool_name @@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner): agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if agent_thoughts: for agent_thought in agent_thoughts: - tools = agent_thought.tool - if tools: - tools = tools.split(";") + tool_names_raw = agent_thought.tool + if tool_names_raw: + tool_names = tool_names_raw.split(";") tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_call_response: list[ToolPromptMessage] = [] - try: - tool_inputs = json.loads(agent_thought.tool_input) - except Exception: - tool_inputs = {tool: {} for tool in tools} - try: - tool_responses = json.loads(agent_thought.observation) - except Exception: - tool_responses = dict.fromkeys(tools, agent_thought.observation) + tool_input_payload = agent_thought.tool_input + if tool_input_payload: + try: + tool_inputs = json.loads(tool_input_payload) + except Exception: + tool_inputs = {tool: {} for tool in tool_names} + else: + tool_inputs = {tool: {} for tool in tool_names} - for tool in tools: + observation_payload = agent_thought.observation + if observation_payload: + try: + tool_responses = json.loads(observation_payload) + except Exception: + tool_responses = dict.fromkeys(tool_names, observation_payload) + else: + tool_responses = dict.fromkeys(tool_names, observation_payload) + + for tool in tool_names: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) tool_calls.append( @@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner): *tool_call_response, ] ) - if not tools: + if not tool_names_raw: result.append(AssistantPromptMessage(content=agent_thought.thought)) else: if message.answer: diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index b32e35d0ca..a55f2d0f5f 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine +from core.workflow.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) @@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC): scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) + # Check if max iteration is reached and model still wants to call tools + if iteration_step == max_iteration_steps and scratchpad.action: + if scratchpad.action.action_name.lower() != "final answer": + raise AgentMaxIterationError(app_config.agent.max_iteration) + # get llm usage if "usage" in usage_dict: if usage_dict["usage"] is not None: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index dcc1326b33..7c5c9136a7 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine +from core.workflow.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) @@ -187,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): ), ) - assistant_message = AssistantPromptMessage(content="", tool_calls=[]) + assistant_message = AssistantPromptMessage(content=response, tool_calls=[]) if tool_calls: assistant_message.tool_calls = [ AssistantPromptMessage.ToolCall( @@ -199,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): ) for tool_call in tool_calls ] - else: - assistant_message.content = response self._current_thoughts.append(assistant_message) @@ -222,6 +221,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): final_answer += response + "\n" + # Check if max iteration is reached and model still wants to call tools + if iteration_step == max_iteration_steps and tool_calls: + raise AgentMaxIterationError(app_config.agent.max_iteration) + # call tools tool_responses = [] for tool_call_id, tool_call_name, tool_call_args in tool_calls: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 307af3747c..13c51529cc 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,4 +1,3 @@ -import json from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal @@ -121,7 +120,7 @@ class VariableEntity(BaseModel): allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: str | None = Field(default=None) + json_schema: dict | None = Field(default=None) @field_validator("description", mode="before") @classmethod @@ -135,17 +134,11 @@ class VariableEntity(BaseModel): @field_validator("json_schema") @classmethod - def validate_json_schema(cls, schema: str | None) -> str | None: + def validate_json_schema(cls, schema: dict | None) -> dict | None: if schema is None: return None - try: - json_schema = json.loads(schema) - except json.JSONDecodeError: - raise ValueError(f"invalid json_schema value {schema}") - - try: - Draft7Validator.check_schema(json_schema) + Draft7Validator.check_schema(schema) except SchemaError as e: raise ValueError(f"Invalid JSON schema: {e.message}") return schema diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index e4b308a6f6..c21c494efe 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: features_dict = workflow.features_dict - app_mode = AppMode.value_of(app_model.mode) app_config = AdvancedChatAppConfig( tenant_id=app_model.tenant_id, diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index fd913b807d..9249e3cc70 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import contextvars import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Literal, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -13,6 +15,9 @@ from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config from constants import UUID_NIL + +if TYPE_CHECKING: + from controllers.console.app.workflow import LoopNodeRunPayload from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -347,7 +352,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow: Workflow, node_id: str, user: Account | EndUser, - args: Mapping, + args: LoopNodeRunPayload, streaming: bool = True, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ @@ -363,7 +368,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): if not node_id: raise ValueError("node_id is required") - if args.get("inputs") is None: + if args.inputs is None: raise ValueError("inputs is required") # convert to app config @@ -381,7 +386,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): stream=streaming, invoke_from=InvokeFrom.DEBUGGER, extras={"auto_generate_conversation_name": False}, - single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), + single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs), ) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index f06a6f9e9b..f8cd68a9bf 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -20,13 +20,15 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration -from core.variables.variables import VariableUnion +from core.variables.variables import Variable from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -37,9 +39,9 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span from models import Workflow -from models.enums import UserFrom from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -105,6 +107,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not app_record: raise ValueError("App not found") + invoke_from = self.application_generate_entity.invoke_from + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + invoke_from = InvokeFrom.DEBUGGER + user_from = self._resolve_user_from(invoke_from) + resume_state = self._resume_graph_runtime_state if resume_state is not None: @@ -156,8 +163,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): system_variables=system_inputs, user_inputs=inputs, environment_variables=self._workflow.environment_variables, - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. conversation_variables=conversation_variables, ) @@ -169,6 +176,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow_id=self._workflow.id, tenant_id=self._workflow.tenant_id, user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, ) db.session.close() @@ -186,12 +195,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): graph=graph, graph_config=self._workflow.graph_dict, user_id=self.application_generate_entity.user_id, - user_from=( - UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else UserFrom.END_USER - ), - invoke_from=self.application_generate_entity.invoke_from, + user_from=user_from, + invoke_from=invoke_from, call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, graph_runtime_state=graph_runtime_state, @@ -214,6 +219,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) workflow_entry.graph_engine.layer(persistence_layer) + conversation_variable_layer = ConversationVariablePersistenceLayer( + ConversationVariableUpdater(session_factory.get_session_maker()) + ) + workflow_entry.graph_engine.layer(conversation_variable_layer) for layer in self._graph_engine_layers: workflow_entry.graph_engine.layer(layer) @@ -323,7 +332,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): trace_manager=app_generate_entity.trace_manager, ) - def _initialize_conversation_variables(self) -> list[VariableUnion]: + def _initialize_conversation_variables(self) -> list[Variable]: """ Initialize conversation variables for the current conversation. @@ -348,7 +357,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation_variables = [var.to_variable() for var in existing_variables] session.commit() - return cast(list[VariableUnion], conversation_variables) + return cast(list[Variable], conversation_variables) def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: """ diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2760466a3b..8b6b8f227b 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner): queue_manager=queue_manager, stream=application_generate_entity.stream, agent=True, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index a6aace168e..07bae66867 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,4 +1,3 @@ -import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Union, final @@ -76,12 +75,24 @@ class BaseAppGenerator: user_inputs = {**user_inputs, **files_inputs, **file_list_inputs} # Check if all files are converted to File - if any(filter(lambda v: isinstance(v, dict), user_inputs.values())): - raise ValueError("Invalid input type") - if any( - filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values())) - ): - raise ValueError("Invalid input type") + invalid_dict_keys = [ + k + for k, v in user_inputs.items() + if isinstance(v, dict) + and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT} + ] + if invalid_dict_keys: + raise ValueError(f"Invalid input type for {invalid_dict_keys}") + + invalid_list_dict_keys = [ + k + for k, v in user_inputs.items() + if isinstance(v, list) + and any(isinstance(item, dict) for item in v) + and entity_dictionary[k].type != VariableEntityType.FILE_LIST + ] + if invalid_list_dict_keys: + raise ValueError(f"Invalid input type for {invalid_list_dict_keys}") return user_inputs @@ -178,12 +189,8 @@ class BaseAppGenerator: elif value == 0: value = False case VariableEntityType.JSON_OBJECT: - if not isinstance(value, str): - raise ValueError(f"{variable_entity.variable} in input form must be a string") - try: - json.loads(value) - except json.JSONDecodeError: - raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object") + if value and not isinstance(value, dict): + raise ValueError(f"{variable_entity.variable} in input form must be a dict") case _: raise AssertionError("this statement should be unreachable.") diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 698eee9894..b41bedbea4 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -90,6 +90,7 @@ class AppQueueManager: """ self._clear_task_belong_cache() self._q.put(None) + self._graph_runtime_state = None # Release reference to allow GC to reclaim memory def _clear_task_belong_cache(self) -> None: """ diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e2e6c11480..617515945b 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,8 @@ +import base64 import logging import time from collections.abc import Generator, Mapping, Sequence +from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ModelConfigWithCredentialsEntity, ) -from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, +) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch +from core.file.enums import FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, + TextPromptMessageContent, ) from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError @@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform -from models.model import App, AppMode, Message, MessageAnnotation +from core.tools.tool_file_manager import ToolFileManager +from extensions.ext_database import db +from models.enums import CreatorUserRole +from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: from core.file.models import File @@ -203,6 +215,9 @@ class AppRunner: queue_manager: AppQueueManager, stream: bool, agent: bool = False, + message_id: str | None = None, + user_id: str | None = None, + tenant_id: str | None = None, ): """ Handle invoke result @@ -210,21 +225,41 @@ class AppRunner: :param queue_manager: application queue manager :param stream: stream :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ if not stream and isinstance(invoke_result, LLMResult): - self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + self._handle_invoke_result_direct( + invoke_result=invoke_result, + queue_manager=queue_manager, + ) elif stream and isinstance(invoke_result, Generator): - self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + self._handle_invoke_result_stream( + invoke_result=invoke_result, + queue_manager=queue_manager, + agent=agent, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + ) else: raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") - def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): + def _handle_invoke_result_direct( + self, + invoke_result: LLMResult, + queue_manager: AppQueueManager, + ): """ Handle invoke result direct :param invoke_result: invoke result :param queue_manager: application queue manager :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ queue_manager.publish( @@ -235,13 +270,22 @@ class AppRunner: ) def _handle_invoke_result_stream( - self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool + self, + invoke_result: Generator[LLMResultChunk, None, None], + queue_manager: AppQueueManager, + agent: bool, + message_id: str | None = None, + user_id: str | None = None, + tenant_id: str | None = None, ): """ Handle invoke result :param invoke_result: invoke result :param queue_manager: application queue manager :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ model: str = "" @@ -259,12 +303,26 @@ class AppRunner: text += message.content elif isinstance(message.content, list): for content in message.content: - if not isinstance(content, str): - # TODO(QuantumGhost): Add multimodal output support for easy ui. - _logger.warning("received multimodal output, type=%s", type(content)) + if isinstance(content, str): + text += content + elif isinstance(content, TextPromptMessageContent): text += content.data + elif isinstance(content, ImagePromptMessageContent): + if message_id and user_id and tenant_id: + try: + self._handle_multimodal_image_content( + content=content, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + queue_manager=queue_manager, + ) + except Exception: + _logger.exception("Failed to handle multimodal image output") + else: + _logger.warning("Received multimodal output but missing required parameters") else: - text += content # failback to str + text += content.data if hasattr(content, "data") else str(content) if not model: model = result.model @@ -289,6 +347,101 @@ class AppRunner: PublishFrom.APPLICATION_MANAGER, ) + def _handle_multimodal_image_content( + self, + content: ImagePromptMessageContent, + message_id: str, + user_id: str, + tenant_id: str, + queue_manager: AppQueueManager, + ): + """ + Handle multimodal image content from LLM response. + Save the image and create a MessageFile record. + + :param content: ImagePromptMessageContent instance + :param message_id: message id + :param user_id: user id + :param tenant_id: tenant id + :param queue_manager: queue manager + :return: + """ + _logger.info("Handling multimodal image content for message %s", message_id) + + image_url = content.url + base64_data = content.base64_data + + _logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data) + + if not image_url and not base64_data: + _logger.warning("Image content has neither URL nor base64 data") + return + + tool_file_manager = ToolFileManager() + + # Save the image file + try: + if image_url: + # Download image from URL + _logger.info("Downloading image from URL: %s", image_url) + tool_file = tool_file_manager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + file_url=image_url, + conversation_id=None, + ) + _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id) + elif base64_data: + if base64_data.startswith("data:"): + base64_data = base64_data.split(",", 1)[1] + + image_binary = base64.b64decode(base64_data) + mimetype = content.mime_type or "image/png" + extension = guess_extension(mimetype) or ".png" + + tool_file = tool_file_manager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=image_binary, + mimetype=mimetype, + filename=f"generated_image{extension}", + ) + _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id) + else: + return + except Exception: + _logger.exception("Failed to save image file") + return + + # Create MessageFile record + message_file = MessageFile( + message_id=message_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + belongs_to="assistant", + url=f"/files/tools/{tool_file.id}", + upload_file_id=tool_file.id, + created_by_role=( + CreatorUserRole.ACCOUNT + if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} + else CreatorUserRole.END_USER + ), + created_by=user_id, + ) + + db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + + # Publish QueueMessageFileEvent + queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file.id), + PublishFrom.APPLICATION_MANAGER, + ) + + _logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id) + def moderation_for_inputs( self, *, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f8338b226b..7d1a4c619f 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index ddfb5725b4..a872c2e1f7 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index d67c4846aa..2959771940 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,3 +1,4 @@ +import json import logging import uuid from collections.abc import Callable, Generator, Mapping diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 13eb40fd60..ea4441b5d8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) documents: list[Document] = [] - if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"): + if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"): from services.dataset_service import DocumentService for datasource_info in datasource_info_list: @@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator): for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) document_id = args.get("original_document_id") or None - if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry: document_id = document_id or documents[i].id document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 4be9e01fbf..8ea34344b2 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, ) +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.app.workflow.node_factory import DifyNodeFactory from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.enums import WorkflowType from core.workflow.graph import Graph -from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -73,9 +73,15 @@ class PipelineRunner(WorkflowBasedAppRunner): """ app_config = self.application_generate_entity.app_config app_config = cast(PipelineConfig, app_config) + invoke_from = self.application_generate_entity.invoke_from + + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + invoke_from = InvokeFrom.DEBUGGER + + user_from = self._resolve_user_from(invoke_from) user_id = None - if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -117,7 +123,7 @@ class PipelineRunner(WorkflowBasedAppRunner): dataset_id=self.application_generate_entity.dataset_id, datasource_type=self.application_generate_entity.datasource_type, datasource_info=self.application_generate_entity.datasource_info, - invoke_from=self.application_generate_entity.invoke_from.value, + invoke_from=invoke_from.value, ) rag_pipeline_variables = [] @@ -149,6 +155,8 @@ class PipelineRunner(WorkflowBasedAppRunner): graph_runtime_state=graph_runtime_state, start_node_id=self.application_generate_entity.start_node_id, workflow=workflow, + user_from=user_from, + invoke_from=invoke_from, ) # RUN WORKFLOW @@ -159,12 +167,8 @@ class PipelineRunner(WorkflowBasedAppRunner): graph=graph, graph_config=workflow.graph_dict, user_id=self.application_generate_entity.user_id, - user_from=( - UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else UserFrom.END_USER - ), - invoke_from=self.application_generate_entity.invoke_from, + user_from=user_from, + invoke_from=invoke_from, call_depth=self.application_generate_entity.call_depth, graph_runtime_state=graph_runtime_state, variable_pool=variable_pool, @@ -210,7 +214,12 @@ class PipelineRunner(WorkflowBasedAppRunner): return workflow def _init_rag_pipeline_graph( - self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None + self, + workflow: Workflow, + graph_runtime_state: GraphRuntimeState, + start_node_id: str | None = None, + user_from: UserFrom = UserFrom.ACCOUNT, + invoke_from: InvokeFrom = InvokeFrom.SERVICE_API, ) -> Graph: """ Init pipeline graph @@ -253,8 +262,8 @@ class PipelineRunner(WorkflowBasedAppRunner): workflow_id=workflow.id, graph_config=graph_config, user_id=self.application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + user_from=user_from, + invoke_from=invoke_from, call_depth=0, ) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 05ba53149b..dc5852d552 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import contextvars import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Literal, Union, overload +from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker import contexts from configs import dify_config @@ -24,6 +26,7 @@ from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTas from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer +from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager @@ -43,6 +46,9 @@ from models.model import App, EndUser from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService +if TYPE_CHECKING: + from controllers.console.app.workflow import LoopNodeRunPayload + SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs" logger = logging.getLogger(__name__) @@ -434,7 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow: Workflow, node_id: str, user: Account | EndUser, - args: Mapping[str, Any], + args: LoopNodeRunPayload, streaming: bool = True, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: """ @@ -450,7 +456,7 @@ class WorkflowAppGenerator(BaseAppGenerator): if not node_id: raise ValueError("node_id is required") - if args.get("inputs") is None: + if args.inputs is None: raise ValueError("inputs is required") # convert to app config @@ -466,7 +472,7 @@ class WorkflowAppGenerator(BaseAppGenerator): stream=streaming, invoke_from=InvokeFrom.DEBUGGER, extras={"auto_generate_conversation_name": False}, - single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), + single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}), workflow_execution_id=str(uuid.uuid4()), ) contexts.plugin_tool_providers.set({}) @@ -532,7 +538,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :return: """ with preserve_flask_contexts(flask_app, context_vars=context): - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: workflow = session.scalar( select(Workflow).where( Workflow.tenant_id == application_generate_entity.app_config.tenant_id, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 44e63c7c4d..077b321104 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -7,10 +7,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -20,7 +20,6 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span from libs.datetime_utils import naive_utc_now -from models.enums import UserFrom from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -66,6 +65,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): """ app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) + invoke_from = self.application_generate_entity.invoke_from + # if only single iteration or single loop run is requested + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + invoke_from = InvokeFrom.DEBUGGER + user_from = self._resolve_user_from(invoke_from) resume_state = self._resume_graph_runtime_state @@ -109,6 +113,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_id=self._workflow.id, tenant_id=self._workflow.tenant_id, user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, root_node_id=self._root_node_id, ) @@ -127,12 +133,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): graph=graph, graph_config=self._workflow.graph_dict, user_id=self.application_generate_entity.user_id, - user_from=( - UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else UserFrom.END_USER - ), - invoke_from=self.application_generate_entity.invoke_from, + user_from=user_from, + invoke_from=invoke_from, call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, graph_runtime_state=graph_runtime_state, diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index b09385aad3..c9d7464c17 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, @@ -28,6 +29,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.graph import Graph @@ -60,7 +62,6 @@ from core.workflow.graph_events import ( ) from core.workflow.graph_events.graph import GraphRunAbortedEvent from core.workflow.nodes import NodeType -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -87,10 +88,18 @@ class WorkflowBasedAppRunner: self._app_id = app_id self._graph_engine_layers = graph_engine_layers + @staticmethod + def _resolve_user_from(invoke_from: InvokeFrom) -> UserFrom: + if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}: + return UserFrom.ACCOUNT + return UserFrom.END_USER + def _init_graph( self, graph_config: Mapping[str, Any], graph_runtime_state: GraphRuntimeState, + user_from: UserFrom, + invoke_from: InvokeFrom, workflow_id: str = "", tenant_id: str = "", user_id: str = "", @@ -115,8 +124,8 @@ class WorkflowBasedAppRunner: workflow_id=workflow_id, graph_config=graph_config, user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=self._queue_manager.invoke_from, + user_from=user_from, + invoke_from=invoke_from, call_depth=0, ) @@ -159,7 +168,7 @@ class WorkflowBasedAppRunner: # Create initial runtime state with variable pool containing environment variables graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, environment_variables=workflow.environment_variables, ), @@ -168,18 +177,22 @@ class WorkflowBasedAppRunner: # Determine which type of single node execution and get graph/variable_pool if single_iteration_run: - graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run( workflow=workflow, node_id=single_iteration_run.node_id, user_inputs=dict(single_iteration_run.inputs), graph_runtime_state=graph_runtime_state, + node_type_filter_key="iteration_id", + node_type_label="iteration", ) elif single_loop_run: - graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( + graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run( workflow=workflow, node_id=single_loop_run.node_id, user_inputs=dict(single_loop_run.inputs), graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", ) else: raise ValueError("Neither single_iteration_run nor single_loop_run is specified") @@ -260,7 +273,7 @@ class WorkflowBasedAppRunner: graph_config=graph_config, user_id="", user_from=UserFrom.ACCOUNT, - invoke_from=self._queue_manager.invoke_from, + invoke_from=InvokeFrom.DEBUGGER, call_depth=0, ) @@ -270,7 +283,9 @@ class WorkflowBasedAppRunner: ) # init graph - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id) + graph = Graph.init( + graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True + ) if not graph: raise ValueError("graph not found in workflow") @@ -316,44 +331,6 @@ class WorkflowBasedAppRunner: return graph, variable_pool - def _get_graph_and_variable_pool_of_single_iteration( - self, - workflow: Workflow, - node_id: str, - user_inputs: dict[str, Any], - graph_runtime_state: GraphRuntimeState, - ) -> tuple[Graph, VariablePool]: - """ - Get variable pool of single iteration - """ - return self._get_graph_and_variable_pool_for_single_node_run( - workflow=workflow, - node_id=node_id, - user_inputs=user_inputs, - graph_runtime_state=graph_runtime_state, - node_type_filter_key="iteration_id", - node_type_label="iteration", - ) - - def _get_graph_and_variable_pool_of_single_loop( - self, - workflow: Workflow, - node_id: str, - user_inputs: dict[str, Any], - graph_runtime_state: GraphRuntimeState, - ) -> tuple[Graph, VariablePool]: - """ - Get variable pool of single loop - """ - return self._get_graph_and_variable_pool_for_single_node_run( - workflow=workflow, - node_id=node_id, - user_inputs=user_inputs, - graph_runtime_state=graph_runtime_state, - node_type_filter_key="loop_id", - node_type_label="loop", - ) - def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): """ Handle event diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 221a69cd3f..0e68e554c8 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -42,7 +42,8 @@ class InvokeFrom(StrEnum): # DEBUGGER indicates that this invocation is from # the workflow (or chatflow) edit page. DEBUGGER = "debugger" - PUBLISHED = "published" + # PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow. + PUBLISHED_PIPELINE = "published" # VALIDATION indicates that this invocation is from validation. VALIDATION = "validation" diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 79fbafe39e..3f9f3da9b2 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -75,7 +75,7 @@ class AnnotationReplyFeature: AppAnnotationService.add_annotation_history( annotation.id, app_record.id, - annotation.question, + annotation.question_text, annotation.content, query, user_id, diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py new file mode 100644 index 0000000000..c070845b73 --- /dev/null +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -0,0 +1,60 @@ +import logging + +from core.variables import VariableBase +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers + +logger = logging.getLogger(__name__) + + +class ConversationVariablePersistenceLayer(GraphEngineLayer): + def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None: + super().__init__() + self._conversation_variable_updater = conversation_variable_updater + + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + if not isinstance(event, NodeRunSucceededEvent): + return + if event.node_type != NodeType.VARIABLE_ASSIGNER: + return + if self.graph_runtime_state is None: + return + + updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] + if not updated_variables: + return + + conversation_id = self.graph_runtime_state.system_variable.conversation_id + if conversation_id is None: + return + + updated_any = False + for item in updated_variables: + selector = item.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) + continue + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + continue + variable = self.graph_runtime_state.variable_pool.get(selector) + if not isinstance(variable, VariableBase): + logger.warning( + "Conversation variable not found in variable pool. selector=%s", + selector, + ) + continue + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) + updated_any = True + + if updated_any: + self._conversation_variable_updater.flush() + + def on_graph_end(self, error: Exception | None) -> None: + pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index e1b5352c2a..1c267091a4 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -75,6 +75,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): """ if isinstance(session_factory, Engine): session_factory = sessionmaker(session_factory) + super().__init__() self._session_maker = session_factory self._state_owner_user_id = state_owner_user_id self._generate_entity = generate_entity @@ -107,8 +108,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer): if not isinstance(event, GraphRunPausedEvent): return - assert self.graph_runtime_state is not None - entity_wrapper: _GenerateEntityUnion if isinstance(self._generate_entity, WorkflowAppGenerateEntity): entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity) diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index fe1a46a945..a7ea9ef446 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -3,8 +3,8 @@ from datetime import UTC, datetime from typing import Any, ClassVar from pydantic import TypeAdapter -from sqlalchemy.orm import Session, sessionmaker +from core.db.session_factory import session_factory from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent @@ -31,12 +31,11 @@ class TriggerPostLayer(GraphEngineLayer): cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity, start_time: datetime, trigger_log_id: str, - session_maker: sessionmaker[Session], ): + super().__init__() self.trigger_log_id = trigger_log_id self.start_time = start_time self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity - self.session_maker = session_maker def on_graph_start(self): pass @@ -46,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer): Update trigger log with success or failure. """ if isinstance(event, tuple(self._STATUS_MAP.keys())): - with self.session_maker() as session: + with session_factory.create_session() as session: repo = SQLAlchemyWorkflowTriggerLogRepository(session) trigger_log = repo.get_by_id(self.trigger_log_id) if not trigger_log: @@ -57,10 +56,6 @@ class TriggerPostLayer(GraphEngineLayer): elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds() # Extract relevant data from result - if not self.graph_runtime_state: - logger.exception("Graph runtime state is not set") - return - outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5bb93fa44a..6c997753fa 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -39,6 +39,7 @@ from core.app.entities.task_entities import ( MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, + StreamEvent, StreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): _task_state: EasyUITaskState _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + _precomputed_event_type: StreamEvent | None = None def __init__( self, @@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) + # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks + if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None: + self._precomputed_event_type = self._message_cycle_manager.get_message_event_type( + message_id=self._message_id + ) yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, - event_type=event_type, + event_type=self._precomputed_event_type, ) else: yield self._agent_message_to_stream_response( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index c6e89a7663..d682083f34 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -5,7 +5,7 @@ from threading import Thread from typing import Union from flask import Flask, current_app -from sqlalchemy import exists, select +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -30,6 +30,7 @@ from core.app.entities.task_entities import ( StreamEvent, WorkflowTaskState, ) +from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.tools.signature import sign_tool_file from extensions.ext_database import db @@ -57,13 +58,15 @@ class MessageCycleManager: self._message_has_file: set[str] = set() def get_message_event_type(self, message_id: str) -> StreamEvent: + # Fast path: cached determination from prior QueueMessageFileEvent if message_id in self._message_has_file: return StreamEvent.MESSAGE_FILE - with Session(db.engine, expire_on_commit=False) as session: - has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() + # Use SQLAlchemy 2.x style session.scalar(select(...)) + with session_factory.create_session() as session: + message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id)) - if has_file: + if message_file: self._message_has_file.add(message_id) return StreamEvent.MESSAGE_FILE @@ -201,6 +204,8 @@ class MessageCycleManager: message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id)) if message_file and message_file.url is not None: + self._message_has_file.add(message_file.message_id) + # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py new file mode 100644 index 0000000000..172ee5d703 --- /dev/null +++ b/api/core/app/workflow/__init__.py @@ -0,0 +1,3 @@ +from .node_factory import DifyNodeFactory + +__all__ = ["DifyNodeFactory"] diff --git a/api/core/app/workflow/layers/__init__.py b/api/core/app/workflow/layers/__init__.py new file mode 100644 index 0000000000..945f75303c --- /dev/null +++ b/api/core/app/workflow/layers/__init__.py @@ -0,0 +1,10 @@ +"""Workflow-level GraphEngine layers that depend on outer infrastructure.""" + +from .observability import ObservabilityLayer +from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer + +__all__ = [ + "ObservabilityLayer", + "PersistenceWorkflowInfo", + "WorkflowPersistenceLayer", +] diff --git a/api/core/workflow/graph_engine/layers/observability.py b/api/core/app/workflow/layers/observability.py similarity index 91% rename from api/core/workflow/graph_engine/layers/observability.py rename to api/core/app/workflow/layers/observability.py index a674816884..94839c8ae3 100644 --- a/api/core/workflow/graph_engine/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -18,12 +18,15 @@ from typing_extensions import override from configs import dify_config from core.workflow.enums import NodeType from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_engine.layers.node_parsers import ( +from core.workflow.graph_events import GraphNodeEventBase +from core.workflow.nodes.base.node import Node +from extensions.otel.parser import ( DefaultNodeOTelParser, + LLMNodeOTelParser, NodeOTelParser, + RetrievalNodeOTelParser, ToolNodeOTelParser, ) -from core.workflow.nodes.base.node import Node from extensions.otel.runtime import is_instrument_flag_enabled logger = logging.getLogger(__name__) @@ -72,6 +75,8 @@ class ObservabilityLayer(GraphEngineLayer): """Initialize parser registry for node types.""" self._parsers = { NodeType.TOOL: ToolNodeOTelParser(), + NodeType.LLM: LLMNodeOTelParser(), + NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(), } def _get_parser(self, node: Node) -> NodeOTelParser: @@ -119,7 +124,9 @@ class ObservabilityLayer(GraphEngineLayer): logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e) @override - def on_node_run_end(self, node: Node, error: Exception | None) -> None: + def on_node_run_end( + self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: """ Called when a node finishes execution. @@ -139,7 +146,7 @@ class ObservabilityLayer(GraphEngineLayer): span = node_context.span parser = self._get_parser(node) try: - parser.parse(node=node, span=span, error=error) + parser.parse(node=node, span=span, error=error, result_event=result_event) span.end() finally: token = node_context.token diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/app/workflow/layers/persistence.py similarity index 99% rename from api/core/workflow/graph_engine/layers/persistence.py rename to api/core/app/workflow/layers/persistence.py index b70f36ec9e..41052b4f52 100644 --- a/api/core/workflow/graph_engine/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -45,7 +45,6 @@ from core.workflow.graph_events import ( from core.workflow.node_events import NodeRunResult from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_entry import WorkflowEntry from libs.datetime_utils import naive_utc_now @@ -316,6 +315,9 @@ class WorkflowPersistenceLayer(GraphEngineLayer): # workflow inputs stay reusable without binding future runs to this conversation. continue inputs[f"sys.{field_name}"] = value + # Local import to avoid circular dependency during app bootstrapping. + from core.workflow.workflow_entry import WorkflowEntry + handled = WorkflowEntry.handle_special_values(inputs) return handled or {} @@ -337,8 +339,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer): if update_finished: execution.finished_at = naive_utc_now() runtime_state = self.graph_runtime_state - if runtime_state is None: - return execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs @@ -404,6 +404,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer): def _system_variables(self) -> Mapping[str, Any]: runtime_state = self.graph_runtime_state - if runtime_state is None: - return {} return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py new file mode 100644 index 0000000000..e0a0059a38 --- /dev/null +++ b/api/core/app/workflow/node_factory.py @@ -0,0 +1,152 @@ +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, final + +from typing_extensions import override + +from configs import dify_config +from core.file import file_manager +from core.helper import ssrf_proxy +from core.helper.code_executor.code_executor import CodeExecutor +from core.helper.code_executor.code_node_provider import CodeNodeProvider +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.enums import NodeType +from core.workflow.graph import NodeFactory +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.nodes.http_request.node import HttpRequestNode +from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol +from core.workflow.nodes.template_transform.template_renderer import ( + CodeExecutorJinja2TemplateRenderer, + Jinja2TemplateRenderer, +) +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from libs.typing import is_str, is_str_dict + +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + + +@final +class DifyNodeFactory(NodeFactory): + """ + Default implementation of NodeFactory that uses the traditional node mapping. + + This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING + and instantiating the appropriate node class. + """ + + def __init__( + self, + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + code_executor: type[CodeExecutor] | None = None, + code_providers: Sequence[type[CodeNodeProvider]] | None = None, + code_limits: CodeNodeLimits | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, + http_request_http_client: HttpClientProtocol = ssrf_proxy, + http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, + http_request_file_manager: FileManagerProtocol = file_manager, + ) -> None: + self.graph_init_params = graph_init_params + self.graph_runtime_state = graph_runtime_state + self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor + self._code_providers: tuple[type[CodeNodeProvider], ...] = ( + tuple(code_providers) if code_providers else CodeNode.default_code_providers() + ) + self._code_limits = code_limits or CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ) + self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() + self._http_request_http_client = http_request_http_client + self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory + self._http_request_file_manager = http_request_file_manager + + @override + def create_node(self, node_config: dict[str, object]) -> Node: + """ + Create a Node instance from node configuration data using the traditional mapping. + + :param node_config: node configuration dictionary containing type and other data + :return: initialized Node instance + :raises ValueError: if node type is unknown or configuration is invalid + """ + # Get node_id from config + node_id = node_config.get("id") + if not is_str(node_id): + raise ValueError("Node config missing id") + + # Get node type from config + node_data = node_config.get("data", {}) + if not is_str_dict(node_data): + raise ValueError(f"Node {node_id} missing data information") + + node_type_str = node_data.get("type") + if not is_str(node_type_str): + raise ValueError(f"Node {node_id} missing or invalid type information") + + try: + node_type = NodeType(node_type_str) + except ValueError: + raise ValueError(f"Unknown node type: {node_type_str}") + + # Get node class + node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) + if not node_mapping: + raise ValueError(f"No class mapping found for node type: {node_type}") + + latest_node_class = node_mapping.get(LATEST_VERSION) + node_version = str(node_data.get("version", "1")) + matched_node_class = node_mapping.get(node_version) + node_class = matched_node_class or latest_node_class + if not node_class: + raise ValueError(f"No latest version class found for node type: {node_type}") + + # Create node instance + if node_type == NodeType.CODE: + return CodeNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + code_executor=self._code_executor, + code_providers=self._code_providers, + code_limits=self._code_limits, + ) + + if node_type == NodeType.TEMPLATE_TRANSFORM: + return TemplateTransformNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + template_renderer=self._template_renderer, + ) + + if node_type == NodeType.HTTP_REQUEST: + return HttpRequestNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + http_client=self._http_request_http_client, + tool_file_manager_factory=self._http_request_tool_file_manager_factory, + file_manager=self._http_request_file_manager, + ) + + return node_class( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + ) diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 50c7249fe4..451e4fda0e 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from configs import dify_config @@ -30,7 +32,7 @@ class DatasourcePlugin(ABC): """ return DatasourceProviderType.LOCAL_FILE - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin: return self.__class__( entity=self.entity.model_copy(), runtime=runtime, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 260dcf04f5..dde7d59726 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from enum import StrEnum from typing import Any @@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DRIVE = "online_drive" @classmethod - def value_of(cls, value: str) -> "DatasourceProviderType": + def value_of(cls, value: str) -> DatasourceProviderType: """ Get value of given mode. @@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter): typ: DatasourceParameterType, required: bool, options: list[str] | None = None, - ) -> "DatasourceParameter": + ) -> DatasourceParameter: """ get a simple datasource parameter @@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel): tool_config: dict | None = None @classmethod - def empty(cls) -> "DatasourceInvokeMeta": + def empty(cls) -> DatasourceInvokeMeta: """ Get an empty instance of DatasourceInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "DatasourceInvokeMeta": + def error_instance(cls, error: str) -> DatasourceInvokeMeta: """ Get an instance of DatasourceInvokeMeta with error """ diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 98ea15e3fc..ce23da1e09 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin @@ -34,7 +34,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): def get_online_document_pages( self, user_id: str, - datasource_parameters: Mapping[str, Any], + datasource_parameters: dict[str, Any], provider_type: str, ) -> Generator[OnlineDocumentPagesMessage, None, None]: manager = PluginDatasourceManager() diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py index 1dae2eafd4..45d4bc4594 100644 --- a/api/core/db/session_factory.py +++ b/api/core/db/session_factory.py @@ -1,7 +1,7 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -_session_maker: sessionmaker | None = None +_session_maker: sessionmaker[Session] | None = None def configure_session_factory(engine: Engine, expire_on_commit: bool = False): @@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False): _session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) -def get_session_maker() -> sessionmaker: +def get_session_maker() -> sessionmaker[Session]: if _session_maker is None: raise RuntimeError("Session factory not configured. Call configure_session_factory() first.") return _session_maker @@ -27,7 +27,7 @@ class SessionFactory: configure_session_factory(engine, expire_on_commit) @staticmethod - def get_session_maker() -> sessionmaker: + def get_session_maker() -> sessionmaker[Session]: return get_session_maker() @staticmethod diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 7fdf5e4be6..135d2a4945 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime from enum import StrEnum @@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel): updated_at: datetime @classmethod - def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity": + def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity: """Create entity from database model with decryption""" return cls( diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 12431976f0..a123fb0321 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: list[ModelType] def __init__(self, provider_entity: ProviderEntity): @@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel): label=provider_entity.label, icon_small=provider_entity.icon_small, icon_small_dark=provider_entity.icon_small_dark, - icon_large=provider_entity.icon_large, supported_model_types=provider_entity.supported_model_types, ) @@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel): provider: str label: I18nObject icon_small: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] = [] diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 8a8067332d..0078ec7e4f 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import StrEnum, auto from typing import Union @@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel): TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR @classmethod - def value_of(cls, value: str) -> "ProviderConfig.Type": + def value_of(cls, value: str) -> ProviderConfig.Type: """ Get value of given mode. diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 6d553d7dc6..2ac483673a 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -8,8 +8,9 @@ import urllib.parse from configs import dify_config -def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str: - url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" +def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str: + base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL) + url = f"{base_url}/files/{upload_file_id}/file-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() diff --git a/api/core/file/models.py b/api/core/file/models.py index d149205d77..6324523b22 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -112,17 +112,17 @@ class File(BaseModel): return text - def generate_url(self) -> str | None: + def generate_url(self, for_external: bool = True) -> str | None: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.remote_url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: if self.related_id is None: raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id) + return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: assert self.related_id is not None assert self.extension is not None - return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) + return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external) return None def to_plugin_parameter(self) -> dict[str, Any]: @@ -133,7 +133,7 @@ class File(BaseModel): "extension": self.extension, "size": self.size, "type": self.type, - "url": self.generate_url(), + "url": self.generate_url(for_external=False), } @model_validator(mode="after") diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index 969125d2f7..5e4807401e 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -1,9 +1,14 @@ +from collections.abc import Mapping from textwrap import dedent +from typing import Any from core.helper.code_executor.template_transformer import TemplateTransformer class Jinja2TemplateTransformer(TemplateTransformer): + # Use separate placeholder for base64-encoded template to avoid confusion + _template_b64_placeholder: str = "{{template_b64}}" + @classmethod def transform_response(cls, response: str): """ @@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer): """ return {"result": cls.extract_result_str_from_response(response)} + @classmethod + def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str: + """ + Override base class to use base64 encoding for template code. + This prevents issues with special characters (quotes, newlines) in templates + breaking the generated Python script. Fixes #26818. + """ + script = cls.get_runner_script() + # Encode template as base64 to safely embed any content including quotes + code_b64 = cls.serialize_code(code) + script = script.replace(cls._template_b64_placeholder, code_b64) + inputs_str = cls.serialize_inputs(inputs) + script = script.replace(cls._inputs_placeholder, inputs_str) + return script + @classmethod def get_runner_script(cls) -> str: runner_script = dedent(f""" - # declare main function - def main(**inputs): - import jinja2 - template = jinja2.Template('''{cls._code_placeholder}''') - return template.render(**inputs) - + import jinja2 import json from base64 import b64decode + # declare main function + def main(**inputs): + # Decode base64-encoded template to handle special characters safely + template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8') + template = jinja2.Template(template_code) + return template.render(**inputs) + # decode and prepare input dict inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8')) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 3965f8cb31..5cdea19a8d 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -13,6 +13,15 @@ class TemplateTransformer(ABC): _inputs_placeholder: str = "{{inputs}}" _result_tag: str = "<>" + @classmethod + def serialize_code(cls, code: str) -> str: + """ + Serialize template code to base64 to safely embed in generated script. + This prevents issues with special characters like quotes breaking the script. + """ + code_bytes = code.encode("utf-8") + return b64encode(code_bytes).decode("utf-8") + @classmethod def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]: """ @@ -67,7 +76,7 @@ class TemplateTransformer(ABC): Post-process the result to convert scientific notation strings back to numbers """ - def convert_scientific_notation(value): + def convert_scientific_notation(value: Any) -> Any: if isinstance(value, str): # Check if the string looks like scientific notation if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE): @@ -81,7 +90,7 @@ class TemplateTransformer(ABC): return [convert_scientific_notation(v) for v in value] return value - return convert_scientific_notation(result) # type: ignore[no-any-return] + return convert_scientific_notation(result) @classmethod @abstractmethod diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0b36969cf9..128c64ff2c 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError): pass +request_error = httpx.RequestError +max_retries_exceeded_error = MaxRetriesExceededError + + def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]: return { "http://": httpx.HTTPTransport( @@ -88,7 +92,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None: return None +def _inject_trace_headers(headers: dict | None) -> dict: + """ + Inject W3C traceparent header for distributed tracing. + + When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically. + When OTEL is disabled, we manually inject the traceparent header. + """ + if headers is None: + headers = {} + + # Skip if already present (case-insensitive check) + for key in headers: + if key.lower() == "traceparent": + return headers + + # Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically + if dify_config.ENABLE_OTEL: + return headers + + # Generate and inject traceparent for non-OTEL scenarios + try: + from core.helper.trace_id_helper import generate_traceparent_header + + traceparent = generate_traceparent_header() + if traceparent: + headers["traceparent"] = traceparent + except Exception: + # Silently ignore errors to avoid breaking requests + logger.debug("Failed to generate traceparent header", exc_info=True) + + return headers + + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + # Convert requests-style allow_redirects to httpx-style follow_redirects if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") if "follow_redirects" not in kwargs: @@ -106,18 +144,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) client = _get_ssrf_client(verify_option) + # Inject traceparent header for distributed tracing (when OTEL is not enabled) + headers = kwargs.get("headers") or {} + headers = _inject_trace_headers(headers) + kwargs["headers"] = headers + # Preserve user-provided Host header # When using a forward proxy, httpx may override the Host header based on the URL. # We extract and preserve any explicitly set Host header to support virtual hosting. - headers = kwargs.get("headers", {}) user_provided_host = _get_user_provided_host_header(headers) retries = 0 while retries <= max_retries: try: - # Build the request manually to preserve the Host header - # httpx may override the Host header when using a proxy, so we use - # the request API to explicitly set headers before sending + # Preserve the user-provided Host header + # httpx may override the Host header when using a proxy headers = {k: v for k, v in headers.items() if k.lower() != "host"} if user_provided_host is not None: headers["host"] = user_provided_host diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py deleted file mode 100644 index c5447c2b3f..0000000000 --- a/api/core/helper/tool_provider_cache.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -import logging -from typing import Any, cast - -from core.tools.entities.api_entities import ToolProviderTypeApiLiteral -from extensions.ext_redis import redis_client, redis_fallback - -logger = logging.getLogger(__name__) - - -class ToolProviderListCache: - """Cache for tool provider lists""" - - CACHE_TTL = 300 # 5 minutes - - @staticmethod - def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str: - """Generate cache key for tool providers list""" - type_filter = typ or "all" - return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}" - - @staticmethod - @redis_fallback(default_return=None) - def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None: - """Get cached tool providers""" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - cached_data = redis_client.get(cache_key) - if cached_data: - try: - return json.loads(cached_data.decode("utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError): - logger.warning("Failed to decode cached tool providers data") - return None - return None - - @staticmethod - @redis_fallback() - def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]): - """Cache tool providers""" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers)) - - @staticmethod - @redis_fallback() - def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None): - """Invalidate cache for tool providers""" - if typ: - # Invalidate specific type cache - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - redis_client.delete(cache_key) - else: - # Invalidate all caches for this tenant - keys = ["builtin", "model", "api", "workflow", "mcp"] - pipeline = redis_client.pipeline() - for key in keys: - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key)) - pipeline.delete(cache_key) - pipeline.execute() diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index 820502e558..e827859109 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None: if len(parts) == 4 and len(parts[1]) == 32: return parts[1] return None + + +def get_span_id_from_otel_context() -> str | None: + """ + Retrieve the current span ID from the active OpenTelemetry trace context. + + Returns: + A 16-character hex string representing the span ID, or None if not available. + """ + try: + from opentelemetry.trace import get_current_span + from opentelemetry.trace.span import INVALID_SPAN_ID + + span = get_current_span() + if not span: + return None + + span_context = span.get_span_context() + if not span_context or span_context.span_id == INVALID_SPAN_ID: + return None + + return f"{span_context.span_id:016x}" + except Exception: + return None + + +def generate_traceparent_header() -> str | None: + """ + Generate a W3C traceparent header from the current context. + + Uses OpenTelemetry context if available, otherwise uses the + ContextVar-based trace_id from the logging context. + + Format: {version}-{trace_id}-{span_id}-{flags} + Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01 + + Returns: + A valid traceparent header string, or None if generation fails. + """ + import uuid + + # Try OTEL context first + trace_id = get_trace_id_from_otel_context() + span_id = get_span_id_from_otel_context() + + if trace_id and span_id: + return f"00-{trace_id}-{span_id}-01" + + # Fallback: use ContextVar-based trace_id or generate new one + from core.logging.context import get_trace_id as get_logging_trace_id + + trace_id = get_logging_trace_id() or uuid.uuid4().hex + + # Generate a new span_id (16 hex chars) + span_id = uuid.uuid4().hex[:16] + + return f"00-{trace_id}-{span_id}-01" diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index af860a1070..370e64e385 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -56,6 +56,10 @@ class HostingConfiguration: self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax() self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark() self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/tongyi/tongyi"] = self.init_tongyi() self.moderation_config = self.init_moderation_config() @@ -128,7 +132,7 @@ class HostingConfiguration: quotas: list[HostingQuota] = [] if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: - hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT + hosted_quota_limit = 0 trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS") trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) @@ -156,18 +160,49 @@ class HostingConfiguration: quota_unit=quota_unit, ) - @staticmethod - def init_anthropic() -> HostingProvider: - quota_unit = QuotaUnit.TOKENS + def init_gemini(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_GEMINI_TRIAL_ENABLED: + hosted_quota_limit = 0 + trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) + quotas.append(trial_quota) + + if dify_config.HOSTED_GEMINI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + + if len(quotas) > 0: + credentials = { + "google_api_key": dify_config.HOSTED_GEMINI_API_KEY, + } + + if dify_config.HOSTED_GEMINI_API_BASE: + credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + + def init_anthropic(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS quotas: list[HostingQuota] = [] if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: - hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT - trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) quotas.append(trial_quota) if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: - paid_quota = PaidHostingQuota() + paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) if len(quotas) > 0: @@ -185,6 +220,94 @@ class HostingConfiguration: quota_unit=quota_unit, ) + def init_tongyi(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_TONGYI_TRIAL_ENABLED: + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) + quotas.append(trial_quota) + + if dify_config.HOSTED_TONGYI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + + if len(quotas) > 0: + credentials = { + "dashscope_api_key": dify_config.HOSTED_TONGYI_API_KEY, + "use_international_endpoint": dify_config.HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT, + } + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + + def init_xai(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_XAI_TRIAL_ENABLED: + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) + quotas.append(trial_quota) + + if dify_config.HOSTED_XAI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + + if len(quotas) > 0: + credentials = { + "api_key": dify_config.HOSTED_XAI_API_KEY, + } + + if dify_config.HOSTED_XAI_API_BASE: + credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + + def init_deepseek(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED: + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) + quotas.append(trial_quota) + + if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + + if len(quotas) > 0: + credentials = { + "api_key": dify_config.HOSTED_DEEPSEEK_API_KEY, + } + + if dify_config.HOSTED_DEEPSEEK_API_BASE: + credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + @staticmethod def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index b4c3ec1caf..be1e306d47 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -71,8 +71,8 @@ class LLMGenerator: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False ) - answer = cast(str, response.message.content) - if answer is None: + answer = response.message.get_text_content() + if answer == "": return "" try: result_dict = json.loads(answer) @@ -184,7 +184,7 @@ class LLMGenerator: prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) - rule_config["prompt"] = cast(str, response.message.content) + rule_config["prompt"] = response.message.get_text_content() except InvokeError as e: error = str(e) @@ -237,13 +237,11 @@ class LLMGenerator: return rule_config - rule_config["prompt"] = cast(str, prompt_content.message.content) + rule_config["prompt"] = prompt_content.message.get_text_content() - if not isinstance(prompt_content.message.content, str): - raise NotImplementedError("prompt content is not a string") parameter_generate_prompt = parameter_template.format( inputs={ - "INPUT_TEXT": prompt_content.message.content, + "INPUT_TEXT": prompt_content.message.get_text_content(), }, remove_template_variables=False, ) @@ -253,7 +251,7 @@ class LLMGenerator: statement_generate_prompt = statement_template.format( inputs={ "TASK_DESCRIPTION": instruction, - "INPUT_TEXT": prompt_content.message.content, + "INPUT_TEXT": prompt_content.message.get_text_content(), }, remove_template_variables=False, ) @@ -263,7 +261,7 @@ class LLMGenerator: parameter_content: LLMResult = model_instance.invoke_llm( prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False ) - rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) + rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content()) except InvokeError as e: error = str(e) error_step = "generate variables" @@ -272,7 +270,7 @@ class LLMGenerator: statement_content: LLMResult = model_instance.invoke_llm( prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False ) - rule_config["opening_statement"] = cast(str, statement_content.message.content) + rule_config["opening_statement"] = statement_content.message.get_text_content() except InvokeError as e: error = str(e) error_step = "generate conversation opener" @@ -315,7 +313,7 @@ class LLMGenerator: prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) - generated_code = cast(str, response.message.content) + generated_code = response.message.get_text_content() return {"code": generated_code, "language": code_language, "error": ""} except InvokeError as e: @@ -351,7 +349,7 @@ class LLMGenerator: raise TypeError("Expected LLMResult when stream=False") response = result - answer = cast(str, response.message.content) + answer = response.message.get_text_content() return answer.strip() @classmethod @@ -375,10 +373,7 @@ class LLMGenerator: prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) - raw_content = response.message.content - - if not isinstance(raw_content, str): - raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}") + raw_content = response.message.get_text_content() try: parsed_content = json.loads(raw_content) diff --git a/api/core/logging/__init__.py b/api/core/logging/__init__.py new file mode 100644 index 0000000000..db046cc9fa --- /dev/null +++ b/api/core/logging/__init__.py @@ -0,0 +1,20 @@ +"""Structured logging components for Dify.""" + +from core.logging.context import ( + clear_request_context, + get_request_id, + get_trace_id, + init_request_context, +) +from core.logging.filters import IdentityContextFilter, TraceContextFilter +from core.logging.structured_formatter import StructuredJSONFormatter + +__all__ = [ + "IdentityContextFilter", + "StructuredJSONFormatter", + "TraceContextFilter", + "clear_request_context", + "get_request_id", + "get_trace_id", + "init_request_context", +] diff --git a/api/core/logging/context.py b/api/core/logging/context.py new file mode 100644 index 0000000000..18633a0b05 --- /dev/null +++ b/api/core/logging/context.py @@ -0,0 +1,35 @@ +"""Request context for logging - framework agnostic. + +This module provides request-scoped context variables for logging, +using Python's contextvars for thread-safe and async-safe storage. +""" + +import uuid +from contextvars import ContextVar + +_request_id: ContextVar[str] = ContextVar("log_request_id", default="") +_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="") + + +def get_request_id() -> str: + """Get current request ID (10 hex chars).""" + return _request_id.get() + + +def get_trace_id() -> str: + """Get fallback trace ID when OTEL is unavailable (32 hex chars).""" + return _trace_id.get() + + +def init_request_context() -> None: + """Initialize request context. Call at start of each request.""" + req_id = uuid.uuid4().hex[:10] + trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex + _request_id.set(req_id) + _trace_id.set(trace_id) + + +def clear_request_context() -> None: + """Clear request context. Call at end of request (optional).""" + _request_id.set("") + _trace_id.set("") diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py new file mode 100644 index 0000000000..1e8aa8d566 --- /dev/null +++ b/api/core/logging/filters.py @@ -0,0 +1,94 @@ +"""Logging filters for structured logging.""" + +import contextlib +import logging + +import flask + +from core.logging.context import get_request_id, get_trace_id + + +class TraceContextFilter(logging.Filter): + """ + Filter that adds trace_id and span_id to log records. + Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id. + """ + + def filter(self, record: logging.LogRecord) -> bool: + # Get trace context from OpenTelemetry + trace_id, span_id = self._get_otel_context() + + # Set trace_id (fallback to ContextVar if no OTEL context) + if trace_id: + record.trace_id = trace_id + else: + record.trace_id = get_trace_id() + + record.span_id = span_id or "" + + # For backward compatibility, also set req_id + record.req_id = get_request_id() + + return True + + def _get_otel_context(self) -> tuple[str, str]: + """Extract trace_id and span_id from OpenTelemetry context.""" + with contextlib.suppress(Exception): + from opentelemetry.trace import get_current_span + from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID + + span = get_current_span() + if span and span.get_span_context(): + ctx = span.get_span_context() + if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID: + trace_id = f"{ctx.trace_id:032x}" + span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else "" + return trace_id, span_id + return "", "" + + +class IdentityContextFilter(logging.Filter): + """ + Filter that adds user identity context to log records. + Extracts tenant_id, user_id, and user_type from Flask-Login current_user. + """ + + def filter(self, record: logging.LogRecord) -> bool: + identity = self._extract_identity() + record.tenant_id = identity.get("tenant_id", "") + record.user_id = identity.get("user_id", "") + record.user_type = identity.get("user_type", "") + return True + + def _extract_identity(self) -> dict[str, str]: + """Extract identity from current_user if in request context.""" + try: + if not flask.has_request_context(): + return {} + from flask_login import current_user + + # Check if user is authenticated using the proxy + if not current_user.is_authenticated: + return {} + + # Access the underlying user object + user = current_user + + from models import Account + from models.model import EndUser + + identity: dict[str, str] = {} + + if isinstance(user, Account): + if user.current_tenant_id: + identity["tenant_id"] = user.current_tenant_id + identity["user_id"] = user.id + identity["user_type"] = "account" + elif isinstance(user, EndUser): + identity["tenant_id"] = user.tenant_id + identity["user_id"] = user.id + identity["user_type"] = user.type or "end_user" + + return identity + except Exception: + return {} diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py new file mode 100644 index 0000000000..4295d2dd34 --- /dev/null +++ b/api/core/logging/structured_formatter.py @@ -0,0 +1,107 @@ +"""Structured JSON log formatter for Dify.""" + +import logging +import traceback +from datetime import UTC, datetime +from typing import Any + +import orjson + +from configs import dify_config + + +class StructuredJSONFormatter(logging.Formatter): + """ + JSON log formatter following the specified schema: + { + "ts": "ISO 8601 UTC", + "severity": "INFO|ERROR|WARN|DEBUG", + "service": "service name", + "caller": "file:line", + "trace_id": "hex 32", + "span_id": "hex 16", + "identity": { "tenant_id", "user_id", "user_type" }, + "message": "log message", + "attributes": { ... }, + "stack_trace": "..." + } + """ + + SEVERITY_MAP: dict[int, str] = { + logging.DEBUG: "DEBUG", + logging.INFO: "INFO", + logging.WARNING: "WARN", + logging.ERROR: "ERROR", + logging.CRITICAL: "ERROR", + } + + def __init__(self, service_name: str | None = None): + super().__init__() + self._service_name = service_name or dify_config.APPLICATION_NAME + + def format(self, record: logging.LogRecord) -> str: + log_dict = self._build_log_dict(record) + try: + return orjson.dumps(log_dict).decode("utf-8") + except TypeError: + # Fallback: convert non-serializable objects to string + import json + + return json.dumps(log_dict, default=str, ensure_ascii=False) + + def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: + # Core fields + log_dict: dict[str, Any] = { + "ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"), + "severity": self.SEVERITY_MAP.get(record.levelno, "INFO"), + "service": self._service_name, + "caller": f"{record.filename}:{record.lineno}", + "message": record.getMessage(), + } + + # Trace context (from TraceContextFilter) + trace_id = getattr(record, "trace_id", "") + span_id = getattr(record, "span_id", "") + + if trace_id: + log_dict["trace_id"] = trace_id + if span_id: + log_dict["span_id"] = span_id + + # Identity context (from IdentityContextFilter) + identity = self._extract_identity(record) + if identity: + log_dict["identity"] = identity + + # Dynamic attributes + attributes = getattr(record, "attributes", None) + if attributes: + log_dict["attributes"] = attributes + + # Stack trace for errors with exceptions + if record.exc_info and record.levelno >= logging.ERROR: + log_dict["stack_trace"] = self._format_exception(record.exc_info) + + return log_dict + + def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None: + tenant_id = getattr(record, "tenant_id", None) + user_id = getattr(record, "user_id", None) + user_type = getattr(record, "user_type", None) + + if not any([tenant_id, user_id, user_type]): + return None + + identity: dict[str, str] = {} + if tenant_id: + identity["tenant_id"] = tenant_id + if user_id: + identity["user_id"] = user_id + if user_type: + identity["user_type"] = user_type + return identity + + def _format_exception(self, exc_info: tuple[Any, ...]) -> str: + if exc_info and exc_info[0] is not None: + return "".join(traceback.format_exception(*exc_info)) + return "" diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index f81e7cead8..5c3cd0d8f8 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -313,17 +313,20 @@ class StreamableHTTPTransport: if is_initialization: self._maybe_extract_session_id_from_response(response) - content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower()) + # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: + # The server MUST NOT send a response to notifications. + if isinstance(message.root, JSONRPCRequest): + content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower()) - if content_type.startswith(JSON): - self._handle_json_response(response, ctx.server_to_client_queue) - elif content_type.startswith(SSE): - self._handle_sse_response(response, ctx) - else: - self._handle_unexpected_content_type( - content_type, - ctx.server_to_client_queue, - ) + if content_type.startswith(JSON): + self._handle_json_response(response, ctx.server_to_client_queue) + elif content_type.startswith(SSE): + self._handle_sse_response(response, ctx) + else: + self._handle_unexpected_content_type( + content_type, + ctx.server_to_client_queue, + ) def _handle_json_response( self, diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index c97ae6eac7..84a6fd0d1f 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): request_id: RequestId, request_meta: RequestParams.Meta | None, request: ReceiveRequestT, - session: """BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT - ]""", + session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], ): self.request_id = request_id diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 89dae2dbff..9e46d72893 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from collections.abc import Mapping, Sequence from enum import StrEnum, auto @@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum): TOOL = auto() @classmethod - def value_of(cls, value: str) -> "PromptMessageRole": + def value_of(cls, value: str) -> PromptMessageRole: """ Get value of given mode. @@ -249,10 +251,7 @@ class AssistantPromptMessage(PromptMessage): :return: True if prompt message is empty, False otherwise """ - if not super().is_empty() and not self.tool_calls: - return False - - return True + return super().is_empty() and not self.tool_calls class SystemPromptMessage(PromptMessage): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index aee6ce1108..19194d162c 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from decimal import Decimal from enum import StrEnum, auto from typing import Any @@ -20,7 +22,7 @@ class ModelType(StrEnum): TTS = auto() @classmethod - def value_of(cls, origin_model_type: str) -> "ModelType": + def value_of(cls, origin_model_type: str) -> ModelType: """ Get model type from origin model type. @@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum): JSON_SCHEMA = auto() @classmethod - def value_of(cls, value: Any) -> "DefaultParameterName": + def value_of(cls, value: Any) -> DefaultParameterName: """ Get parameter name from value. diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index 648b209ef1..2d88751668 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] models: list[AIModelEntity] = [] @@ -123,7 +122,6 @@ class ProviderEntity(BaseModel): label: I18nObject description: I18nObject | None = None icon_small: I18nObject | None = None - icon_large: I18nObject | None = None icon_small_dark: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None @@ -157,7 +155,6 @@ class ProviderEntity(BaseModel): provider=self.provider, label=self.label, icon_small=self.icon_small, - icon_large=self.icon_large, supported_model_types=self.supported_model_types, models=self.models, ) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index c0f4c504d9..7a0757f219 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,7 +1,7 @@ import logging import time import uuid -from collections.abc import Generator, Sequence +from collections.abc import Callable, Generator, Iterator, Sequence from typing import Union from pydantic import ConfigDict @@ -30,6 +30,142 @@ def _gen_tool_call_id() -> str: return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" +def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None: + if not callbacks: + return + + for callback in callbacks: + try: + invoke(callback) + except Exception as e: + if callback.raise_error: + raise + logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e) + + +def _get_or_create_tool_call( + existing_tools_calls: list[AssistantPromptMessage.ToolCall], + tool_call_id: str, +) -> AssistantPromptMessage.ToolCall: + """ + Get or create a tool call by ID. + + If `tool_call_id` is empty, returns the most recently created tool call. + """ + if not tool_call_id: + if not existing_tools_calls: + raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta") + return existing_tools_calls[-1] + + tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), + ) + existing_tools_calls.append(tool_call) + + return tool_call + + +def _merge_tool_call_delta( + tool_call: AssistantPromptMessage.ToolCall, + delta: AssistantPromptMessage.ToolCall, +) -> None: + if delta.id: + tool_call.id = delta.id + if delta.type: + tool_call.type = delta.type + if delta.function.name: + tool_call.function.name = delta.function.name + if delta.function.arguments: + tool_call.function.arguments += delta.function.arguments + + +def _build_llm_result_from_first_chunk( + model: str, + prompt_messages: Sequence[PromptMessage], + chunks: Iterator[LLMResultChunk], +) -> LLMResult: + """ + Build a single `LLMResult` from the first returned chunk. + + This is used for `stream=False` because the plugin side may still implement the response via a chunked stream. + """ + content = "" + content_list: list[PromptMessageContentUnionTypes] = [] + usage = LLMUsage.empty_usage() + system_fingerprint: str | None = None + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + + first_chunk = next(chunks, None) + if first_chunk is not None: + if isinstance(first_chunk.delta.message.content, str): + content += first_chunk.delta.message.content + elif isinstance(first_chunk.delta.message.content, list): + content_list.extend(first_chunk.delta.message.content) + + if first_chunk.delta.message.tool_calls: + _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) + + usage = first_chunk.delta.usage or LLMUsage.empty_usage() + system_fingerprint = first_chunk.system_fingerprint + + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=content or content_list, + tool_calls=tools_calls, + ), + usage=usage, + system_fingerprint=system_fingerprint, + ) + + +def _invoke_llm_via_plugin( + *, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + model_parameters: dict, + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, +) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + from core.plugin.impl.model import PluginModelClient + + plugin_model_manager = PluginModelClient() + return plugin_model_manager.invoke_llm( + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + +def _normalize_non_stream_plugin_result( + model: str, + prompt_messages: Sequence[PromptMessage], + result: Union[LLMResult, Iterator[LLMResultChunk]], +) -> LLMResult: + if isinstance(result, LLMResult): + return result + return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result) + + def _increase_tool_call( new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] ): @@ -40,42 +176,13 @@ def _increase_tool_call( :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. """ - def get_tool_call(tool_call_id: str): - """ - Get or create a tool call by ID - - :param tool_call_id: tool call ID - :return: existing or new tool call - """ - if not tool_call_id: - return existing_tools_calls[-1] - - _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None) - if _tool_call is None: - _tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - existing_tools_calls.append(_tool_call) - - return _tool_call - for new_tool_call in new_tool_calls: # generate ID for tool calls with function name but no ID to track them if new_tool_call.function.name and not new_tool_call.id: new_tool_call.id = _gen_tool_call_id() - # get tool call - tool_call = get_tool_call(new_tool_call.id) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments + + tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id) + _merge_tool_call_delta(tool_call, new_tool_call) class LargeLanguageModel(AIModel): @@ -141,10 +248,7 @@ class LargeLanguageModel(AIModel): result: Union[LLMResult, Generator[LLMResultChunk, None, None]] try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - result = plugin_model_manager.invoke_llm( + result = _invoke_llm_via_plugin( tenant_id=self.tenant_id, user_id=user or "unknown", plugin_id=self.plugin_id, @@ -154,38 +258,13 @@ class LargeLanguageModel(AIModel): model_parameters=model_parameters, prompt_messages=prompt_messages, tools=tools, - stop=list(stop) if stop else None, + stop=stop, stream=stream, ) if not stream: - content = "" - content_list = [] - usage = LLMUsage.empty_usage() - system_fingerprint = None - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - for chunk in result: - if isinstance(chunk.delta.message.content, str): - content += chunk.delta.message.content - elif isinstance(chunk.delta.message.content, list): - content_list.extend(chunk.delta.message.content) - if chunk.delta.message.tool_calls: - _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - - usage = chunk.delta.usage or LLMUsage.empty_usage() - system_fingerprint = chunk.system_fingerprint - break - - result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=content or content_list, - tool_calls=tools_calls, - ), - usage=usage, - system_fingerprint=system_fingerprint, + result = _normalize_non_stream_plugin_result( + model=model, prompt_messages=prompt_messages, result=result ) except Exception as e: self._trigger_invoke_error_callbacks( @@ -425,27 +504,21 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_before_invoke( - llm_instance=self, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_before_invoke", + invoke=lambda callback: callback.on_before_invoke( + llm_instance=self, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_new_chunk_callbacks( self, @@ -473,26 +546,22 @@ class LargeLanguageModel(AIModel): :param stream: is stream response :param user: unique user id """ - if callbacks: - for callback in callbacks: - try: - callback.on_new_chunk( - llm_instance=self, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e) + _run_callbacks( + callbacks, + event="on_new_chunk", + invoke=lambda callback: callback.on_new_chunk( + llm_instance=self, + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_after_invoke_callbacks( self, @@ -521,28 +590,22 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_after_invoke( - llm_instance=self, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_after_invoke", + invoke=lambda callback: callback.on_after_invoke( + llm_instance=self, + result=result, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_invoke_error_callbacks( self, @@ -571,25 +634,19 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_invoke_error( - llm_instance=self, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_invoke_error", + invoke=lambda callback: callback.on_invoke_error( + llm_instance=self, + ex=ex, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b8704ef4ed..28f162a928 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import hashlib import logging from collections.abc import Sequence @@ -38,7 +40,7 @@ class ModelProviderFactory: plugin_providers = self.get_plugin_model_providers() return [provider.declaration for provider in plugin_providers] - def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]: + def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: """ Get all plugin model providers :return: list of plugin model providers @@ -76,7 +78,7 @@ class ModelProviderFactory: plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) return plugin_model_provider_entity.declaration - def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity": + def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: """ Get plugin model provider :param provider: provider name @@ -285,7 +287,7 @@ class ModelProviderFactory: """ Get provider icon :param provider: provider name - :param icon_type: icon type (icon_small or icon_large) + :param icon_type: icon type (icon_small or icon_small_dark) :param lang: language (zh_Hans or en_US) :return: provider icon """ @@ -309,13 +311,7 @@ class ModelProviderFactory: else: file_name = provider_schema.icon_small_dark.en_US else: - if not provider_schema.icon_large: - raise ValueError(f"Provider {provider} does not have large icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_large.zh_Hans - else: - file_name = provider_schema.icon_large.en_US + raise ValueError(f"Unsupported icon type: {icon_type}.") if not file_name: raise ValueError(f"Provider {provider} does not have icon.") diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index d6bd4d2015..22ad756c91 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,6 +1,7 @@ import logging from collections.abc import Sequence +from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( @@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db @@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance): ), status=status, links=trace_metadata.links, + span_kind=SpanKind.SERVER, ) self.trace_client.add_span(message_span) @@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) session_factory = sessionmaker(bind=db.engine) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, @@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance): ), status=status, links=trace_metadata.links, + span_kind=SpanKind.SERVER, ) self.trace_client.add_span(message_span) @@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance): ), status=status, links=trace_metadata.links, + span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL, ) self.trace_client.add_span(workflow_span) diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index d3324f8f82..7624586367 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -166,7 +166,7 @@ class SpanBuilder: attributes=span_data.attributes, events=span_data.events, links=span_data.links, - kind=trace_api.SpanKind.INTERNAL, + kind=span_data.span_kind, status=span_data.status, start_time=span_data.start_time, end_time=span_data.end_time, diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py index 20ff2d0875..9078031490 100644 --- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py +++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py @@ -4,7 +4,7 @@ from typing import Any from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event -from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode from pydantic import BaseModel, Field @@ -34,3 +34,4 @@ class SpanData(BaseModel): status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.") start_time: int | None = Field(..., description="The start time of the span in nanoseconds.") end_time: int | None = Field(..., description="The end time of the span in nanoseconds.") + span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.") diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index c36391c940..549e428f88 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -471,6 +471,9 @@ class TraceTask: if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: + # Lazy import to avoid circular import during module initialization + from repositories.factory import DifyAPIRepositoryFactory + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 3b83121357..6674228dc0 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from collections.abc import Mapping, Sequence from datetime import datetime @@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum): return [item.value for item in cls] @classmethod - def of(cls, credential_type: str) -> "CredentialType": + def of(cls, credential_type: str) -> CredentialType: type_name = credential_type.lower() if type_name in {"api-key", "api_key"}: return cls.API_KEY diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7bb2749afa..7a6a598a2f 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -103,6 +103,9 @@ class BasePluginClient: prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br") + # Inject traceparent header for distributed tracing + self._inject_trace_headers(prepared_headers) + prepared_data: bytes | dict[str, Any] | str | None = ( data if isinstance(data, (bytes, str, dict)) or data is None else None ) @@ -114,6 +117,31 @@ class BasePluginClient: return str(url), prepared_headers, prepared_data, params, files + def _inject_trace_headers(self, headers: dict[str, str]) -> None: + """ + Inject W3C traceparent header for distributed tracing. + + This ensures trace context is propagated to plugin daemon even if + HTTPXClientInstrumentor doesn't cover module-level httpx functions. + """ + if not dify_config.ENABLE_OTEL: + return + + import contextlib + + # Skip if already present (case-insensitive check) + for key in headers: + if key.lower() == "traceparent": + return + + # Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call + with contextlib.suppress(Exception): + from core.helper.trace_id_helper import generate_traceparent_header + + traceparent = generate_traceparent_header() + if traceparent: + headers["traceparent"] = traceparent + def _stream_request( self, method: str, @@ -292,18 +320,17 @@ class BasePluginClient: case PluginInvokeError.__name__: error_object = json.loads(message) invoke_error_type = error_object.get("error_type") - args = error_object.get("args") match invoke_error_type: case InvokeRateLimitError.__name__: - raise InvokeRateLimitError(description=args.get("description")) + raise InvokeRateLimitError(description=error_object.get("message")) case InvokeAuthorizationError.__name__: - raise InvokeAuthorizationError(description=args.get("description")) + raise InvokeAuthorizationError(description=error_object.get("message")) case InvokeBadRequestError.__name__: - raise InvokeBadRequestError(description=args.get("description")) + raise InvokeBadRequestError(description=error_object.get("message")) case InvokeConnectionError.__name__: - raise InvokeConnectionError(description=args.get("description")) + raise InvokeConnectionError(description=error_object.get("message")) case InvokeServerUnavailableError.__name__: - raise InvokeServerUnavailableError(description=args.get("description")) + raise InvokeServerUnavailableError(description=error_object.get("message")) case CredentialsValidateFailedError.__name__: raise CredentialsValidateFailedError(error_object.get("message")) case EndpointSetupFailedError.__name__: @@ -311,11 +338,11 @@ class BasePluginClient: case TriggerProviderCredentialValidationError.__name__: raise TriggerProviderCredentialValidationError(error_object.get("message")) case TriggerPluginInvokeError.__name__: - raise TriggerPluginInvokeError(description=error_object.get("description")) + raise TriggerPluginInvokeError(description=error_object.get("message")) case TriggerInvokeError.__name__: raise TriggerInvokeError(error_object.get("message")) case EventIgnoreError.__name__: - raise EventIgnoreError(description=error_object.get("description")) + raise EventIgnoreError(description=error_object.get("message")) case _: raise PluginInvokeError(description=message) case PluginDaemonInternalServerError.__name__: diff --git a/api/core/plugin/impl/endpoint.py b/api/core/plugin/impl/endpoint.py index 5b88742be5..2db5185a2c 100644 --- a/api/core/plugin/impl/endpoint.py +++ b/api/core/plugin/impl/endpoint.py @@ -1,5 +1,6 @@ from core.plugin.entities.endpoint import EndpointEntityWithInstance from core.plugin.impl.base import BasePluginClient +from core.plugin.impl.exc import PluginDaemonInternalServerError class PluginEndpointClient(BasePluginClient): @@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient): def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str): """ Delete the given endpoint. + + This operation is idempotent: if the endpoint is already deleted (record not found), + it will return True instead of raising an error. """ - return self._request_with_plugin_daemon_response( - "POST", - f"plugin/{tenant_id}/endpoint/remove", - bool, - data={ - "endpoint_id": endpoint_id, - }, - headers={ - "Content-Type": "application/json", - }, - ) + try: + return self._request_with_plugin_daemon_response( + "POST", + f"plugin/{tenant_id}/endpoint/remove", + bool, + data={ + "endpoint_id": endpoint_id, + }, + headers={ + "Content-Type": "application/json", + }, + ) + except PluginDaemonInternalServerError as e: + # Make delete idempotent: if record is not found, consider it a success + if "record not found" in str(e.description).lower(): + return True + raise def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str): """ diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6c818bdc8b..fdbfca4330 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -331,7 +331,6 @@ class ProviderManager: provider=provider_schema.provider, label=provider_schema.label, icon_small=provider_schema.icon_small, - icon_large=provider_schema.icon_large, supported_model_types=provider_schema.supported_model_types, ), ) @@ -619,18 +618,18 @@ class ProviderManager: ) for quota in configuration.quotas: - if quota.quota_type == ProviderQuotaType.TRIAL: + if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID): # Init trial provider records if not exists - if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: + if quota.quota_type not in provider_quota_to_provider_record_dict: try: # FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic new_provider_record = Provider( tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, - provider_type=ProviderType.SYSTEM, - quota_type=ProviderQuotaType.TRIAL, - quota_limit=quota.quota_limit, # type: ignore + provider_type=ProviderType.SYSTEM.value, + quota_type=quota.quota_type, + quota_limit=0, # type: ignore quota_used=0, is_valid=True, ) @@ -642,8 +641,8 @@ class ProviderManager: stmt = select(Provider).where( Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, - Provider.provider_type == ProviderType.SYSTEM, - Provider.quota_type == ProviderQuotaType.TRIAL, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == quota.quota_type, ) existed_provider_record = db.session.scalar(stmt) if not existed_provider_record: @@ -913,6 +912,22 @@ class ProviderManager: provider_record ) quota_configurations = [] + + if dify_config.EDITION == "CLOUD": + from services.credit_pool_service import CreditPoolService + + trail_pool = CreditPoolService.get_pool( + tenant_id=tenant_id, + pool_type=ProviderQuotaType.TRIAL.value, + ) + paid_pool = CreditPoolService.get_pool( + tenant_id=tenant_id, + pool_type=ProviderQuotaType.PAID.value, + ) + else: + trail_pool = None + paid_pool = None + for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type == ProviderQuotaType.FREE: @@ -933,16 +948,36 @@ class ProviderManager: raise ValueError("quota_used is None") if provider_record.quota_limit is None: raise ValueError("quota_limit is None") + if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=trail_pool.quota_used, + quota_limit=trail_pool.quota_limit, + is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=provider_record.quota_used, - quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used - or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=paid_pool.quota_used, + quota_limit=paid_pool.quota_limit, + is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) + + else: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=provider_record.quota_used, + quota_limit=provider_record.quota_limit, + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) quota_configurations.append(quota_configuration) diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index 9cb009035b..e182c35b99 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -27,26 +27,44 @@ class CleanProcessor: pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" text = re.sub(pattern, "", text) - # Remove URL but keep Markdown image URLs - # First, temporarily replace Markdown image URLs with a placeholder - markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)" - placeholders: list[str] = [] + # Remove URL but keep Markdown image URLs and link URLs + # Replace the ENTIRE markdown link/image with a single placeholder to protect + # the link text (which might also be a URL) from being removed + markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)" + markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)" + placeholders: list[tuple[str, str, str]] = [] # (type, text, url) - def replace_with_placeholder(match, placeholders=placeholders): + def replace_markdown_with_placeholder(match, placeholders=placeholders): + link_type = "link" + link_text = match.group(1) + url = match.group(2) + placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__" + placeholders.append((link_type, link_text, url)) + return placeholder + + def replace_image_with_placeholder(match, placeholders=placeholders): + link_type = "image" url = match.group(1) - placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__" - placeholders.append(url) - return f"![image]({placeholder})" + placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__" + placeholders.append((link_type, "image", url)) + return placeholder - text = re.sub(markdown_image_pattern, replace_with_placeholder, text) + # Protect markdown links first + text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text) + # Then protect markdown images + text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text) # Now remove all remaining URLs - url_pattern = r"https?://[^\s)]+" + url_pattern = r"https?://\S+" text = re.sub(url_pattern, "", text) - # Finally, restore the Markdown image URLs - for i, url in enumerate(placeholders): - text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url) + # Restore the Markdown links and images + for i, (link_type, text_or_alt, url) in enumerate(placeholders): + placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__" + if link_type == "link": + text = text.replace(placeholder, f"[{text_or_alt}]({url})") + else: # image + text = text.replace(placeholder, f"![{text_or_alt}]({url})") return text def filter_string(self, text): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 9807cb4e6a..8ec1ce6242 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,4 +1,5 @@ import concurrent.futures +import logging from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -13,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.embedding.retrieval import RetrievalSegments +from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -36,6 +37,8 @@ default_retrieval_model = { "score_threshold_enabled": False, } +logger = logging.getLogger(__name__) + class RetrievalService: # Cache precompiled regular expressions to avoid repeated compilation @@ -106,7 +109,12 @@ class RetrievalService: ) ) - concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) + if futures: + for future in concurrent.futures.as_completed(futures, timeout=3600): + if exceptions: + for f in futures: + f.cancel() + break if exceptions: raise ValueError(";\n".join(exceptions)) @@ -210,6 +218,7 @@ class RetrievalService: ) all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @classmethod @@ -303,6 +312,7 @@ class RetrievalService: else: all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @classmethod @@ -351,6 +361,7 @@ class RetrievalService: else: all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @staticmethod @@ -381,10 +392,9 @@ class RetrievalService: records = [] include_segment_ids = set() segment_child_map = {} - segment_file_map = {} valid_dataset_documents = {} - image_doc_ids = [] + image_doc_ids: list[Any] = [] child_index_node_ids = [] index_node_ids = [] doc_to_document_map = {} @@ -417,28 +427,39 @@ class RetrievalService: child_index_node_ids = [i for i in child_index_node_ids if i] index_node_ids = [i for i in index_node_ids if i] - segment_ids = [] + segment_ids: list[str] = [] index_node_segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = [] - attachment_map = {} - child_chunk_map = {} - doc_segment_map = {} + attachment_map: dict[str, list[dict[str, Any]]] = {} + child_chunk_map: dict[str, list[ChildChunk]] = {} + doc_segment_map: dict[str, list[str]] = {} with session_factory.create_session() as session: attachments = cls.get_segment_attachment_infos(image_doc_ids, session) for attachment in attachments: segment_ids.append(attachment["segment_id"]) - attachment_map[attachment["segment_id"]] = attachment - doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"] - + if attachment["segment_id"] in attachment_map: + attachment_map[attachment["segment_id"]].append(attachment["attachment_info"]) + else: + attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]] + if attachment["segment_id"] in doc_segment_map: + doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"]) + else: + doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]] child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) child_index_nodes = session.execute(child_chunk_stmt).scalars().all() for i in child_index_nodes: segment_ids.append(i.segment_id) - child_chunk_map[i.segment_id] = i - doc_segment_map[i.segment_id] = i.index_node_id + if i.segment_id in child_chunk_map: + child_chunk_map[i.segment_id].append(i) + else: + child_chunk_map[i.segment_id] = [i] + if i.segment_id in doc_segment_map: + doc_segment_map[i.segment_id].append(i.index_node_id) + else: + doc_segment_map[i.segment_id] = [i.index_node_id] if index_node_ids: document_segment_stmt = select(DocumentSegment).where( @@ -448,7 +469,7 @@ class RetrievalService: ) index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore for index_node_segment in index_node_segments: - doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id + doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id] if segment_ids: document_segment_stmt = select(DocumentSegment).where( DocumentSegment.enabled == True, @@ -461,95 +482,86 @@ class RetrievalService: segments.extend(index_node_segments) for segment in segments: - doc_id = doc_segment_map.get(segment.id) - child_chunk = child_chunk_map.get(segment.id) - attachment_info = attachment_map.get(segment.id) + child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) + attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) + ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id) - if doc_id: - document = doc_to_document_map[doc_id] - ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get( - document.metadata.get("document_id") - ) - - if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - if child_chunk: + if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + if child_chunks or attachment_infos: + child_chunk_details = [] + max_score = 0.0 + for child_chunk in child_chunks: + document = doc_to_document_map[child_chunk.index_node_id] child_chunk_detail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, "score": document.metadata.get("score", 0.0) if document else 0.0, } - map_detail = { - "max_score": document.metadata.get("score", 0.0) if document else 0.0, - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, + child_chunk_details.append(child_chunk_detail) + max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) + for attachment_info in attachment_infos: + file_document = doc_to_document_map[attachment_info["id"]] + max_score = max( + max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0 + ) + + map_detail = { + "max_score": max_score, + "child_chunks": child_chunk_details, } - if attachment_info: - segment_file_map[segment.id] = [attachment_info] - records.append(record) - else: - if child_chunk: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - if segment.id in segment_child_map: - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], - document.metadata.get("score", 0.0) if document else 0.0, - ) - else: - segment_child_map[segment.id] = { - "max_score": document.metadata.get("score", 0.0) if document else 0.0, - "child_chunks": [child_chunk_detail], - } - if attachment_info: - if segment.id in segment_file_map: - segment_file_map[segment.id].append(attachment_info) - else: - segment_file_map[segment.id] = [attachment_info] - else: - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score", 0.0), # type: ignore - } - if attachment_info: - segment_file_map[segment.id] = [attachment_info] - records.append(record) - else: - if attachment_info: - attachment_infos = segment_file_map.get(segment.id, []) - if attachment_info not in attachment_infos: - attachment_infos.append(attachment_info) - segment_file_map[segment.id] = attachment_infos + segment_child_map[segment.id] = map_detail + record: dict[str, Any] = { + "segment": segment, + } + records.append(record) + else: + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + max_score = 0.0 + segment_document = doc_to_document_map.get(segment.index_node_id) + if segment_document: + max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + for attachment_info in attachment_infos: + file_doc = doc_to_document_map.get(attachment_info["id"]) + if file_doc: + max_score = max(max_score, file_doc.metadata.get("score", 0.0)) + record = { + "segment": segment, + "score": max_score, + } + records.append(record) # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore - if record["segment"].id in segment_file_map: - record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] + if record["segment"].id in attachment_map: + record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment] - result = [] + result: list[RetrievalSegments] = [] for record in records: # Extract segment segment = record["segment"] # Extract child_chunks, ensuring it's a list or None - child_chunks = record.get("child_chunks") - if not isinstance(child_chunks, list): - child_chunks = None + raw_child_chunks = record.get("child_chunks") + child_chunks_list: list[RetrievalChildChunk] | None = None + if isinstance(raw_child_chunks, list): + # Sort by score descending + sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True) + child_chunks_list = [ + RetrievalChildChunk( + id=chunk["id"], + content=chunk["content"], + score=chunk.get("score", 0.0), + position=chunk["position"], + ) + for chunk in sorted_chunks + ] # Extract files, ensuring it's a list or None files = record.get("files") @@ -566,11 +578,11 @@ class RetrievalService: # Create RetrievalSegments object retrieval_segment = RetrievalSegments( - segment=segment, child_chunks=child_chunks, score=score, files=files + segment=segment, child_chunks=child_chunks_list, score=score, files=files ) result.append(retrieval_segment) - return result + return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True) except Exception as e: db.session.rollback() raise e @@ -662,7 +674,14 @@ class RetrievalService: document_ids_filter=document_ids_filter, ) ) - concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) + # Use as_completed for early error propagation - cancel remaining futures on first error + if futures: + for future in concurrent.futures.as_completed(futures, timeout=300): + if future.exception(): + # Cancel remaining futures to avoid unnecessary waiting + for f in futures: + f.cancel() + break if exceptions: raise ValueError(";\n".join(exceptions)) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index a306f9ba0c..91bb71bfa6 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import json import logging @@ -6,7 +8,7 @@ import re import threading import time import uuid -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import clickzetta # type: ignore from pydantic import BaseModel, model_validator @@ -76,7 +78,7 @@ class ClickzettaConnectionPool: Manages connection reuse across ClickzettaVector instances. """ - _instance: Optional["ClickzettaConnectionPool"] = None + _instance: ClickzettaConnectionPool | None = None _lock = threading.Lock() def __init__(self): @@ -89,7 +91,7 @@ class ClickzettaConnectionPool: self._start_cleanup_thread() @classmethod - def get_instance(cls) -> "ClickzettaConnectionPool": + def get_instance(cls) -> ClickzettaConnectionPool: """Get singleton instance of connection pool.""" if cls._instance is None: with cls._lock: @@ -104,7 +106,7 @@ class ClickzettaConnectionPool: f"{config.workspace}:{config.vcluster}:{config.schema_name}" ) - def _create_connection(self, config: ClickzettaConfig) -> "Connection": + def _create_connection(self, config: ClickzettaConfig) -> Connection: """Create a new ClickZetta connection.""" max_retries = 3 retry_delay = 1.0 @@ -134,7 +136,7 @@ class ClickzettaConnectionPool: raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") - def _configure_connection(self, connection: "Connection"): + def _configure_connection(self, connection: Connection): """Configure connection session settings.""" try: with connection.cursor() as cursor: @@ -181,7 +183,7 @@ class ClickzettaConnectionPool: except Exception: logger.exception("Failed to configure connection, continuing with defaults") - def _is_connection_valid(self, connection: "Connection") -> bool: + def _is_connection_valid(self, connection: Connection) -> bool: """Check if connection is still valid.""" try: with connection.cursor() as cursor: @@ -190,7 +192,7 @@ class ClickzettaConnectionPool: except Exception: return False - def get_connection(self, config: ClickzettaConfig) -> "Connection": + def get_connection(self, config: ClickzettaConfig) -> Connection: """Get a connection from the pool or create a new one.""" config_key = self._get_config_key(config) @@ -221,7 +223,7 @@ class ClickzettaConnectionPool: # No valid connection found, create new one return self._create_connection(config) - def return_connection(self, config: ClickzettaConfig, connection: "Connection"): + def return_connection(self, config: ClickzettaConfig, connection: Connection): """Return a connection to the pool.""" config_key = self._get_config_key(config) @@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector): self._connection_pool = ClickzettaConnectionPool.get_instance() self._init_write_queue() - def _get_connection(self) -> "Connection": + def _get_connection(self) -> Connection: """Get a connection from the pool.""" return self._connection_pool.get_connection(self._config) - def _return_connection(self, connection: "Connection"): + def _return_connection(self, connection: Connection): """Return a connection to the pool.""" self._connection_pool.return_connection(self._config, connection) class ConnectionContext: """Context manager for borrowing and returning connections.""" - def __init__(self, vector_instance: "ClickzettaVector"): + def __init__(self, vector_instance: ClickzettaVector): self.vector = vector_instance self.connection: Connection | None = None - def __enter__(self) -> "Connection": + def __enter__(self) -> Connection: self.connection = self.vector._get_connection() return self.connection @@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector): if self.connection: self.vector._return_connection(self.connection) - def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": + def get_connection_context(self) -> ClickzettaVector.ConnectionContext: """Get a connection context manager.""" return self.ConnectionContext(self) @@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector): """Return the vector database type.""" return "clickzetta" - def _ensure_connection(self) -> "Connection": + def _ensure_connection(self) -> Connection: """Get a connection from the pool.""" return self._get_connection() @@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector): # No need for dataset_id filter since each dataset has its own table - # Use simple quote escaping for LIKE clause - escaped_query = query.replace("'", "''") - filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'") + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query).replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'") where_clause = " AND ".join(filter_clauses) search_sql = f""" diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py index b1bfabb76e..50bb2429ec 100644 --- a/api/core/rag/datasource/vdb/iris/iris_vector.py +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -154,7 +154,7 @@ class IrisConnectionPool: # Add to cache to skip future checks self._schemas_initialized.add(schema) - except Exception as e: + except Exception: conn.rollback() logger.exception("Failed to ensure schema %s exists", schema) raise @@ -177,6 +177,9 @@ class IrisConnectionPool: class IrisVector(BaseVector): """IRIS vector database implementation using native VECTOR type and HNSW indexing.""" + # Fallback score for full-text search when Rank function unavailable or TEXT_INDEX disabled + _FULL_TEXT_FALLBACK_SCORE = 0.5 + def __init__(self, collection_name: str, config: IrisVectorConfig) -> None: super().__init__(collection_name) self.config = config @@ -272,37 +275,131 @@ class IrisVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - """Search documents by full-text using iFind index or fallback to LIKE search.""" + """Search documents by full-text using iFind index with BM25 relevance scoring. + + When IRIS_TEXT_INDEX is enabled, this method uses the auto-generated Rank + function from %iFind.Index.Basic to calculate BM25 relevance scores. The Rank + function is automatically created with naming: {schema}.{table_name}_{index}Rank + + Args: + query: Search query string + **kwargs: Optional parameters including top_k, document_ids_filter + + Returns: + List of Document objects with relevance scores in metadata["score"] + """ top_k = kwargs.get("top_k", 5) + document_ids_filter = kwargs.get("document_ids_filter") with self._get_cursor() as cursor: if self.config.IRIS_TEXT_INDEX: - # Use iFind full-text search with index + # Use iFind full-text search with auto-generated Rank function text_index_name = f"idx_{self.table_name}_text" + # IRIS removes underscores from function names + table_no_underscore = self.table_name.replace("_", "") + index_no_underscore = text_index_name.replace("_", "") + rank_function = f"{self.schema}.{table_no_underscore}_{index_no_underscore}Rank" + + # Build WHERE clause with document ID filter if provided + where_clause = f"WHERE %ID %FIND search_index({text_index_name}, ?)" + # First param for Rank function, second for FIND + params = [query, query] + + if document_ids_filter: + # Add document ID filter + placeholders = ",".join("?" * len(document_ids_filter)) + where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})" + params.extend(document_ids_filter) + sql = f""" - SELECT TOP {top_k} id, text, meta + SELECT TOP {top_k} + id, + text, + meta, + {rank_function}(%ID, ?) AS score FROM {self.schema}.{self.table_name} - WHERE %ID %FIND search_index({text_index_name}, ?) + {where_clause} + ORDER BY score DESC """ - cursor.execute(sql, (query,)) + + logger.debug( + "iFind search: query='%s', index='%s', rank='%s'", + query, + text_index_name, + rank_function, + ) + + try: + cursor.execute(sql, params) + except Exception: # pylint: disable=broad-exception-caught + # Fallback to query without Rank function if it fails + logger.warning( + "Rank function '%s' failed, using fixed score", + rank_function, + exc_info=True, + ) + sql_fallback = f""" + SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score + FROM {self.schema}.{self.table_name} + {where_clause} + """ + # Skip first param (for Rank function) + cursor.execute(sql_fallback, params[1:]) else: - # Fallback to LIKE search (inefficient for large datasets) - query_pattern = f"%{query}%" + # Fallback to LIKE search (IRIS_TEXT_INDEX disabled) + from libs.helper import ( # pylint: disable=import-outside-toplevel + escape_like_pattern, + ) + + escaped_query = escape_like_pattern(query) + query_pattern = f"%{escaped_query}%" + + # Build WHERE clause with document ID filter if provided + where_clause = "WHERE text LIKE ? ESCAPE '\\\\'" + params = [query_pattern] + + if document_ids_filter: + placeholders = ",".join("?" * len(document_ids_filter)) + where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})" + params.extend(document_ids_filter) + sql = f""" - SELECT TOP {top_k} id, text, meta + SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score FROM {self.schema}.{self.table_name} - WHERE text LIKE ? + {where_clause} + ORDER BY LENGTH(text) ASC """ - cursor.execute(sql, (query_pattern,)) + + logger.debug( + "LIKE fallback (TEXT_INDEX disabled): query='%s'", + query_pattern, + ) + cursor.execute(sql, params) docs = [] for row in cursor.fetchall(): - if len(row) >= 3: - metadata = json.loads(row[2]) if row[2] else {} - docs.append(Document(page_content=row[1], metadata=metadata)) + # Expecting 4 columns: id, text, meta, score + if len(row) >= 4: + text_content = row[1] + meta_str = row[2] + score_value = row[3] + + metadata = json.loads(meta_str) if meta_str else {} + # Add score to metadata for hybrid search compatibility + score = float(score_value) if score_value is not None else 0.0 + metadata["score"] = score + + docs.append(Document(page_content=text_content, metadata=metadata)) + + logger.info( + "Full-text search completed: query='%s', results=%d/%d", + query, + len(docs), + top_k, + ) if not docs: - logger.info("Full-text search for '%s' returned no results", query) + logger.warning("Full-text search for '%s' returned no results", query) return docs @@ -366,7 +463,11 @@ class IrisVector(BaseVector): AS %iFind.Index.Basic (LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0) """ - logger.info("Creating text index: %s with language: %s", text_index_name, language) + logger.info( + "Creating text index: %s with language: %s", + text_index_name, + language, + ) logger.info("SQL for text index: %s", sql_text_index) cursor.execute(sql_text_index) logger.info("Text index created successfully: %s", text_index_name) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 445a0a7f8b..0615b8312c 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -255,7 +255,10 @@ class PGVector(BaseVector): return with self._get_cursor() as cur: - cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'") + if not cur.fetchone(): + cur.execute("CREATE EXTENSION vector") + cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) # PG hnsw index only support 2000 dimension or less # ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 84d1e26b34..b48dd93f04 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -66,6 +66,8 @@ class WeaviateVector(BaseVector): in a Weaviate collection. """ + _DOCUMENT_ID_PROPERTY = "document_id" + def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): """ Initializes the Weaviate vector store. @@ -353,15 +355,12 @@ class WeaviateVector(BaseVector): return [] col = self._client.collections.use(self._collection_name) - props = list({*self._attributes, "document_id", Field.TEXT_KEY.value}) + props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value}) where = None doc_ids = kwargs.get("document_ids_filter") or [] if doc_ids: - ors = [Filter.by_property("document_id").equal(x) for x in doc_ids] - where = ors[0] - for f in ors[1:]: - where = where | f + where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids) top_k = int(kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -408,10 +407,7 @@ class WeaviateVector(BaseVector): where = None doc_ids = kwargs.get("document_ids_filter") or [] if doc_ids: - ors = [Filter.by_property("document_id").equal(x) for x in doc_ids] - where = ors[0] - for f in ors[1:]: - where = where | f + where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids) top_k = int(kwargs.get("top_k", 4)) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 1fe74d3042..69adac522d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Sequence from typing import Any @@ -22,7 +24,7 @@ class DatasetDocumentStore: self._document_id = document_id @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore": + def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore: return cls(**config_dict) def to_dict(self) -> dict[str, Any]: diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 013c287248..6d28ce25bc 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -112,7 +112,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": - extractor = PdfExtractor(file_path) + extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = ( UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key) @@ -148,7 +148,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": - extractor = PdfExtractor(file_path) + extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in {".htm", ".html"}: diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 80530d99a6..6aabcac704 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,25 +1,57 @@ """Abstract interface for document loader implementations.""" import contextlib +import io +import logging +import uuid from collections.abc import Iterator +import pypdfium2 +import pypdfium2.raw as pdfium_c + +from configs import dify_config from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document +from extensions.ext_database import db from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now +from models.enums import CreatorUserRole +from models.model import UploadFile + +logger = logging.getLogger(__name__) class PdfExtractor(BaseExtractor): - """Load pdf files. - + """ + PdfExtractor is used to extract text and images from PDF files. Args: - file_path: Path to the file to load. + file_path: Path to the PDF file. + tenant_id: Workspace ID. + user_id: ID of the user performing the extraction. + file_cache_key: Optional cache key for the extracted text. """ - def __init__(self, file_path: str, file_cache_key: str | None = None): - """Initialize with file path.""" + # Magic bytes for image format detection: (magic_bytes, extension, mime_type) + IMAGE_FORMATS = [ + (b"\xff\xd8\xff", "jpg", "image/jpeg"), + (b"\x89PNG\r\n\x1a\n", "png", "image/png"), + (b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"), + (b"GIF8", "gif", "image/gif"), + (b"BM", "bmp", "image/bmp"), + (b"II*\x00", "tiff", "image/tiff"), + (b"MM\x00*", "tiff", "image/tiff"), + (b"II+\x00", "tiff", "image/tiff"), + (b"MM\x00+", "tiff", "image/tiff"), + ] + MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS) + + def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None): + """Initialize PdfExtractor.""" self._file_path = file_path + self._tenant_id = tenant_id + self._user_id = user_id self._file_cache_key = file_cache_key def extract(self) -> list[Document]: @@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor): def parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" - import pypdfium2 # type: ignore with blob.as_bytes_io() as file_path: pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) @@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor): text_page = page.get_textpage() content = text_page.get_text_range() text_page.close() + + image_content = self._extract_images(page) + if image_content: + content += "\n" + image_content + page.close() metadata = {"source": blob.source, "page": page_number} yield Document(page_content=content, metadata=metadata) finally: pdf_reader.close() + + def _extract_images(self, page) -> str: + """ + Extract images from a PDF page, save them to storage and database, + and return markdown image links. + + Args: + page: pypdfium2 page object. + + Returns: + Markdown string containing links to the extracted images. + """ + image_content = [] + upload_files = [] + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + try: + image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,)) + for obj in image_objects: + try: + # Extract image bytes + img_byte_arr = io.BytesIO() + # Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly + # Fallback to png for other formats + obj.extract(img_byte_arr, fb_format="png") + img_bytes = img_byte_arr.getvalue() + + if not img_bytes: + continue + + header = img_bytes[: self.MAX_MAGIC_LEN] + image_ext = None + mime_type = None + for magic, ext, mime in self.IMAGE_FORMATS: + if header.startswith(magic): + image_ext = ext + mime_type = mime + break + + if not image_ext or not mime_type: + continue + + file_uuid = str(uuid.uuid4()) + file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext + + storage.save(file_key, img_bytes) + + # save file to db + upload_file = UploadFile( + tenant_id=self._tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=len(img_bytes), + extension=image_ext, + mime_type=mime_type, + created_by=self._user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self._user_id, + used_at=naive_utc_now(), + ) + upload_files.append(upload_file) + image_content.append(f"![image]({base_url}/files/{upload_file.id}/file-preview)") + except Exception as e: + logger.warning("Failed to extract image from PDF: %s", e) + continue + except Exception as e: + logger.warning("Failed to get objects from PDF page: %s", e) + if upload_files: + db.session.add_all(upload_files) + db.session.commit() + return "\n".join(image_content) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index f67f613e9d..511f5a698d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -7,10 +7,11 @@ import re import tempfile import uuid from urllib.parse import urlparse -from xml.etree import ElementTree import httpx from docx import Document as DocxDocument +from docx.oxml.ns import qn +from docx.text.run import Run from configs import dify_config from core.helper import ssrf_proxy @@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor): image_map = self._extract_images_from_docx(doc) - hyperlinks_url = None - url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") - for para in doc.paragraphs: - for run in para.runs: - if run.text and hyperlinks_url: - result = f" [{run.text}]({hyperlinks_url}) " - run.text = result - hyperlinks_url = None - if "HYPERLINK" in run.element.xml: - try: - xml = ElementTree.XML(run.element.xml) - x_child = [c for c in xml.iter() if c is not None] - for x in x_child: - if x is None: - continue - if x.tag.endswith("instrText"): - if x.text is None: - continue - for i in url_pattern.findall(x.text): - hyperlinks_url = str(i) - except Exception: - logger.exception("Failed to parse HYPERLINK xml") - def parse_paragraph(paragraph): - paragraph_content = [] - - def append_image_link(image_id, has_drawing): + def append_image_link(image_id, has_drawing, target_buffer): """Helper to append image link from image_map based on relationship type.""" rel = doc.part.rels[image_id] if rel.is_external: if image_id in image_map and not has_drawing: - paragraph_content.append(image_map[image_id]) + target_buffer.append(image_map[image_id]) else: image_part = rel.target_part if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + target_buffer.append(image_map[image_part]) - for run in paragraph.runs: + def process_run(run, target_buffer): + # Helper to extract text and embedded images from a run element and append them to target_buffer if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): # Process drawing type images drawing_elements = run.element.findall( @@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor): # External image: use embed_id as key if embed_id in image_map: has_drawing = True - paragraph_content.append(image_map[embed_id]) + target_buffer.append(image_map[embed_id]) else: # Internal image: use target_part as key image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: has_drawing = True - paragraph_content.append(image_map[image_part]) + target_buffer.append(image_map[image_part]) # Process pict type images shape_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" @@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - append_image_link(image_id, has_drawing) + append_image_link(image_id, has_drawing, target_buffer) # Find imagedata element in VML image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") if image_data is not None: @@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - append_image_link(image_id, has_drawing) + append_image_link(image_id, has_drawing, target_buffer) if run.text.strip(): - paragraph_content.append(run.text.strip()) + target_buffer.append(run.text.strip()) + + def process_hyperlink(hyperlink_elem, target_buffer): + # Helper to extract text from a hyperlink element and append it to target_buffer + r_id = hyperlink_elem.get(qn("r:id")) + + # Extract text from runs inside the hyperlink + link_text_parts = [] + for run_elem in hyperlink_elem.findall(qn("w:r")): + run = Run(run_elem, paragraph) + # Hyperlink text may be split across multiple runs (e.g., with different formatting), + # so collect all run texts first + if run.text: + link_text_parts.append(run.text) + + link_text = "".join(link_text_parts).strip() + + # Resolve URL + if r_id: + try: + rel = doc.part.rels.get(r_id) + if rel and rel.is_external: + link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})" + except Exception: + logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id) + + if link_text: + target_buffer.append(link_text) + + paragraph_content = [] + # State for legacy HYPERLINK fields + hyperlink_field_url = None + hyperlink_field_text_parts: list = [] + is_collecting_field_text = False + # Iterate through paragraph elements in document order + for child in paragraph._element: + tag = child.tag + if tag == qn("w:r"): + # Regular run + run = Run(child, paragraph) + + # Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks + fld_chars = child.findall(qn("w:fldChar")) + instr_texts = child.findall(qn("w:instrText")) + + # Handle Fields + if fld_chars or instr_texts: + # Process instrText to find HYPERLINK "url" + for instr in instr_texts: + if instr.text and "HYPERLINK" in instr.text: + # Quick regex to extract URL + match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE) + if match: + hyperlink_field_url = match.group(1) + + # Process fldChar + for fld_char in fld_chars: + fld_char_type = fld_char.get(qn("w:fldCharType")) + if fld_char_type == "begin": + # Start of a field: reset legacy link state + hyperlink_field_url = None + hyperlink_field_text_parts = [] + is_collecting_field_text = False + elif fld_char_type == "separate": + # Separator: if we found a URL, start collecting visible text + if hyperlink_field_url: + is_collecting_field_text = True + elif fld_char_type == "end": + # End of field + if is_collecting_field_text and hyperlink_field_url: + # Create markdown link and append to main content + display_text = "".join(hyperlink_field_text_parts).strip() + if display_text: + link_md = f"[{display_text}]({hyperlink_field_url})" + paragraph_content.append(link_md) + # Reset state + hyperlink_field_url = None + hyperlink_field_text_parts = [] + is_collecting_field_text = False + + # Decide where to append content + target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content + process_run(run, target_buffer) + elif tag == qn("w:hyperlink"): + process_hyperlink(child, paragraph_content) return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() diff --git a/api/core/rag/pipeline/queue.py b/api/core/rag/pipeline/queue.py index 7472598a7f..bf8db95b4e 100644 --- a/api/core/rag/pipeline/queue.py +++ b/api/core/rag/pipeline/queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import Sequence from typing import Any @@ -16,7 +18,7 @@ class TaskWrapper(BaseModel): return self.model_dump_json() @classmethod - def deserialize(cls, serialized_data: str) -> "TaskWrapper": + def deserialize(cls, serialized_data: str) -> TaskWrapper: return cls.model_validate_json(serialized_data) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index baf879df95..f8f85d141a 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app -from sqlalchemy import and_, or_, select +from sqlalchemy import and_, literal, or_, select from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -515,7 +515,11 @@ class DatasetRetrieval: 0 ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + dataset_count = len(available_datasets) with measure_time() as timer: + cancel_event = threading.Event() + thread_exceptions: list[Exception] = [] + if query: query_thread = threading.Thread( target=self._multiple_retrieve_thread, @@ -534,6 +538,9 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": query, "attachment_id": None, + "dataset_count": dataset_count, + "cancel_event": cancel_event, + "thread_exceptions": thread_exceptions, }, ) all_threads.append(query_thread) @@ -557,12 +564,26 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": None, "attachment_id": attachment_id, + "dataset_count": dataset_count, + "cancel_event": cancel_event, + "thread_exceptions": thread_exceptions, }, ) all_threads.append(attachment_thread) attachment_thread.start() - for thread in all_threads: - thread.join() + + # Poll threads with short timeout to detect errors quickly (fail-fast) + while any(t.is_alive() for t in all_threads): + for thread in all_threads: + thread.join(timeout=0.1) + if thread_exceptions: + cancel_event.set() + break + if thread_exceptions: + break + + if thread_exceptions: + raise thread_exceptions[0] self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) if all_documents: @@ -1036,7 +1057,7 @@ class DatasetRetrieval: if automatic_metadata_filters: conditions = [] for sequence, filter in enumerate(automatic_metadata_filters): - self._process_metadata_filter_func( + self.process_metadata_filter_func( sequence, filter.get("condition"), # type: ignore filter.get("metadata_name"), # type: ignore @@ -1072,7 +1093,7 @@ class DatasetRetrieval: value=expected_value, ) ) - filters = self._process_metadata_filter_func( + filters = self.process_metadata_filter_func( sequence, condition.comparison_operator, metadata_name, @@ -1168,26 +1189,33 @@ class DatasetRetrieval: return None return automatic_metadata_filters - def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list + @classmethod + def process_metadata_filter_func( + cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list ): if value is None and condition not in ("empty", "not empty"): return filters json_field = DatasetDocument.doc_metadata[metadata_name].as_string() + from libs.helper import escape_like_pattern + match condition: case "contains": - filters.append(json_field.like(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}%", escape="\\")) case "not contains": - filters.append(json_field.notlike(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\")) case "start with": - filters.append(json_field.like(f"{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"{escaped_value}%", escape="\\")) case "end with": - filters.append(json_field.like(f"%{value}")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}", escape="\\")) case "is" | "=": if isinstance(value, str): @@ -1218,6 +1246,20 @@ class DatasetRetrieval: case "≥" | ">=": filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value) + case "in" | "not in": + if isinstance(value, str): + value_list = [v.strip() for v in value.split(",") if v.strip()] + elif isinstance(value, (list, tuple)): + value_list = [str(v) for v in value if v is not None] + else: + value_list = [str(value)] if value is not None else [] + + if not value_list: + # `field in []` is False, `field not in []` is True + filters.append(literal(condition == "not in")) + else: + op = json_field.in_ if condition == "in" else json_field.notin_ + filters.append(op(value_list)) case _: pass @@ -1389,69 +1431,89 @@ class DatasetRetrieval: score_threshold: float, query: str | None, attachment_id: str | None, + dataset_count: int, + cancel_event: threading.Event | None = None, + thread_exceptions: list[Exception] | None = None, ): - with flask_app.app_context(): - threads = [] - all_documents_item: list[Document] = [] - index_type = None - for dataset in available_datasets: - index_type = dataset.indexing_technique - document_ids_filter = None - if dataset.provider != "external": - if metadata_condition and not metadata_filter_document_ids: - continue - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids - else: + try: + with flask_app.app_context(): + threads = [] + all_documents_item: list[Document] = [] + index_type = None + for dataset in available_datasets: + # Check for cancellation signal + if cancel_event and cancel_event.is_set(): + break + index_type = dataset.indexing_technique + document_ids_filter = None + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: continue - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": flask_app, - "dataset_id": dataset.id, - "query": query, - "top_k": top_k, - "all_documents": all_documents_item, - "document_ids_filter": document_ids_filter, - "metadata_condition": metadata_condition, - "attachment_ids": [attachment_id] if attachment_id else None, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": flask_app, + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents_item, + "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, + "attachment_ids": [attachment_id] if attachment_id else None, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() - if reranking_enable: - # do rerank for searched documents - data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) - if query: - all_documents_item = data_post_processor.invoke( - query=query, - documents=all_documents_item, - score_threshold=score_threshold, - top_n=top_k, - query_type=QueryType.TEXT_QUERY, - ) - if attachment_id: - all_documents_item = data_post_processor.invoke( - documents=all_documents_item, - score_threshold=score_threshold, - top_n=top_k, - query_type=QueryType.IMAGE_QUERY, - query=attachment_id, - ) - else: - if index_type == IndexTechniqueType.ECONOMY: - if not query: - all_documents_item = [] - else: - all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) - elif index_type == IndexTechniqueType.HIGH_QUALITY: - all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + # Poll threads with short timeout to respond quickly to cancellation + while any(t.is_alive() for t in threads): + for thread in threads: + thread.join(timeout=0.1) + if cancel_event and cancel_event.is_set(): + break + if cancel_event and cancel_event.is_set(): + break + + # Skip second reranking when there is only one dataset + if reranking_enable and dataset_count > 1: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) + if query: + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY, + ) + if attachment_id: + all_documents_item = data_post_processor.invoke( + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.IMAGE_QUERY, + query=attachment_id, + ) else: - all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item - if all_documents_item: - all_documents.extend(all_documents_item) + if index_type == IndexTechniqueType.ECONOMY: + if not query: + all_documents_item = [] + else: + all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) + elif index_type == IndexTechniqueType.HIGH_QUALITY: + all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + else: + all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item + if all_documents_item: + all_documents.extend(all_documents_item) + except Exception as e: + if cancel_event: + cancel_event.set() + if thread_exceptions is not None: + thread_exceptions.append(e) diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index 51bfae1cd3..b4ecfe47ff 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import logging import threading from collections.abc import Mapping, MutableMapping from pathlib import Path -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar class SchemaRegistry: @@ -11,7 +13,7 @@ class SchemaRegistry: logger: ClassVar[logging.Logger] = logging.getLogger(__name__) - _default_instance: ClassVar[Optional["SchemaRegistry"]] = None + _default_instance: ClassVar[SchemaRegistry | None] = None _lock: ClassVar[threading.Lock] = threading.Lock() def __init__(self, base_dir: str): @@ -20,7 +22,7 @@ class SchemaRegistry: self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {} @classmethod - def default_registry(cls) -> "SchemaRegistry": + def default_registry(cls) -> SchemaRegistry: """Returns the default schema registry for builtin schemas (thread-safe singleton)""" if cls._default_instance is None: with cls._lock: diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 8ca4eabb7a..ebd200a822 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy @@ -24,7 +26,7 @@ class Tool(ABC): self.entity = entity self.runtime = runtime - def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool: """ fork a new tool with metadata :return: the new tool @@ -166,7 +168,7 @@ class Tool(ABC): type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) ) - def create_file_message(self, file: "File") -> ToolInvokeMessage: + def create_file_message(self, file: File) -> ToolInvokeMessage: return ToolInvokeMessage( type=ToolInvokeMessage.MessageType.FILE, message=ToolInvokeMessage.FileMessage(), diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 84efefba07..51b0407886 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.__base.tool import Tool @@ -24,7 +26,7 @@ class BuiltinTool(Tool): super().__init__(**kwargs) self.provider = provider - def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool: """ fork a new tool with metadata :return: the new tool diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 0cc992155a..e2f6c00555 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pydantic import Field from sqlalchemy import select @@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController): self.tools = [] @classmethod - def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": + def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController: credentials_schema = [ ProviderConfig( name="auth_type", diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 583a3584f7..96268d029e 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import contextlib from collections.abc import Mapping @@ -55,7 +57,7 @@ class ToolProviderType(StrEnum): MCP = auto() @classmethod - def value_of(cls, value: str) -> "ToolProviderType": + def value_of(cls, value: str) -> ToolProviderType: """ Get value of given mode. @@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum): OPENAI_ACTIONS = auto() @classmethod - def value_of(cls, value: str) -> "ApiProviderSchemaType": + def value_of(cls, value: str) -> ApiProviderSchemaType: """ Get value of given mode. @@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum): API_KEY_QUERY = auto() @classmethod - def value_of(cls, value: str) -> "ApiProviderAuthType": + def value_of(cls, value: str) -> ApiProviderAuthType: """ Get value of given mode. @@ -128,7 +130,7 @@ class ToolInvokeMessage(BaseModel): text: str class JsonMessage(BaseModel): - json_object: dict + json_object: dict | list suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string") class BlobMessage(BaseModel): @@ -142,7 +144,14 @@ class ToolInvokeMessage(BaseModel): end: bool = Field(..., description="Whether the chunk is the last chunk") class FileMessage(BaseModel): - pass + file_marker: str = Field(default="file_marker") + + @model_validator(mode="before") + @classmethod + def validate_file_message(cls, values): + if isinstance(values, dict) and "file_marker" not in values: + raise ValueError("Invalid FileMessage: missing file_marker") + return values class VariableMessage(BaseModel): variable_name: str = Field(..., description="The name of the variable") @@ -232,10 +241,22 @@ class ToolInvokeMessage(BaseModel): @field_validator("message", mode="before") @classmethod - def decode_blob_message(cls, v): + def decode_blob_message(cls, v, info: ValidationInfo): + # 处理 blob 解码 if isinstance(v, dict) and "blob" in v: with contextlib.suppress(Exception): v["blob"] = base64.b64decode(v["blob"]) + + # Force correct message type based on type field + # Only wrap dict types to avoid wrapping already parsed Pydantic model objects + if info.data and isinstance(info.data, dict) and isinstance(v, dict): + msg_type = info.data.get("type") + if msg_type == cls.MessageType.JSON: + if "json_object" not in v: + v = {"json_object": v} + elif msg_type == cls.MessageType.FILE: + v = {"file_marker": "file_marker"} + return v @field_serializer("message") @@ -307,7 +328,7 @@ class ToolParameter(PluginParameter): typ: ToolParameterType, required: bool, options: list[str] | None = None, - ) -> "ToolParameter": + ) -> ToolParameter: """ get a simple tool parameter @@ -429,14 +450,14 @@ class ToolInvokeMeta(BaseModel): tool_config: dict | None = None @classmethod - def empty(cls) -> "ToolInvokeMeta": + def empty(cls) -> ToolInvokeMeta: """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "ToolInvokeMeta": + def error_instance(cls, error: str) -> ToolInvokeMeta: """ Get an instance of ToolInvokeMeta with error """ diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index fbaf31ad09..ef9e9c103a 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import logging @@ -6,7 +8,15 @@ from typing import Any from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError -from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent +from core.mcp.types import ( + AudioContent, + BlobResourceContents, + CallToolResult, + EmbeddedResource, + ImageContent, + TextContent, + TextResourceContents, +) from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType @@ -53,10 +63,19 @@ class MCPTool(Tool): for content in result.content: if isinstance(content, TextContent): yield from self._process_text_content(content) - elif isinstance(content, ImageContent): - yield self._process_image_content(content) - elif isinstance(content, AudioContent): - yield self._process_audio_content(content) + elif isinstance(content, ImageContent | AudioContent): + yield self.create_blob_message( + blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType} + ) + elif isinstance(content, EmbeddedResource): + resource = content.resource + if isinstance(resource, TextResourceContents): + yield self.create_text_message(resource.text) + elif isinstance(resource, BlobResourceContents): + mime_type = resource.mimeType or "application/octet-stream" + yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type}) + else: + raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}") else: logger.warning("Unsupported content type=%s", type(content)) @@ -101,15 +120,7 @@ class MCPTool(Tool): for item in json_list: yield self.create_json_message(item) - def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage: - """Process image content and return a blob message.""" - return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}) - - def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage: - """Process audio content and return a blob message.""" - return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}) - - def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool: return MCPTool( entity=self.entity, runtime=runtime, diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 828dc3b810..d3a2ad488c 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Generator from typing import Any @@ -46,7 +48,7 @@ class PluginTool(Tool): message_id=message_id, ) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool: return PluginTool( entity=self.entity, runtime=runtime, diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index fef3157f27..22e099deba 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -7,12 +7,12 @@ import time from configs import dify_config -def sign_tool_file(tool_file_id: str, extension: str) -> str: +def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str: """ sign file to get a temporary url for plugin access """ - # Use internal URL for plugin/tool file access in Docker environments - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + # Use internal URL for plugin/tool file access in Docker environments, unless for_external is True + base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL) file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 13fd579e20..3f57a346cd 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,5 +1,6 @@ import contextlib import json +import logging from collections.abc import Generator, Iterable from copy import deepcopy from datetime import UTC, datetime @@ -36,6 +37,8 @@ from extensions.ext_database import db from models.enums import CreatorUserRole from models.model import Message, MessageFile +logger = logging.getLogger(__name__) + class ToolEngine: """ @@ -123,25 +126,31 @@ class ToolEngine: # transform tool invoke message to get LLM friendly message return plain_text, message_files, meta except ToolProviderCredentialValidationError as e: + logger.error(e, exc_info=True) error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: error_response = f"there is not a tool named {tool.entity.identity.name}" + logger.error(e, exc_info=True) agent_tool_callback.on_tool_error(e) except ToolParameterValidationError as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" agent_tool_callback.on_tool_error(e) + logger.error(e, exc_info=True) except ToolInvokeError as e: error_response = f"tool invoke error: {e}" agent_tool_callback.on_tool_error(e) + logger.error(e, exc_info=True) except ToolEngineInvokeError as e: meta = e.meta error_response = f"tool invoke error: {meta.error}" agent_tool_callback.on_tool_error(e) + logger.error(e, exc_info=True) return error_response, [], meta except Exception as e: error_response = f"unknown error: {e}" agent_tool_callback.on_tool_error(e) + logger.error(e, exc_info=True) return error_response, [], ToolInvokeMeta.error_instance(error_response) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3486182192..584975de05 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser: @staticmethod def auto_parse_to_tool_bundle( content: str, extra_info: dict | None = None, warning: dict | None = None - ) -> tuple[list[ApiToolBundle], str]: + ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ auto parse to tool bundle diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 0f9a91a111..4bfaa5e49b 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -4,6 +4,7 @@ import re def remove_leading_symbols(text: str) -> str: """ Remove leading punctuation or symbols from the given text. + Preserves markdown links like [text](url) at the start. Args: text (str): The input text to process. @@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str: Returns: str: The text with leading punctuation or symbols removed. """ + # Check if text starts with a markdown link - preserve it + markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)" + if re.match(markdown_link_pattern, text): + return text + # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+' diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 2bd973f831..a706f101ca 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping from pydantic import Field @@ -47,14 +49,13 @@ class WorkflowToolProviderController(ToolProviderController): self.provider_id = provider_id @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": + def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: with session_factory.create_session() as session, session.begin(): app = session.get(App, db_provider.app_id) if not app: raise ValueError("app not found") user = session.get(Account, db_provider.user_id) if db_provider.user_id else None - controller = WorkflowToolProviderController( entity=ToolProviderEntity( identity=ToolProviderIdentity( @@ -67,7 +68,7 @@ class WorkflowToolProviderController(ToolProviderController): credentials_schema=[], plugin_id=None, ), - provider_id="", + provider_id=db_provider.id, ) controller.tools = [ diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 30334f5da8..9c1ceff145 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,12 +1,13 @@ +from __future__ import annotations + import json import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast -from flask import has_request_context from sqlalchemy import select -from sqlalchemy.orm import Session +from core.db.session_factory import session_factory from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool @@ -18,9 +19,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from extensions.ext_database import db from factories.file_factory import build_from_mapping -from libs.login import current_user from models import Account, Tenant from models.model import App, EndUser from models.workflow import Workflow @@ -181,7 +180,7 @@ class WorkflowTool(Tool): return found return None - def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool: """ fork a new tool with metadata @@ -208,50 +207,38 @@ class WorkflowTool(Tool): Returns: Account | EndUser | None: The resolved user object, or None if resolution fails. """ - if has_request_context(): - return self._resolve_user_from_request() - else: - return self._resolve_user_from_database(user_id=user_id) - - def _resolve_user_from_request(self) -> Account | EndUser | None: - """ - Resolve user from Flask request context. - """ - try: - # Note: `current_user` is a LocalProxy. Never compare it with None directly. - return getattr(current_user, "_get_current_object", lambda: current_user)() - except Exception as e: - logger.warning("Failed to resolve user from request context: %s", e) - return None + return self._resolve_user_from_database(user_id=user_id) def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None: """ Resolve user from database (worker/Celery context). """ + with session_factory.create_session() as session: + tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) + tenant = session.scalar(tenant_stmt) + if not tenant: + return None + + user_stmt = select(Account).where(Account.id == user_id) + user = session.scalar(user_stmt) + if user: + user.current_tenant = tenant + session.expunge(user) + return user + + end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) + end_user = session.scalar(end_user_stmt) + if end_user: + session.expunge(end_user) + return end_user - tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) - tenant = db.session.scalar(tenant_stmt) - if not tenant: return None - user_stmt = select(Account).where(Account.id == user_id) - user = db.session.scalar(user_stmt) - if user: - user.current_tenant = tenant - return user - - end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) - end_user = db.session.scalar(end_user_stmt) - if end_user: - return end_user - - return None - def _get_workflow(self, app_id: str, version: str) -> Workflow: """ get the workflow by app id and version """ - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): if not version: stmt = ( select(Workflow) @@ -263,22 +250,24 @@ class WorkflowTool(Tool): stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) workflow = session.scalar(stmt) - if not workflow: - raise ValueError("workflow not found or not published") + if not workflow: + raise ValueError("workflow not found or not published") - return workflow + session.expunge(workflow) + return workflow def _get_app(self, app_id: str) -> App: """ get the app by app id """ stmt = select(App).where(App.id == app_id) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): app = session.scalar(stmt) - if not app: - raise ValueError("app not found") + if not app: + raise ValueError("app not found") - return app + session.expunge(app) + return app def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 7a1cbf9940..7498224923 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -30,6 +30,7 @@ from .variables import ( SecretVariable, StringVariable, Variable, + VariableBase, ) __all__ = [ @@ -62,4 +63,5 @@ __all__ = [ "StringSegment", "StringVariable", "Variable", + "VariableBase", ] diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 406b4e6f93..8330f1fe19 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None: # - All variants in `SegmentUnion` must inherit from the `Segment` class. # - The union must include all non-abstract subclasses of `Segment`, except: # - `SegmentGroup`, which is not added to the variable pool. -# - `Variable` and its subclasses, which are handled by `VariableUnion`. +# - `VariableBase` and its subclasses, which are handled by `Variable`. SegmentUnion: TypeAlias = Annotated[ ( Annotated[NoneSegment, Tag(SegmentType.NONE)] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index ce71711344..13b926c978 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Mapping from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from core.file.models import File @@ -52,7 +54,7 @@ class SegmentType(StrEnum): return self in _ARRAY_TYPES @classmethod - def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + def infer_segment_type(cls, value: Any) -> SegmentType | None: """ Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. @@ -173,7 +175,7 @@ class SegmentType(StrEnum): raise AssertionError("this statement should be unreachable.") @staticmethod - def cast_value(value: Any, type_: "SegmentType"): + def cast_value(value: Any, type_: SegmentType): # Cast Python's `bool` type to `int` when the runtime type requires # an integer or number. # @@ -193,7 +195,7 @@ class SegmentType(StrEnum): return [int(i) for i in value] return value - def exposed_type(self) -> "SegmentType": + def exposed_type(self) -> SegmentType: """Returns the type exposed to the frontend. The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. @@ -202,7 +204,7 @@ class SegmentType(StrEnum): return SegmentType.NUMBER return self - def element_type(self) -> "SegmentType | None": + def element_type(self) -> SegmentType | None: """Return the element type of the current segment type, or `None` if the element type is undefined. Raises: @@ -217,7 +219,7 @@ class SegmentType(StrEnum): return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) @staticmethod - def get_zero_value(t: "SegmentType"): + def get_zero_value(t: SegmentType): # Lazy import to avoid circular dependency from factories import variable_factory diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 9fd0bbc5b2..a19c53918d 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -27,7 +27,7 @@ from .segments import ( from .types import SegmentType -class Variable(Segment): +class VariableBase(Segment): """ A variable is a segment that has a name. @@ -45,23 +45,23 @@ class Variable(Segment): selector: Sequence[str] = Field(default_factory=list) -class StringVariable(StringSegment, Variable): +class StringVariable(StringSegment, VariableBase): pass -class FloatVariable(FloatSegment, Variable): +class FloatVariable(FloatSegment, VariableBase): pass -class IntegerVariable(IntegerSegment, Variable): +class IntegerVariable(IntegerSegment, VariableBase): pass -class ObjectVariable(ObjectSegment, Variable): +class ObjectVariable(ObjectSegment, VariableBase): pass -class ArrayVariable(ArraySegment, Variable): +class ArrayVariable(ArraySegment, VariableBase): pass @@ -89,16 +89,16 @@ class SecretVariable(StringVariable): return encrypter.obfuscated_token(self.value) -class NoneVariable(NoneSegment, Variable): +class NoneVariable(NoneSegment, VariableBase): value_type: SegmentType = SegmentType.NONE value: None = None -class FileVariable(FileSegment, Variable): +class FileVariable(FileSegment, VariableBase): pass -class BooleanVariable(BooleanSegment, Variable): +class BooleanVariable(BooleanSegment, VariableBase): pass @@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel): value: Any -# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Variable` for type hinting when serialization is not required. +# The `Variable` type is used to enable serialization and deserialization with Pydantic. +# Use `VariableBase` for type hinting when serialization is not required. # # Note: -# - All variants in `VariableUnion` must inherit from the `Variable` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -VariableUnion: TypeAlias = Annotated[ +# - All variants in `Variable` must inherit from the `VariableBase` class. +# - The union must include all non-abstract subclasses of `VariableBase`. +Variable: TypeAlias = Annotated[ ( Annotated[NoneVariable, Tag(SegmentType.NONE)] | Annotated[StringVariable, Tag(SegmentType.STRING)] diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md index 72f5dbe1e2..9a39f976a6 100644 --- a/api/core/workflow/README.md +++ b/api/core/workflow/README.md @@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO")) engine.layer(ExecutionLimitsLayer(max_nodes=100)) ``` +`engine.layer()` binds the read-only runtime state before execution, so layer hooks +can assume `graph_runtime_state` is available. + ### Event-Driven Architecture All node executions emit events for monitoring and integration: diff --git a/api/core/workflow/context/__init__.py b/api/core/workflow/context/__init__.py new file mode 100644 index 0000000000..1237d6a017 --- /dev/null +++ b/api/core/workflow/context/__init__.py @@ -0,0 +1,34 @@ +""" +Execution Context - Context management for workflow execution. + +This package provides Flask-independent context management for workflow +execution in multi-threaded environments. +""" + +from core.workflow.context.execution_context import ( + AppContext, + ContextProviderNotFoundError, + ExecutionContext, + IExecutionContext, + NullAppContext, + capture_current_context, + read_context, + register_context, + register_context_capturer, + reset_context_provider, +) +from core.workflow.context.models import SandboxContext + +__all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "IExecutionContext", + "NullAppContext", + "SandboxContext", + "capture_current_context", + "read_context", + "register_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py new file mode 100644 index 0000000000..e3007530f0 --- /dev/null +++ b/api/core/workflow/context/execution_context.py @@ -0,0 +1,284 @@ +""" +Execution Context - Abstracted context management for workflow execution. +""" + +import contextvars +import threading +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Protocol, TypeVar, final, runtime_checkable + +from pydantic import BaseModel + + +class AppContext(ABC): + """ + Abstract application context interface. + + This abstraction allows workflow execution to work with or without Flask + by providing a common interface for application context management. + """ + + @abstractmethod + def get_config(self, key: str, default: Any = None) -> Any: + """Get configuration value by key.""" + pass + + @abstractmethod + def get_extension(self, name: str) -> Any: + """Get Flask extension by name (e.g., 'db', 'cache').""" + pass + + @abstractmethod + def enter(self) -> AbstractContextManager[None]: + """Enter the application context.""" + pass + + +@runtime_checkable +class IExecutionContext(Protocol): + """ + Protocol for execution context. + + This protocol defines the interface that all execution contexts must implement, + allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably. + """ + + def __enter__(self) -> "IExecutionContext": + """Enter the execution context.""" + ... + + def __exit__(self, *args: Any) -> None: + """Exit the execution context.""" + ... + + @property + def user(self) -> Any: + """Get user object.""" + ... + + +@final +class ExecutionContext: + """ + Execution context for workflow execution in worker threads. + + This class encapsulates all context needed for workflow execution: + - Application context (Flask app or standalone) + - Context variables for Python contextvars + - User information (optional) + + It is designed to be serializable and passable to worker threads. + """ + + def __init__( + self, + app_context: AppContext | None = None, + context_vars: contextvars.Context | None = None, + user: Any = None, + ) -> None: + """ + Initialize execution context. + + Args: + app_context: Application context (Flask or standalone) + context_vars: Python contextvars to preserve + user: User object (optional) + """ + self._app_context = app_context + self._context_vars = context_vars + self._user = user + self._local = threading.local() + + @property + def app_context(self) -> AppContext | None: + """Get application context.""" + return self._app_context + + @property + def context_vars(self) -> contextvars.Context | None: + """Get context variables.""" + return self._context_vars + + @property + def user(self) -> Any: + """Get user object.""" + return self._user + + @contextmanager + def enter(self) -> Generator[None, None, None]: + """ + Enter this execution context. + + This is a convenience method that creates a context manager. + """ + # Restore context variables if provided + if self._context_vars: + for var, val in self._context_vars.items(): + var.set(val) + + # Enter app context if available + if self._app_context is not None: + with self._app_context.enter(): + yield + else: + yield + + def __enter__(self) -> "ExecutionContext": + """Enter the execution context.""" + cm = self.enter() + self._local.cm = cm + cm.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + """Exit the execution context.""" + cm = getattr(self._local, "cm", None) + if cm is not None: + cm.__exit__(*args) + + +class NullAppContext(AppContext): + """ + Null implementation of AppContext for non-Flask environments. + + This is used when running without Flask (e.g., in tests or standalone mode). + """ + + def __init__(self, config: dict[str, Any] | None = None) -> None: + """ + Initialize null app context. + + Args: + config: Optional configuration dictionary + """ + self._config = config or {} + self._extensions: dict[str, Any] = {} + + def get_config(self, key: str, default: Any = None) -> Any: + """Get configuration value by key.""" + return self._config.get(key, default) + + def get_extension(self, name: str) -> Any: + """Get extension by name.""" + return self._extensions.get(name) + + def set_extension(self, name: str, extension: Any) -> None: + """Set extension by name.""" + self._extensions[name] = extension + + @contextmanager + def enter(self) -> Generator[None, None, None]: + """Enter null context (no-op).""" + yield + + +class ExecutionContextBuilder: + """ + Builder for creating ExecutionContext instances. + + This provides a fluent API for building execution contexts. + """ + + def __init__(self) -> None: + self._app_context: AppContext | None = None + self._context_vars: contextvars.Context | None = None + self._user: Any = None + + def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder": + """Set application context.""" + self._app_context = app_context + return self + + def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder": + """Set context variables.""" + self._context_vars = context_vars + return self + + def with_user(self, user: Any) -> "ExecutionContextBuilder": + """Set user.""" + self._user = user + return self + + def build(self) -> ExecutionContext: + """Build the execution context.""" + return ExecutionContext( + app_context=self._app_context, + context_vars=self._context_vars, + user=self._user, + ) + + +_capturer: Callable[[], IExecutionContext] | None = None + +# Tenant-scoped providers using tuple keys for clarity and constant-time lookup. +# Key mapping: +# (name, tenant_id) -> provider +# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox") +# - tenant_id: tenant identifier string +# Value: +# provider: Callable[[], BaseModel] returning the typed context value +# Type-safety note: +# - This registry cannot enforce that all providers for a given name return the same BaseModel type. +# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice), +# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and +# def read_sandbox_ctx(tenant_id: str) -> SandboxContext. +_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} + +T = TypeVar("T", bound=BaseModel) + + +class ContextProviderNotFoundError(KeyError): + """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id).""" + + pass + + +def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: + """Register a single enterable execution context capturer (e.g., Flask).""" + global _capturer + _capturer = capturer + + +def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None: + """Register a tenant-specific provider for a named context. + + Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions. + Consider adding a typed wrapper for this registration in your feature module. + """ + _tenant_context_providers[(name, tenant_id)] = provider + + +def read_context(name: str, *, tenant_id: str) -> BaseModel: + """ + Read a context value for a specific tenant. + + Raises KeyError if the provider for (name, tenant_id) is not registered. + """ + prov = _tenant_context_providers.get((name, tenant_id)) + if prov is None: + raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'") + return prov() + + +def capture_current_context() -> IExecutionContext: + """ + Capture current execution context from the calling environment. + + If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal + context with NullAppContext + copy of current contextvars. + """ + if _capturer is None: + return ExecutionContext( + app_context=NullAppContext(), + context_vars=contextvars.copy_context(), + ) + return _capturer() + + +def reset_context_provider() -> None: + """Reset the capturer and all tenant-scoped context providers (primarily for tests).""" + global _capturer + _capturer = None + _tenant_context_providers.clear() diff --git a/api/core/workflow/context/models.py b/api/core/workflow/context/models.py new file mode 100644 index 0000000000..af5a4b2614 --- /dev/null +++ b/api/core/workflow/context/models.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from pydantic import AnyHttpUrl, BaseModel + + +class SandboxContext(BaseModel): + """Typed context for sandbox integration. All fields optional by design.""" + + sandbox_url: AnyHttpUrl | None = None + sandbox_token: str | None = None # optional, if later needed for auth + + +__all__ = ["SandboxContext"] diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py index fd78248c17..75f47691da 100644 --- a/api/core/workflow/conversation_variable_updater.py +++ b/api/core/workflow/conversation_variable_updater.py @@ -1,7 +1,7 @@ import abc from typing import Protocol -from core.variables import Variable +from core.variables import VariableBase class ConversationVariableUpdater(Protocol): @@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol): """ @abc.abstractmethod - def update(self, conversation_id: str, variable: "Variable"): + def update(self, conversation_id: str, variable: "VariableBase"): """ Updates the value of the specified conversation variable in the underlying storage. :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `Variable` instance containing the updated value. + :param variable: The `VariableBase` instance containing the updated value. """ pass diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index a8a86d3db2..1b3fb36f1f 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain implementation details like tenant_id, app_id, etc. """ +from __future__ import annotations + from collections.abc import Mapping from datetime import datetime from typing import Any @@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel): graph: Mapping[str, Any], inputs: Mapping[str, Any], started_at: datetime, - ) -> "WorkflowExecution": + ) -> WorkflowExecution: return WorkflowExecution( id_=id_, workflow_id=workflow_id, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index c08b62a253..bb3b13e8c6 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum): def is_ended(self) -> bool: return self in _END_STATE + @classmethod + def ended_values(cls) -> list[str]: + return [status.value for status in _END_STATE] + _END_STATE = frozenset( [ diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index ba5a01fc94..31bf6f3b27 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from collections import defaultdict from collections.abc import Mapping, Sequence @@ -175,7 +177,7 @@ class Graph: def _create_node_instances( cls, node_configs_map: dict[str, dict[str, object]], - node_factory: "NodeFactory", + node_factory: NodeFactory, ) -> dict[str, Node]: """ Create node instances from configurations using the node factory. @@ -197,7 +199,7 @@ class Graph: return nodes @classmethod - def new(cls) -> "GraphBuilder": + def new(cls) -> GraphBuilder: """Create a fluent builder for assembling a graph programmatically.""" return GraphBuilder(graph_cls=cls) @@ -284,9 +286,10 @@ class Graph: cls, *, graph_config: Mapping[str, object], - node_factory: "NodeFactory", + node_factory: NodeFactory, root_node_id: str | None = None, - ) -> "Graph": + skip_validation: bool = False, + ) -> Graph: """ Initialize graph @@ -337,8 +340,9 @@ class Graph: root_node=root_node, ) - # Validate the graph structure using built-in validators - get_graph_validator().validate(graph) + if not skip_validation: + # Validate the graph structure using built-in validators + get_graph_validator().validate(graph) return graph @@ -383,7 +387,7 @@ class GraphBuilder: self._edges: list[Edge] = [] self._edge_counter = 0 - def add_root(self, node: Node) -> "GraphBuilder": + def add_root(self, node: Node) -> GraphBuilder: """Register the root node. Must be called exactly once.""" if self._nodes: @@ -398,7 +402,7 @@ class GraphBuilder: *, from_node_id: str | None = None, source_handle: str = "source", - ) -> "GraphBuilder": + ) -> GraphBuilder: """Append a node and connect it from the specified predecessor.""" if not self._nodes: @@ -419,7 +423,7 @@ class GraphBuilder: return self - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder": + def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: """Connect two existing nodes without adding a new node.""" if tail not in self._nodes_by_id: diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 4be3adb8f8..0fccd4a0fd 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue. import json from typing import TYPE_CHECKING, Any, final -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand if TYPE_CHECKING: from extensions.ext_redis import RedisClientWrapper @@ -113,6 +113,8 @@ class RedisChannel: return AbortCommand.model_validate(data) if command_type == CommandType.PAUSE: return PauseCommand.model_validate(data) + if command_type == CommandType.UPDATE_VARIABLES: + return UpdateVariablesCommand.model_validate(data) # For other command types, use base class return GraphEngineCommand.model_validate(data) diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/core/workflow/graph_engine/command_processing/__init__.py index 837f5e55fd..7b4f0dfff7 100644 --- a/api/core/workflow/graph_engine/command_processing/__init__.py +++ b/api/core/workflow/graph_engine/command_processing/__init__.py @@ -5,11 +5,12 @@ This package handles external commands sent to the engine during execution. """ -from .command_handlers import AbortCommandHandler, PauseCommandHandler +from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler from .command_processor import CommandProcessor __all__ = [ "AbortCommandHandler", "CommandProcessor", "PauseCommandHandler", + "UpdateVariablesCommandHandler", ] diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index e9f109c88c..cfe856d9e8 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -4,9 +4,10 @@ from typing import final from typing_extensions import override from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.runtime import VariablePool from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand from .command_processor import CommandHandler logger = logging.getLogger(__name__) @@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler): reason = command.reason pause_reason = SchedulingPause(message=reason) execution.pause(pause_reason) + + +@final +class UpdateVariablesCommandHandler(CommandHandler): + def __init__(self, variable_pool: VariablePool) -> None: + self._variable_pool = variable_pool + + @override + def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: + assert isinstance(command, UpdateVariablesCommand) + for update in command.updates: + try: + variable = update.value + self._variable_pool.add(variable.selector, variable) + logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) + except ValueError as exc: + logger.warning( + "Skipping invalid variable selector %s for workflow %s: %s", + getattr(update.value, "selector", None), + execution.workflow_id, + exc, + ) diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 0d51b2b716..41276eb444 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine instance to control its execution flow. """ -from enum import StrEnum +from collections.abc import Sequence +from enum import StrEnum, auto from typing import Any from pydantic import BaseModel, Field +from core.variables.variables import Variable + class CommandType(StrEnum): """Types of commands that can be sent to GraphEngine.""" - ABORT = "abort" - PAUSE = "pause" + ABORT = auto() + PAUSE = auto() + UPDATE_VARIABLES = auto() class GraphEngineCommand(BaseModel): @@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand): command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") reason: str = Field(default="unknown reason", description="reason for pause") + + +class VariableUpdate(BaseModel): + """Represents a single variable update instruction.""" + + value: Variable = Field(description="New variable value") + + +class UpdateVariablesCommand(GraphEngineCommand): + """Command to update a group of variables in the variable pool.""" + + command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") + updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 744013cb04..187dfcf021 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,14 +5,15 @@ This engine uses a modular architecture with separated packages following Domain-Driven Design principles for improved maintainability and testability. """ -import contextvars +from __future__ import annotations + import logging import queue +import threading from collections.abc import Generator from typing import TYPE_CHECKING, cast, final -from flask import Flask, current_app - +from core.workflow.context import capture_current_context from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph @@ -31,8 +32,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr if TYPE_CHECKING: # pragma: no cover - used only for static analysis from core.workflow.runtime.graph_runtime_state import GraphProtocol -from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler -from .entities.commands import AbortCommand, PauseCommand +from .command_processing import ( + AbortCommandHandler, + CommandProcessor, + PauseCommandHandler, + UpdateVariablesCommandHandler, +) +from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager from .graph_state_manager import GraphStateManager @@ -71,10 +77,13 @@ class GraphEngine: scale_down_idle_time: float | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" + # stop event + self._stop_event = threading.Event() # Bind runtime state to current workflow context self._graph = graph self._graph_runtime_state = graph_runtime_state + self._graph_runtime_state.stop_event = self._stop_event self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel @@ -141,22 +150,16 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) + self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) + # === Extensibility === # Layers allow plugins to extend engine functionality self._layers: list[GraphEngineLayer] = [] # === Worker Pool Setup === - # Capture Flask app context for worker threads - flask_app: Flask | None = None - try: - app = current_app._get_current_object() # type: ignore - if isinstance(app, Flask): - flask_app = app - except RuntimeError: - pass - - # Capture context variables for worker threads - context_vars = contextvars.copy_context() + # Capture execution context for worker threads + execution_context = capture_current_context() # Create worker pool for parallel node execution self._worker_pool = WorkerPool( @@ -164,12 +167,12 @@ class GraphEngine: event_queue=self._event_queue, graph=self._graph, layers=self._layers, - flask_app=flask_app, - context_vars=context_vars, + execution_context=execution_context, min_workers=self._min_workers, max_workers=self._max_workers, scale_up_threshold=self._scale_up_threshold, scale_down_idle_time=self._scale_down_idle_time, + stop_event=self._stop_event, ) # === Orchestration === @@ -200,6 +203,7 @@ class GraphEngine: event_handler=self._event_handler_registry, execution_coordinator=self._execution_coordinator, event_emitter=self._event_manager, + stop_event=self._stop_event, ) # === Validation === @@ -213,9 +217,16 @@ class GraphEngine: if id(node.graph_runtime_state) != expected_state_id: raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") - def layer(self, layer: GraphEngineLayer) -> "GraphEngine": + def _bind_layer_context( + self, + layer: GraphEngineLayer, + ) -> None: + layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) + + def layer(self, layer: GraphEngineLayer) -> GraphEngine: """Add a layer for extending functionality.""" self._layers.append(layer) + self._bind_layer_context(layer) return self def run(self) -> Generator[GraphEngineEvent, None, None]: @@ -304,14 +315,7 @@ class GraphEngine: def _initialize_layers(self) -> None: """Initialize layers with context.""" self._event_manager.set_layers(self._layers) - # Create a read-only wrapper for the runtime state - read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state) for layer in self._layers: - try: - layer.initialize(read_only_state, self._command_channel) - except Exception: - logger.exception("Failed to initialize layer %s", layer.__class__.__name__) - try: layer.on_graph_start() except Exception: @@ -319,6 +323,7 @@ class GraphEngine: def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" + self._stop_event.clear() paused_nodes: list[str] = [] deferred_nodes: list[str] = [] if resume: @@ -352,13 +357,12 @@ class GraphEngine: def _stop_execution(self) -> None: """Stop execution subsystems.""" + self._stop_event.set() self._dispatcher.stop() self._worker_pool.stop() # Don't mark complete here as the dispatcher already does it # Notify layers - logger = logging.getLogger(__name__) - for layer in self._layers: try: layer.on_graph_end(self._graph_execution.error) diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py index 78f8ecdcdf..b9c9243963 100644 --- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py +++ b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py @@ -60,6 +60,7 @@ class SkipPropagator: if edge_states["has_taken"]: # Enqueue node self._state_manager.enqueue_node(downstream_node_id) + self._state_manager.start_execution(downstream_node_id) return # All edges are skipped, propagate skip to this node diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/core/workflow/graph_engine/layers/README.md index 17845ee1f0..b0f295037c 100644 --- a/api/core/workflow/graph_engine/layers/README.md +++ b/api/core/workflow/graph_engine/layers/README.md @@ -8,7 +8,7 @@ Pluggable middleware for engine extensions. Abstract base class for layers. -- `initialize()` - Receive runtime context +- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) - `on_graph_start()` - Execution start hook - `on_event()` - Process all events - `on_graph_end()` - Execution end hook @@ -34,6 +34,9 @@ engine.layer(debug_layer) engine.run() ``` +`engine.layer()` binds the read-only runtime state before execution, so +`graph_runtime_state` is always available inside layer hooks. + ## Custom Layers ```python diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/core/workflow/graph_engine/layers/__init__.py index 772433e48c..0a29a52993 100644 --- a/api/core/workflow/graph_engine/layers/__init__.py +++ b/api/core/workflow/graph_engine/layers/__init__.py @@ -8,11 +8,9 @@ with middleware-like components that can observe events and interact with execut from .base import GraphEngineLayer from .debug_logging import DebugLoggingLayer from .execution_limits import ExecutionLimitsLayer -from .observability import ObservabilityLayer __all__ = [ "DebugLoggingLayer", "ExecutionLimitsLayer", "GraphEngineLayer", - "ObservabilityLayer", ] diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 780f92a0f4..ff4a483aed 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -8,11 +8,19 @@ intercept and respond to GraphEngine events. from abc import ABC, abstractmethod from core.workflow.graph_engine.protocols.command_channel import CommandChannel -from core.workflow.graph_events import GraphEngineEvent +from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState +class GraphEngineLayerNotInitializedError(Exception): + """Raised when a layer's runtime state is accessed before initialization.""" + + def __init__(self, layer_name: str | None = None) -> None: + name = layer_name or "GraphEngineLayer" + super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") + + class GraphEngineLayer(ABC): """ Abstract base class for GraphEngine layers. @@ -28,22 +36,27 @@ class GraphEngineLayer(ABC): def __init__(self) -> None: """Initialize the layer. Subclasses can override with custom parameters.""" - self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None + self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None self.command_channel: CommandChannel | None = None + @property + def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: + if self._graph_runtime_state is None: + raise GraphEngineLayerNotInitializedError(type(self).__name__) + return self._graph_runtime_state + def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: """ Initialize the layer with engine dependencies. - Called by GraphEngine before execution starts to inject the read-only runtime state - and command channel. This allows layers to observe engine context and send - commands, but prevents direct state modification. - + Called by GraphEngine to inject the read-only runtime state and command channel. + This is invoked when the layer is registered with a `GraphEngine` instance. + Implementations should be idempotent. Args: graph_runtime_state: Read-only view of the runtime state command_channel: Channel for sending commands to the engine """ - self.graph_runtime_state = graph_runtime_state + self._graph_runtime_state = graph_runtime_state self.command_channel = command_channel @abstractmethod @@ -85,7 +98,7 @@ class GraphEngineLayer(ABC): """ pass - def on_node_run_start(self, node: Node) -> None: # noqa: B027 + def on_node_run_start(self, node: Node) -> None: """ Called immediately before a node begins execution. @@ -96,9 +109,11 @@ class GraphEngineLayer(ABC): Args: node: The node instance about to be executed """ - pass + return - def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027 + def on_node_run_end( + self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: """ Called after a node finishes execution. @@ -108,5 +123,6 @@ class GraphEngineLayer(ABC): Args: node: The node instance that just finished execution error: Exception instance if the node failed, otherwise None + result_event: The final result event from node execution (succeeded/failed/paused), if any """ - pass + return diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py index 034ebcf54f..e0402cd09c 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/core/workflow/graph_engine/layers/debug_logging.py @@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info("=" * 80) self.logger.info("🚀 GRAPH EXECUTION STARTED") self.logger.info("=" * 80) - - if self.graph_runtime_state: - # Log initial state - self.logger.info("Initial State:") + # Log initial state + self.logger.info("Initial State:") @override def on_event(self, event: GraphEngineEvent) -> None: @@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info(" Node retries: %s", self.retry_count) # Log final state if available - if self.graph_runtime_state and self.include_outputs: - if self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) + if self.include_outputs and self.graph_runtime_state.outputs: + self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) self.logger.info("=" * 80) diff --git a/api/core/workflow/graph_engine/layers/node_parsers.py b/api/core/workflow/graph_engine/layers/node_parsers.py deleted file mode 100644 index b6bac794df..0000000000 --- a/api/core/workflow/graph_engine/layers/node_parsers.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -Node-level OpenTelemetry parser interfaces and defaults. -""" - -import json -from typing import Protocol - -from opentelemetry.trace import Span -from opentelemetry.trace.status import Status, StatusCode - -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.tool.entities import ToolNodeData - - -class NodeOTelParser(Protocol): - """Parser interface for node-specific OpenTelemetry enrichment.""" - - def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ... - - -class DefaultNodeOTelParser: - """Fallback parser used when no node-specific parser is registered.""" - - def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: - span.set_attribute("node.id", node.id) - if node.execution_id: - span.set_attribute("node.execution_id", node.execution_id) - if hasattr(node, "node_type") and node.node_type: - span.set_attribute("node.type", node.node_type.value) - - if error: - span.record_exception(error) - span.set_status(Status(StatusCode.ERROR, str(error))) - else: - span.set_status(Status(StatusCode.OK)) - - -class ToolNodeOTelParser: - """Parser for tool nodes that captures tool-specific metadata.""" - - def __init__(self) -> None: - self._delegate = DefaultNodeOTelParser() - - def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: - self._delegate.parse(node=node, span=span, error=error) - - tool_data = getattr(node, "_node_data", None) - if not isinstance(tool_data, ToolNodeData): - return - - span.set_attribute("tool.provider.id", tool_data.provider_id) - span.set_attribute("tool.provider.type", tool_data.provider_type.value) - span.set_attribute("tool.provider.name", tool_data.provider_name) - span.set_attribute("tool.name", tool_data.tool_name) - span.set_attribute("tool.label", tool_data.tool_label) - if tool_data.plugin_unique_identifier: - span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier) - if tool_data.credential_id: - span.set_attribute("tool.credential.id", tool_data.credential_id) - if tool_data.tool_configurations: - span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False)) diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index 0577ba8f02..d2cfa755d9 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel. This module provides a simplified interface for controlling workflow executions using the new Redis command channel, without requiring user permission checks. -Supports stop, pause, and resume operations. """ import logging +from collections.abc import Sequence from typing import final from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + GraphEngineCommand, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -23,7 +29,6 @@ class GraphEngineManager: This class provides a simple interface for controlling workflow executions by sending commands through Redis channels, without user validation. - Supports stop and pause operations. """ @staticmethod @@ -45,6 +50,16 @@ class GraphEngineManager: pause_command = PauseCommand(reason=reason or "User requested pause") GraphEngineManager._send_command(task_id, pause_command) + @staticmethod + def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None: + """Send a command to update variables in a running workflow.""" + + if not updates: + return + + update_command = UpdateVariablesCommand(updates=updates) + GraphEngineManager._send_command(task_id, update_command) + @staticmethod def _send_command(task_id: str, command: GraphEngineCommand) -> None: """Send a command to the workflow-specific Redis channel.""" diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 685347fbce..d40d15c545 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -44,6 +44,7 @@ class Dispatcher: event_queue: queue.Queue[GraphNodeEventBase], event_handler: "EventHandler", execution_coordinator: ExecutionCoordinator, + stop_event: threading.Event, event_emitter: EventManager | None = None, ) -> None: """ @@ -61,7 +62,7 @@ class Dispatcher: self._event_emitter = event_emitter self._thread: threading.Thread | None = None - self._stop_event = threading.Event() + self._stop_event = stop_event self._start_time: float | None = None def start(self) -> None: @@ -69,16 +70,14 @@ class Dispatcher: if self._thread and self._thread.is_alive(): return - self._stop_event.clear() self._start_time = time.time() self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) self._thread.start() def stop(self) -> None: """Stop the dispatcher thread.""" - self._stop_event.set() if self._thread and self._thread.is_alive(): - self._thread.join(timeout=10.0) + self._thread.join(timeout=2.0) def _dispatcher_loop(self) -> None: """Main dispatcher loop.""" diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/core/workflow/graph_engine/ready_queue/factory.py index 1144e1de69..a9d4f470e5 100644 --- a/api/core/workflow/graph_engine/ready_queue/factory.py +++ b/api/core/workflow/graph_engine/ready_queue/factory.py @@ -2,6 +2,8 @@ Factory for creating ReadyQueue instances from serialized state. """ +from __future__ import annotations + from typing import TYPE_CHECKING from .in_memory import InMemoryReadyQueue @@ -11,7 +13,7 @@ if TYPE_CHECKING: from .protocol import ReadyQueue -def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue": +def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: """ Create a ReadyQueue instance from a serialized state. diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py index 8b7c2e441e..8ceaa428c3 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/core/workflow/graph_engine/response_coordinator/session.py @@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally by ResponseStreamCoordinator to manage streaming sessions. """ +from __future__ import annotations + from dataclasses import dataclass from core.workflow.nodes.answer.answer_node import AnswerNode @@ -27,7 +29,7 @@ class ResponseSession: index: int = 0 # Current position in the template segments @classmethod - def from_node(cls, node: Node) -> "ResponseSession": + def from_node(cls, node: Node) -> ResponseSession: """ Create a ResponseSession from an AnswerNode or EndNode. diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index e37a08ae47..512df6ff86 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -5,26 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events to the event_queue for the dispatcher to process. """ -import contextvars import queue import threading import time from collections.abc import Sequence from datetime import datetime -from typing import final -from uuid import uuid4 +from typing import TYPE_CHECKING, final -from flask import Flask from typing_extensions import override +from core.workflow.context import IExecutionContext from core.workflow.graph import Graph from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent +from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event from core.workflow.nodes.base.node import Node -from libs.flask_utils import preserve_flask_contexts from .ready_queue import ReadyQueue +if TYPE_CHECKING: + pass + @final class Worker(threading.Thread): @@ -42,9 +42,9 @@ class Worker(threading.Thread): event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: Sequence[GraphEngineLayer], + stop_event: threading.Event, worker_id: int = 0, - flask_app: Flask | None = None, - context_vars: contextvars.Context | None = None, + execution_context: IExecutionContext | None = None, ) -> None: """ Initialize worker thread. @@ -55,23 +55,24 @@ class Worker(threading.Thread): graph: Graph containing nodes to execute layers: Graph engine layers for node execution hooks worker_id: Unique identifier for this worker - flask_app: Optional Flask application for context preservation - context_vars: Optional context variables to preserve in worker thread + execution_context: Optional execution context for context preservation """ super().__init__(name=f"GraphWorker-{worker_id}", daemon=True) self._ready_queue = ready_queue self._event_queue = event_queue self._graph = graph self._worker_id = worker_id - self._flask_app = flask_app - self._context_vars = context_vars - self._stop_event = threading.Event() - self._last_task_time = time.time() + self._execution_context = execution_context + self._stop_event = stop_event self._layers = layers if layers is not None else [] + self._last_task_time = time.time() def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() + """Worker is controlled via shared stop_event from GraphEngine. + + This method is a no-op retained for backward compatibility. + """ + pass @property def is_idle(self) -> bool: @@ -111,7 +112,7 @@ class Worker(threading.Thread): self._ready_queue.task_done() except Exception as e: error_event = NodeRunFailedEvent( - id=str(uuid4()), + id=node.execution_id, node_id=node.id, node_type=node.node_type, in_iteration_id=None, @@ -130,33 +131,36 @@ class Worker(threading.Thread): node.ensure_execution_id() error: Exception | None = None + result_event: GraphNodeEventBase | None = None - if self._flask_app and self._context_vars: - with preserve_flask_contexts( - flask_app=self._flask_app, - context_vars=self._context_vars, - ): + # Execute the node with preserved context if execution context is provided + if self._execution_context is not None: + with self._execution_context: self._invoke_node_run_start_hooks(node) try: node_events = node.run() for event in node_events: self._event_queue.put(event) + if is_node_result_event(event): + result_event = event except Exception as exc: error = exc raise finally: - self._invoke_node_run_end_hooks(node, error) + self._invoke_node_run_end_hooks(node, error, result_event) else: self._invoke_node_run_start_hooks(node) try: node_events = node.run() for event in node_events: self._event_queue.put(event) + if is_node_result_event(event): + result_event = event except Exception as exc: error = exc raise finally: - self._invoke_node_run_end_hooks(node, error) + self._invoke_node_run_end_hooks(node, error, result_event) def _invoke_node_run_start_hooks(self, node: Node) -> None: """Invoke on_node_run_start hooks for all layers.""" @@ -167,11 +171,13 @@ class Worker(threading.Thread): # Silently ignore layer errors to prevent disrupting node execution continue - def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None: + def _invoke_node_run_end_hooks( + self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: """Invoke on_node_run_end hooks for all layers.""" for layer in self._layers: try: - layer.on_node_run_end(node, error) + layer.on_node_run_end(node, error, result_event) except Exception: # Silently ignore layer errors to prevent disrupting node execution continue diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 5b9234586b..9ce7d16e93 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class. import logging import queue import threading -from typing import TYPE_CHECKING, final +from typing import final from configs import dify_config +from core.workflow.context import IExecutionContext from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase @@ -20,11 +21,6 @@ from ..worker import Worker logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from contextvars import Context - - from flask import Flask - @final class WorkerPool: @@ -41,8 +37,8 @@ class WorkerPool: event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: list[GraphEngineLayer], - flask_app: "Flask | None" = None, - context_vars: "Context | None" = None, + stop_event: threading.Event, + execution_context: IExecutionContext | None = None, min_workers: int | None = None, max_workers: int | None = None, scale_up_threshold: int | None = None, @@ -56,8 +52,7 @@ class WorkerPool: event_queue: Queue for worker events graph: The workflow graph layers: Graph engine layers for node execution hooks - flask_app: Optional Flask app for context preservation - context_vars: Optional context variables + execution_context: Optional execution context for context preservation min_workers: Minimum number of workers max_workers: Maximum number of workers scale_up_threshold: Queue depth to trigger scale up @@ -66,8 +61,7 @@ class WorkerPool: self._ready_queue = ready_queue self._event_queue = event_queue self._graph = graph - self._flask_app = flask_app - self._context_vars = context_vars + self._execution_context = execution_context self._layers = layers # Scaling parameters with defaults @@ -81,6 +75,7 @@ class WorkerPool: self._worker_counter = 0 self._lock = threading.RLock() self._running = False + self._stop_event = stop_event # No longer tracking worker states with callbacks to avoid lock contention @@ -135,7 +130,7 @@ class WorkerPool: # Wait for workers to finish for worker in self._workers: if worker.is_alive(): - worker.join(timeout=10.0) + worker.join(timeout=2.0) self._workers.clear() @@ -150,8 +145,8 @@ class WorkerPool: graph=self._graph, layers=self._layers, worker_id=worker_id, - flask_app=self._flask_app, - context_vars=self._context_vars, + execution_context=self._execution_context, + stop_event=self._stop_event, ) worker.start() diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 7bb6346cb7..56ea642092 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -46,6 +46,7 @@ from .node import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + is_node_result_event, ) __all__ = [ @@ -77,4 +78,5 @@ __all__ = [ "NodeRunStartedEvent", "NodeRunStreamChunkEvent", "NodeRunSucceededEvent", + "is_node_result_event", ] diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index 140b4de1da..975d72ad1f 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -72,3 +72,26 @@ class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): class NodeRunPauseRequestedEvent(GraphNodeEventBase): reason: PauseReason = Field(..., description="pause reason") + + +def is_node_result_event(event: GraphNodeEventBase) -> bool: + """ + Check if an event is a final result event from node execution. + + A result event indicates the completion of a node execution and contains + runtime information such as inputs, outputs, or error details. + + Args: + event: The event to check + + Returns: + True if the event is a node result event (succeeded/failed/paused), False otherwise + """ + return isinstance( + event, + ( + NodeRunSucceededEvent, + NodeRunFailedEvent, + NodeRunPauseRequestedEvent, + ), + ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 4be006de11..5a365f769d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, cast @@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]): variable_pool: VariablePool, node_data: AgentNodeData, for_log: bool = False, - strategy: "PluginAgentStrategy", + strategy: PluginAgentStrategy, ) -> dict[str, Any]: """ Generate parameters based on the given tool parameters, variable pool, and node data. @@ -233,7 +235,18 @@ class AgentNode(Node[AgentNodeData]): 0, ): value_param = param.get("value", {}) - params[key] = value_param.get("value", "") if value_param is not None else None + if value_param and value_param.get("type", "") == "variable": + variable_selector = value_param.get("value") + if not variable_selector: + raise ValueError("Variable selector is missing for a variable-type parameter.") + + variable = variable_pool.get(variable_selector) + if variable is None: + raise AgentVariableNotFoundError(str(variable_selector)) + + params[key] = variable.value + else: + params[key] = value_param.get("value", "") if value_param is not None else None else: params[key] = None parameters = params @@ -328,7 +341,7 @@ class AgentNode(Node[AgentNodeData]): def _generate_credentials( self, parameters: dict[str, Any], - ) -> "InvokeCredentials": + ) -> InvokeCredentials: """ Generate credentials based on the given agent parameters. """ @@ -442,9 +455,7 @@ class AgentNode(Node[AgentNodeData]): model_schema.features.remove(feature) return model_schema - def _filter_mcp_type_tool( - self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]] - ) -> list[dict[str, Any]]: + def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Filter MCP type tool :param strategy: plugin agent strategy @@ -483,7 +494,7 @@ class AgentNode(Node[AgentNodeData]): text = "" files: list[File] = [] - json_list: list[dict] = [] + json_list: list[dict | list] = [] agent_logs: list[AgentLogEvent] = [] agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} @@ -557,13 +568,18 @@ class AgentNode(Node[AgentNodeData]): elif message.type == ToolInvokeMessage.MessageType.JSON: assert isinstance(message.message, ToolInvokeMessage.JsonMessage) if node_type == NodeType.AGENT: - msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) - agent_execution_metadata = { - WorkflowNodeExecutionMetadataKey(key): value - for key, value in msg_metadata.items() - if key in WorkflowNodeExecutionMetadataKey.__members__.values() - } + if isinstance(message.message.json_object, dict): + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) + agent_execution_metadata = { + WorkflowNodeExecutionMetadataKey(key): value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + else: + msg_metadata = {} + llm_usage = LLMUsage.empty_usage() + agent_execution_metadata = {} if message.message.json_object: json_list.append(message.message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: @@ -672,7 +688,7 @@ class AgentNode(Node[AgentNodeData]): yield agent_log # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any]] = [] + json_output: list[dict[str, Any] | list[Any]] = [] # Step 1: append each agent log as its own dict. if agent_logs: diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py index 944f5f0b20..ba2c83d8a6 100644 --- a/api/core/workflow/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exc.py @@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError): self.expected_type = expected_type self.actual_type = actual_type super().__init__(message) + + +class AgentMaxIterationError(AgentNodeError): + """Exception raised when the agent exceeds the maximum iteration limit.""" + + def __init__(self, max_iteration: int): + self.max_iteration = max_iteration + super().__init__( + f"Agent exceeded the maximum iteration limit of {max_iteration}. " + f"The agent was unable to complete the task within the allowed number of iterations." + ) diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 5aab6bbde4..e5a20c8e91 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from abc import ABC from builtins import type as type_ @@ -111,7 +113,7 @@ class DefaultValue(BaseModel): raise DefaultValueTypeError(f"Cannot convert to number: {value}") @model_validator(mode="after") - def validate_value_type(self) -> "DefaultValue": + def validate_value_type(self) -> DefaultValue: # Type validation configuration type_validators = { DefaultValueType.STRING: { diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index a08c6d3ed8..2b773b537c 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import logging import operator @@ -72,7 +74,7 @@ class Node(Generic[NodeDataT]): in its output. """ - node_type: ClassVar["NodeType"] + node_type: ClassVar[NodeType] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData @@ -211,14 +213,14 @@ class Node(Generic[NodeDataT]): return None # Global registry populated via __init_subclass__ - _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {} + _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} def __init__( self, id: str, config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, ) -> None: self._graph_init_params = graph_init_params self.id = id @@ -254,7 +256,7 @@ class Node(Generic[NodeDataT]): return @property - def graph_init_params(self) -> "GraphInitParams": + def graph_init_params(self) -> GraphInitParams: return self._graph_init_params @property @@ -300,6 +302,10 @@ class Node(Generic[NodeDataT]): """ raise NotImplementedError + def _should_stop(self) -> bool: + """Check if execution should be stopped.""" + return self.graph_runtime_state.stop_event.is_set() + def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() @@ -368,6 +374,21 @@ class Node(Generic[NodeDataT]): yield event else: yield event + + if self._should_stop(): + error_message = "Execution cancelled" + yield NodeRunFailedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + start_at=self._start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error_message, + ), + error=error_message, + ) + return except Exception as e: logger.exception("Node %s failed to run", self._node_id) result = NodeRunResult( @@ -474,7 +495,7 @@ class Node(Generic[NodeDataT]): raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod - def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]: + def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. Import all modules under core.workflow.nodes so subclasses register themselves on import. @@ -484,12 +505,8 @@ class Node(Generic[NodeDataT]): import core.workflow.nodes as _nodes_pkg for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."): - # Avoid importing modules that depend on the registry to prevent circular imports - # e.g. node_factory imports node_mapping which builds the mapping here. - if _modname in { - "core.workflow.nodes.node_factory", - "core.workflow.nodes.node_mapping", - }: + # Avoid importing modules that depend on the registry to prevent circular imports. + if _modname == "core.workflow.nodes.node_mapping": continue importlib.import_module(_modname) diff --git a/api/core/workflow/nodes/base/template.py b/api/core/workflow/nodes/base/template.py index ba3e2058cf..81f4b9f6fb 100644 --- a/api/core/workflow/nodes/base/template.py +++ b/api/core/workflow/nodes/base/template.py @@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes, similar to SegmentGroup but focused on template representation without values. """ +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -58,7 +60,7 @@ class Template: segments: list[TemplateSegmentUnion] @classmethod - def from_answer_template(cls, template_str: str) -> "Template": + def from_answer_template(cls, template_str: str) -> Template: """Create a Template from an Answer node template string. Example: @@ -107,7 +109,7 @@ class Template: return cls(segments=segments) @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template": + def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: """Create a Template from an End node outputs configuration. End nodes are treated as templates of concatenated variables with newlines. diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index a38e10030a..e3035d3bf0 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,8 +1,7 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import Any, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast -from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider @@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.limits import CodeNodeLimits from .exc import ( CodeNodeError, @@ -20,9 +20,41 @@ from .exc import ( OutputValidationError, ) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + class CodeNode(Node[CodeNodeData]): node_type = NodeType.CODE + _DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = ( + Python3CodeProvider, + JavascriptCodeProvider, + ) + _limits: CodeNodeLimits + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + code_executor: type[CodeExecutor] | None = None, + code_providers: Sequence[type[CodeNodeProvider]] | None = None, + code_limits: CodeNodeLimits, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor + self._code_providers: tuple[type[CodeNodeProvider], ...] = ( + tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS + ) + self._limits = code_limits @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]): if filters: code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] - code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) + code_provider: type[CodeNodeProvider] = next( + provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language) + ) return code_provider.get_default_config() + @classmethod + def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]: + return cls._DEFAULT_CODE_PROVIDERS + @classmethod def version(cls) -> str: return "1" @@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]): variables[variable_name] = variable.to_object() if variable else None # Run code try: - result = CodeExecutor.execute_workflow_code_template( + _ = self._select_code_provider(code_language) + result = self._code_executor.execute_workflow_code_template( language=code_language, code=code, inputs=variables, @@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]): return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) + def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]: + for provider in self._code_providers: + if provider.is_accept_language(code_language): + return provider + raise CodeNodeError(f"Unsupported code language: {code_language}") + def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string @@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]): if value is None: return None - if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + if len(value) > self._limits.max_string_length: raise OutputValidationError( f"The length of output variable `{variable}` must be" - f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" + f" less than {self._limits.max_string_length} characters" ) return value.replace("\x00", "") @@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]): if value is None: return None - if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: + if value > self._limits.max_number or value < self._limits.min_number: raise OutputValidationError( f"Output variable `{variable}` is out of range," - f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." + f" it must be between {self._limits.min_number} and {self._limits.max_number}." ) if isinstance(value, float): decimal_value = Decimal(str(value)).normalize() precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] # raise error if precision is too high - if precision > dify_config.CODE_MAX_PRECISION: + if precision > self._limits.max_precision: raise OutputValidationError( f"Output variable `{variable}` has too high precision," - f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." + f" it must be less than {self._limits.max_precision} digits." ) return value @@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]): # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. # Note that `_transform_result` may produce lists containing `None` values, # which don't conform to the type requirements of `Array*Segment` classes. - if depth > dify_config.CODE_MAX_DEPTH: - raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") + if depth > self._limits.max_depth: + raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.") transformed_result: dict[str, Any] = {} if output_schema is None: @@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]): f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead." ) else: - if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: + if len(value) > self._limits.max_number_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." + f" less than {self._limits.max_number_array_length} elements." ) for i, inner_value in enumerate(value): @@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]): f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: + if len(result[output_name]) > self._limits.max_string_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." + f" less than {self._limits.max_string_array_length} elements." ) transformed_result[output_name] = [ @@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]): f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: + if len(result[output_name]) > self._limits.max_object_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." + f" less than {self._limits.max_object_array_length} elements." ) for i, value in enumerate(result[output_name]): diff --git a/api/core/workflow/nodes/code/limits.py b/api/core/workflow/nodes/code/limits.py new file mode 100644 index 0000000000..a6b9e9e68e --- /dev/null +++ b/api/core/workflow/nodes/code/limits.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class CodeNodeLimits: + max_string_length: int + max_number: int | float + min_number: int | float + max_precision: int + max_depth: int + max_number_array_length: int + max_string_array_length: int + max_object_array_length: int diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index bb2140f42e..925561cf7c 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -301,7 +301,7 @@ class DatasourceNode(Node[DatasourceNodeData]): text = "" files: list[File] = [] - json: list[dict] = [] + json: list[dict | list] = [] variables: dict[str, Any] = {} diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 931c6113a7..429f8411a6 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -17,6 +17,7 @@ from core.helper import ssrf_proxy from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.runtime import VariablePool +from ..protocols import FileManagerProtocol, HttpClientProtocol from .entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, @@ -78,6 +79,8 @@ class Executor: timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, + http_client: HttpClientProtocol = ssrf_proxy, + file_manager: FileManagerProtocol = file_manager, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -104,6 +107,8 @@ class Executor: self.data = None self.json = None self.max_retries = max_retries + self._http_client = http_client + self._file_manager = file_manager # init template self.variable_pool = variable_pool @@ -200,7 +205,7 @@ class Executor: if file_variable is None: raise FileFetchError(f"cannot fetch file with selector {file_selector}") file = file_variable.value - self.content = file_manager.download(file) + self.content = self._file_manager.download(file) case "x-www-form-urlencoded": form_data = { self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( @@ -239,7 +244,7 @@ class Executor: ): file_tuple = ( file.filename, - file_manager.download(file), + self._file_manager.download(file), file.mime_type or "application/octet-stream", ) if key not in files: @@ -332,19 +337,18 @@ class Executor: do http request depending on api bundle """ _METHOD_MAP = { - "get": ssrf_proxy.get, - "head": ssrf_proxy.head, - "post": ssrf_proxy.post, - "put": ssrf_proxy.put, - "delete": ssrf_proxy.delete, - "patch": ssrf_proxy.patch, + "get": self._http_client.get, + "head": self._http_client.head, + "post": self._http_client.post, + "put": self._http_client.put, + "delete": self._http_client.delete, + "patch": self._http_client.patch, } method_lc = self.method.lower() if method_lc not in _METHOD_MAP: raise InvalidHttpMethodError(f"Invalid http method {self.method}") request_args = { - "url": self.url, "data": self.data, "files": self.files, "json": self.json, @@ -357,8 +361,12 @@ class Executor: } # request_args = {k: v for k, v in request_args.items() if v is not None} try: - response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries) - except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: + response: httpx.Response = _METHOD_MAP[method_lc]( + url=self.url, + **request_args, + max_retries=self.max_retries, + ) + except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e: raise HttpRequestNodeError(str(e)) from e # FIXME: fix type ignore, this maybe httpx type issue return response diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 9bd1cb9761..964e53e03c 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,10 +1,11 @@ import logging import mimetypes -from collections.abc import Mapping, Sequence -from typing import Any +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any from configs import dify_config -from core.file import File, FileTransferMethod +from core.file import File, FileTransferMethod, file_manager +from core.helper import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.variables.segments import ArrayFileSegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.http_request.executor import Executor +from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol from factories import file_factory from .entities import ( @@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + class HttpRequestNode(Node[HttpRequestNodeData]): node_type = NodeType.HTTP_REQUEST + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + http_client: HttpClientProtocol = ssrf_proxy, + tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, + file_manager: FileManagerProtocol = file_manager, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._http_client = http_client + self._tool_file_manager_factory = tool_file_manager_factory + self._file_manager = file_manager + @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { @@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]): timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, max_retries=0, + http_client=self._http_client, + file_manager=self._file_manager, ) process_data["request"] = http_executor.to_log() @@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): mime_type = ( content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" ) - tool_file_manager = ToolFileManager() + tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( user_id=self.user_id, diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index e5d86414c1..ced996e7e0 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,17 +1,15 @@ -import contextvars import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, NewType, cast -from flask import Flask, current_app from typing_extensions import TypeIs from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment -from core.variables.variables import VariableUnion +from core.variables.variables import Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import ( NodeExecutionType, @@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.runtime import VariablePool from libs.datetime_utils import naive_utc_now -from libs.flask_utils import preserve_flask_contexts from .exc import ( InvalidIteratorValueError, @@ -51,6 +48,7 @@ from .exc import ( ) if TYPE_CHECKING: + from core.workflow.context import IExecutionContext from core.workflow.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -240,7 +238,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): datetime, list[GraphNodeEventBase], object | None, - dict[str, VariableUnion], + dict[str, Variable], LLMUsage, ] ], @@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self._execute_single_iteration_parallel, index=index, item=item, - flask_app=current_app._get_current_object(), # type: ignore - context_vars=contextvars.copy_context(), + execution_context=self._capture_execution_context(), ) future_to_index[future] = index @@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self, index: int, item: object, - flask_app: Flask, - context_vars: contextvars.Context, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]: + execution_context: "IExecutionContext", + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" - with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): + with execution_context: iter_start_at = datetime.now(UTC).replace(tzinfo=None) events: list[GraphNodeEventBase] = [] outputs_temp: list[object] = [] @@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): graph_engine.graph_runtime_state.llm_usage, ) + def _capture_execution_context(self) -> "IExecutionContext": + """Capture current execution context for parallel iterations.""" + from core.workflow.context import capture_current_context + + return capture_current_context() + def _handle_iteration_success( self, started_at: datetime, @@ -515,11 +517,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return variable_mapping - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]: + def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None: + def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: parent_pool = self.graph_runtime_state.variable_pool parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) @@ -586,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _create_graph_engine(self, index: int, item: object): # Import dependencies + from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState # Create GraphInitParams from node attributes diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index adc474bd60..8670a71aa3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -6,7 +6,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import and_, func, literal, or_, select +from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import DatasetRetrieveConfigEntity @@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if automatic_metadata_filters: conditions = [] for sequence, filter in enumerate(automatic_metadata_filters): - self._process_metadata_filter_func( + DatasetRetrieval.process_metadata_filter_func( sequence, filter.get("condition", ""), filter.get("metadata_name", ""), @@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD value=expected_value, ) ) - filters = self._process_metadata_filter_func( + filters = DatasetRetrieval.process_metadata_filter_func( sequence, condition.comparison_operator, metadata_name, @@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD return [], usage return automatic_metadata_filters, usage - def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any] - ) -> list[Any]: - if value is None and condition not in ("empty", "not empty"): - return filters - - json_field = Document.doc_metadata[metadata_name].as_string() - - match condition: - case "contains": - filters.append(json_field.like(f"%{value}%")) - - case "not contains": - filters.append(json_field.notlike(f"%{value}%")) - - case "start with": - filters.append(json_field.like(f"{value}%")) - - case "end with": - filters.append(json_field.like(f"%{value}")) - case "in": - if isinstance(value, str): - value_list = [v.strip() for v in value.split(",") if v.strip()] - elif isinstance(value, (list, tuple)): - value_list = [str(v) for v in value if v is not None] - else: - value_list = [str(value)] if value is not None else [] - - if not value_list: - filters.append(literal(False)) - else: - filters.append(json_field.in_(value_list)) - - case "not in": - if isinstance(value, str): - value_list = [v.strip() for v in value.split(",") if v.strip()] - elif isinstance(value, (list, tuple)): - value_list = [str(v) for v in value if v is not None] - else: - value_list = [str(value)] if value is not None else [] - - if not value_list: - filters.append(literal(True)) - else: - filters.append(json_field.notin_(value_list)) - - case "is" | "=": - if isinstance(value, str): - filters.append(json_field == value) - elif isinstance(value, (int, float)): - filters.append(Document.doc_metadata[metadata_name].as_float() == value) - - case "is not" | "≠": - if isinstance(value, str): - filters.append(json_field != value) - elif isinstance(value, (int, float)): - filters.append(Document.doc_metadata[metadata_name].as_float() != value) - - case "empty": - filters.append(Document.doc_metadata[metadata_name].is_(None)) - - case "not empty": - filters.append(Document.doc_metadata[metadata_name].isnot(None)) - - case "before" | "<": - filters.append(Document.doc_metadata[metadata_name].as_float() < value) - - case "after" | ">": - filters.append(Document.doc_metadata[metadata_name].as_float() > value) - - case "≤" | "<=": - filters.append(Document.doc_metadata[metadata_name].as_float() <= value) - - case "≥" | ">=": - filters.append(Document.doc_metadata[metadata_name].as_float() >= value) - - case _: - pass - - return filters - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 0c545469bc..01e25cbf5c 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.provider_entities import QuotaUnit +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.file.models import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -136,21 +136,37 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs used_quota = 1 if used_quota is not None and system_configuration.current_quota_type is not None: - with Session(db.engine) as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) + if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, ) - session.execute(stmt) - session.commit() + elif system_configuration.current_quota_type == ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + else: + with Session(db.engine) as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=naive_utc_now(), + ) + ) + session.execute(stmt) + session.commit() diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 04e2802191..dfb55dcd80 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import io import json @@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]): # Instance attributes specific to LLMNode. # Output variable for file - _file_outputs: list["File"] + _file_outputs: list[File] _llm_file_saver: LLMFileSaver @@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]): self, id: str, config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]): structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]): *, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]): ) @staticmethod - def _image_file_to_markdown(file: "File", /): + def _image_file_to_markdown(file: File, /): text_chunk = f"![]({file.generate_url()})" return text_chunk @@ -774,7 +776,7 @@ class LLMNode(Node[LLMNodeData]): def fetch_prompt_messages( *, sys_query: str | None = None, - sys_files: Sequence["File"], + sys_files: Sequence[File], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -785,7 +787,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, - context_files: list["File"] | None = None, + context_files: list[File] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]): *, invoke_result: LLMResult | LLMResultWithStructuredOutput, saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], reasoning_format: Literal["separated", "tagged"] = "tagged", request_latency: float | None = None, ) -> ModelInvokeCompletedEvent: @@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]): *, content: ImagePromptMessageContent, file_saver: LLMFileSaver, - ) -> "File": + ) -> File: """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: @@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]): *, contents: str | list[PromptMessageContentUnionTypes] | None, file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 1f9fc8a115..07d05966cc 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -413,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): def _create_graph_engine(self, start_at: datetime, root_node_id: str): # Import dependencies + from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState # Create GraphInitParams from node attributes diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py deleted file mode 100644 index c55ad346bf..0000000000 --- a/api/core/workflow/nodes/node_factory.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import TYPE_CHECKING, final - -from typing_extensions import override - -from core.workflow.enums import NodeType -from core.workflow.graph import NodeFactory -from core.workflow.nodes.base.node import Node -from libs.typing import is_str, is_str_dict - -from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING - -if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState - - -@final -class DifyNodeFactory(NodeFactory): - """ - Default implementation of NodeFactory that uses the traditional node mapping. - - This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING - and instantiating the appropriate node class. - """ - - def __init__( - self, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ) -> None: - self.graph_init_params = graph_init_params - self.graph_runtime_state = graph_runtime_state - - @override - def create_node(self, node_config: dict[str, object]) -> Node: - """ - Create a Node instance from node configuration data using the traditional mapping. - - :param node_config: node configuration dictionary containing type and other data - :return: initialized Node instance - :raises ValueError: if node type is unknown or configuration is invalid - """ - # Get node_id from config - node_id = node_config.get("id") - if not is_str(node_id): - raise ValueError("Node config missing id") - - # Get node type from config - node_data = node_config.get("data", {}) - if not is_str_dict(node_data): - raise ValueError(f"Node {node_id} missing data information") - - node_type_str = node_data.get("type") - if not is_str(node_type_str): - raise ValueError(f"Node {node_id} missing or invalid type information") - - try: - node_type = NodeType(node_type_str) - except ValueError: - raise ValueError(f"Unknown node type: {node_type_str}") - - # Get node class - node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) - if not node_mapping: - raise ValueError(f"No class mapping found for node type: {node_type}") - - latest_node_class = node_mapping.get(LATEST_VERSION) - node_version = str(node_data.get("version", "1")) - matched_node_class = node_mapping.get(node_version) - node_class = matched_node_class or latest_node_class - if not node_class: - raise ValueError(f"No latest version class found for node type: {node_type}") - - # Create node instance - return node_class( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py new file mode 100644 index 0000000000..e7dcf62fcf --- /dev/null +++ b/api/core/workflow/nodes/protocols.py @@ -0,0 +1,29 @@ +from typing import Protocol + +import httpx + +from core.file import File + + +class HttpClientProtocol(Protocol): + @property + def max_retries_exceeded_error(self) -> type[Exception]: ... + + @property + def request_error(self) -> type[Exception]: ... + + def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + + +class FileManagerProtocol(Protocol): + def download(self, f: File, /) -> bytes: ... diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 36fc5078c5..53c1b4ee6b 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,4 +1,3 @@ -import json from typing import Any from jsonschema import Draft7Validator, ValidationError @@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]): if value is None and variable.required: raise ValueError(f"{key} is required in input form") + # If no value provided, skip further processing for this key + if not value: + continue + + if not isinstance(value, dict): + raise ValueError(f"JSON object for '{key}' must be an object") + + # Overwrite with normalized dict to ensure downstream consistency + node_inputs[key] = value + + # If schema exists, then validate against it schema = variable.json_schema if not schema: continue - if not value: - continue - try: - json_schema = json.loads(schema) - except json.JSONDecodeError as e: - raise ValueError(f"{schema} must be a valid JSON object") - - try: - json_value = json.loads(value) - except json.JSONDecodeError as e: - raise ValueError(f"{value} must be a valid JSON object") - - try: - Draft7Validator(json_schema).validate(json_value) + Draft7Validator(schema).validate(value) except ValidationError as e: raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") - node_inputs[key] = json_value diff --git a/api/core/workflow/nodes/template_transform/template_renderer.py b/api/core/workflow/nodes/template_transform/template_renderer.py new file mode 100644 index 0000000000..a5f06bf2bb --- /dev/null +++ b/api/core/workflow/nodes/template_transform/template_renderer.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Protocol + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage + + +class TemplateRenderError(ValueError): + """Raised when rendering a Jinja2 template fails.""" + + +class Jinja2TemplateRenderer(Protocol): + """Render Jinja2 templates for template transform nodes.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + """Render a Jinja2 template with provided variables.""" + raise NotImplementedError + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Adapter that renders Jinja2 templates via CodeExecutor.""" + + _code_executor: type[CodeExecutor] + + def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None: + self._code_executor = code_executor or CodeExecutor + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = self._code_executor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, code=template, inputs=variables + ) + except CodeExecutionError as exc: + raise TemplateRenderError(str(exc)) from exc + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 2274323960..f7e0bccccf 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,18 +1,44 @@ from collections.abc import Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any from configs import dify_config -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData +from core.workflow.nodes.template_transform.template_renderer import ( + CodeExecutorJinja2TemplateRenderer, + Jinja2TemplateRenderer, + TemplateRenderError, +) + +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH class TemplateTransformNode(Node[TemplateTransformNodeData]): node_type = NodeType.TEMPLATE_TRANSFORM + _template_renderer: Jinja2TemplateRenderer + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + template_renderer: Jinja2TemplateRenderer | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): variables[variable_name] = value.to_object() if value else None # Run code try: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables - ) - except CodeExecutionError as e: + rendered = self._template_renderer.render_template(self.node_data.template, variables) + except TemplateRenderError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, @@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]} + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered} ) @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 2e7ec757b4..68ac60e4f6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]): text = "" files: list[File] = [] - json: list[dict] = [] + json: list[dict | list] = [] variables: dict[str, Any] = {} @@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]): message.message.metadata = dict_metadata # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any]] = [] + json_output: list[dict[str, Any] | list[Any]] = [] # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] if json: diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py deleted file mode 100644 index 050e213535..0000000000 --- a/api/core/workflow/nodes/variable_assigner/common/impl.py +++ /dev/null @@ -1,28 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.variables.variables import Variable -from extensions.ext_database import db -from models import ConversationVariable - -from .exc import VariableOperatorNodeError - - -class ConversationVariableUpdaterImpl: - def update(self, conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableOperatorNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() - - def flush(self): - pass - - -def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: - return ConversationVariableUpdaterImpl() diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index da23207b62..9f5818f4bb 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,9 +1,8 @@ -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, TypeAlias +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, Variable +from core.variables import SegmentType, VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: from core.workflow.runtime import GraphRuntimeState -_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] - - class VariableAssignerNode(Node[VariableAssignerData]): node_type = NodeType.VARIABLE_ASSIGNER - _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY def __init__( self, @@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]): config: Mapping[str, Any], graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, ): super().__init__( id=id, @@ -39,7 +32,15 @@ class VariableAssignerNode(Node[VariableAssignerData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._conv_var_updater_factory = conv_var_updater_factory + + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: + """ + Check if this Variable Assigner node blocks the output of specific variables. + + Returns True if this node updates any of the requested conversation variables. + """ + assigned_selector = tuple(self.node_data.assigned_variable_selector) + return assigned_selector in variable_selectors @classmethod def version(cls) -> str: @@ -72,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, Variable): + if not isinstance(original_variable, VariableBase): raise VariableOperatorNodeError("assigned variable not found") match self.node_data.write_mode: @@ -96,16 +97,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): # Over write the variable. self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - # TODO: Move database operation to the pipeline. - # Update conversation variable. - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) - if not conversation_id: - raise VariableOperatorNodeError("conversation_id not found") - conv_var_updater = self._conv_var_updater_factory() - conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) - conv_var_updater.flush() updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 389fb54d35..5857702e72 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,24 +1,20 @@ import json from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, cast +from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import InvokeFrom -from core.variables import SegmentType, Variable +from core.variables import SegmentType, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem from .enums import InputType, Operation from .exc import ( - ConversationIDNotFoundError, InputTypeNotSupportedError, InvalidDataError, InvalidInputValueError, @@ -26,6 +22,10 @@ from .exc import ( VariableNotFoundError, ) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_node_id = item.variable_selector[0] @@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ class VariableAssignerNode(Node[VariableAssignerNodeData]): node_type = NodeType.VARIABLE_ASSIGNER + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: """ Check if this Variable Assigner node blocks the output of specific variables. @@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): return False - def _conv_var_updater_factory(self) -> ConversationVariableUpdater: - return conversation_variable_updater_factory() - @classmethod def version(cls) -> str: return "2" @@ -107,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): # ==================== Validation Part # Check if variable exists - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=item.variable_selector) # Check if operation is supported @@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): # remove the duplicated items first. updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) - conv_var_updater = self._conv_var_updater_factory() - # Update variables for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value - if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID: - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) - if not conversation_id: - if self.invoke_from != InvokeFrom.DEBUGGER: - raise ConversationIDNotFoundError - else: - conversation_id = conversation_id.value - conv_var_updater.update( - conversation_id=cast(str, conversation_id), - variable=variable, - ) - conv_var_updater.flush() updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors @@ -216,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): def _handle_item( self, *, - variable: Variable, + variable: VariableBase, operation: Operation, value: Any, ): diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/core/workflow/repositories/draft_variable_repository.py index 97bfcd5666..66ef714c16 100644 --- a/api/core/workflow/repositories/draft_variable_repository.py +++ b/api/core/workflow/repositories/draft_variable_repository.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from collections.abc import Mapping from typing import Any, Protocol @@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol): node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, - ) -> "DraftVariableSaver": + ) -> DraftVariableSaver: pass diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index e060cb704d..f79230217c 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -2,6 +2,7 @@ from __future__ import annotations import importlib import json +import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass @@ -211,6 +212,8 @@ class GraphRuntimeState: self._pending_graph_node_states: dict[str, NodeState] | None = None self._pending_graph_edge_states: dict[str, NodeState] | None = None + self.stop_event: threading.Event = threading.Event() + if graph is not None: self.attach_graph(graph) diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py index 5e0878e873..bfbb5ba704 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.model_runtime.entities.llm_entities import LLMUsage @@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView class ReadOnlyVariablePool(Protocol): """Read-only interface for VariablePool.""" - def get(self, node_id: str, variable_key: str) -> Segment | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """Get a variable value (read-only).""" ... diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py index 8539727fd6..d3e4c60d9b 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/core/workflow/runtime/read_only_wrappers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any @@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper: def __init__(self, variable_pool: VariablePool) -> None: self._variable_pool = variable_pool - def get(self, node_id: str, variable_key: str) -> Segment | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """Return a copy of a variable value if present.""" - value = self._variable_pool.get([node_id, variable_key]) + value = self._variable_pool.get(selector) return deepcopy(value) if value is not None else None def get_all_by_node(self, node_id: str) -> Mapping[str, object]: diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 7fbaec9e70..c4b077fa69 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import defaultdict from collections.abc import Mapping, Sequence @@ -7,10 +9,10 @@ from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager -from core.variables import Segment, SegmentGroup, Variable +from core.variables import Segment, SegmentGroup, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import FileSegment, ObjectSegment -from core.variables.variables import RAGPipelineVariableInput, VariableUnion +from core.variables.variables import RAGPipelineVariableInput, Variable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, @@ -30,7 +32,7 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) @@ -42,15 +44,15 @@ class VariablePool(BaseModel): ) system_variables: SystemVariable = Field( description="System variables", - default_factory=SystemVariable.empty, + default_factory=SystemVariable.default, ) - environment_variables: Sequence[VariableUnion] = Field( + environment_variables: Sequence[Variable] = Field( description="Environment variables.", - default_factory=list[VariableUnion], + default_factory=list[Variable], ) - conversation_variables: Sequence[VariableUnion] = Field( + conversation_variables: Sequence[Variable] = Field( description="Conversation variables.", - default_factory=list[VariableUnion], + default_factory=list[Variable], ) rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( description="RAG pipeline variables.", @@ -103,7 +105,7 @@ class VariablePool(BaseModel): f"got {len(selector)} elements" ) - if isinstance(value, Variable): + if isinstance(value, VariableBase): variable = value elif isinstance(value, Segment): variable = variable_factory.segment_to_variable(segment=value, selector=selector) @@ -112,9 +114,9 @@ class VariablePool(BaseModel): variable = variable_factory.segment_to_variable(segment=segment, selector=selector) node_id, name = self._selector_to_keys(selector) - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. + self.variable_dictionary[node_id][name] = cast(Variable, variable) @classmethod def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: @@ -267,6 +269,6 @@ class VariablePool(BaseModel): self.add(selector, value) @classmethod - def empty(cls) -> "VariablePool": + def empty(cls) -> VariablePool: """Create an empty variable pool.""" - return cls(system_variables=SystemVariable.empty()) + return cls(system_variables=SystemVariable.default()) diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py index ad925912a4..6946e3e6ab 100644 --- a/api/core/workflow/system_variable.py +++ b/api/core/workflow/system_variable.py @@ -1,6 +1,9 @@ +from __future__ import annotations + from collections.abc import Mapping, Sequence from types import MappingProxyType from typing import Any +from uuid import uuid4 from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator @@ -70,8 +73,8 @@ class SystemVariable(BaseModel): return data @classmethod - def empty(cls) -> "SystemVariable": - return cls() + def default(cls) -> SystemVariable: + return cls(workflow_execution_id=str(uuid4())) def to_dict(self) -> dict[SystemVariableKey, Any]: # NOTE: This method is provided for compatibility with legacy code. @@ -114,7 +117,7 @@ class SystemVariable(BaseModel): d[SystemVariableKey.TIMESTAMP] = self.timestamp return d - def as_view(self) -> "SystemVariableReadOnlyView": + def as_view(self) -> SystemVariableReadOnlyView: return SystemVariableReadOnlyView(self) diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index ea0bdc3537..7992785fe1 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -2,7 +2,7 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.variables import Variable +from core.variables import VariableBase from core.variables.consts import SELECTORS_LENGTH from core.workflow.runtime import VariablePool @@ -26,7 +26,7 @@ class VariableLoader(Protocol): """ @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: """Load variables based on the provided selectors. If the selectors are empty, this method should return an empty list. @@ -36,7 +36,7 @@ class VariableLoader(Protocol): :param: selectors: a list of string list, each inner list should have at least two elements: - the first element is the node ID, - the second element is the variable name. - :return: a list of Variable objects that match the provided selectors. + :return: a list of VariableBase objects that match the provided selectors. """ pass @@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader): Serves as a placeholder when no variable loading is needed. """ - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: return [] diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index ddf545bb34..b645f29d27 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -7,6 +7,8 @@ from typing import Any from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.layers.observability import ObservabilityLayer +from core.app.workflow.node_factory import DifyNodeFactory from core.file.models import File from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams @@ -14,7 +16,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer +from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.nodes import NodeType @@ -136,13 +138,11 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - node_config = workflow.get_node_config_by_id(node_id) + node_config = dict(workflow.get_node_config_by_id(node_id)) node_config_data = node_config.get("data", {}) - # Get node class + # Get node type node_type = NodeType(node_config_data.get("type")) - node_version = node_config_data.get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -158,12 +158,12 @@ class WorkflowEntry: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state - node = node_cls( - id=str(uuid.uuid4()), - config=node_config, + node_factory = DifyNodeFactory( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) + node = node_factory.create_node(node_config) + node_cls = type(node) try: # variable selector to variable mapping @@ -190,8 +190,7 @@ class WorkflowEntry: ) try: - # run node - generator = node.run() + generator = cls._traced_node_run(node) except Exception as e: logger.exception( "error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s", @@ -278,7 +277,7 @@ class WorkflowEntry: # init variable pool variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, environment_variables=[], ) @@ -324,8 +323,7 @@ class WorkflowEntry: tenant_id=tenant_id, ) - # run node - generator = node.run() + generator = cls._traced_node_run(node) return node, generator except Exception as e: @@ -431,3 +429,26 @@ class WorkflowEntry: input_value = current_variable.value | input_value variable_pool.add([variable_node_id] + variable_key_list, input_value) + + @staticmethod + def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]: + """ + Wraps a node's run method with OpenTelemetry tracing and returns a generator. + """ + # Wrap node.run() with ObservabilityLayer hooks to produce node-level spans + layer = ObservabilityLayer() + layer.on_graph_start() + node.ensure_execution_id() + + def _gen(): + error: Exception | None = None + layer.on_node_run_start(node) + try: + yield from node.run() + except Exception as exc: + error = exc + raise + finally: + layer.on_node_run_end(node, error) + + return _gen() diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 470d75d352..7c2b1ba181 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -3,8 +3,9 @@ set -e # Set UTF-8 encoding to address potential encoding issues in containerized environments -export LANG=${LANG:-en_US.UTF-8} -export LC_ALL=${LC_ALL:-en_US.UTF-8} +# Use C.UTF-8 which is universally available in all containers +export LANG=${LANG:-C.UTF-8} +export LC_ALL=${LC_ALL:-C.UTF-8} export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8} if [[ "${MIGRATION_ENABLED}" == "true" ]]; then diff --git a/api/enums/hosted_provider.py b/api/enums/hosted_provider.py new file mode 100644 index 0000000000..c6d3715dc1 --- /dev/null +++ b/api/enums/hosted_provider.py @@ -0,0 +1,21 @@ +from enum import StrEnum + + +class HostedTrialProvider(StrEnum): + """ + Enum representing hosted model provider names for trial access. + """ + + OPENAI = "langgenius/openai/openai" + ANTHROPIC = "langgenius/anthropic/anthropic" + GEMINI = "langgenius/gemini/google" + X = "langgenius/x/x" + DEEPSEEK = "langgenius/deepseek/deepseek" + TONGYI = "langgenius/tongyi/tongyi" + + @property + def config_key(self) -> str: + """Return the config key used in dify_config (e.g., HOSTED_{config_key}_PAID_ENABLED).""" + if self == HostedTrialProvider.X: + return "XAI" + return self.name diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index c79764983b..d37217e168 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re from .delete_tool_parameters_cache_when_sync_draft_workflow import ( handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow, ) +from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published @@ -30,6 +31,7 @@ __all__ = [ "handle_create_installed_app_when_app_created", "handle_create_site_record_when_app_created", "handle_delete_tool_parameters_cache_when_sync_draft_workflow", + "handle_queue_credential_sync_when_tenant_created", "handle_sync_plugin_trigger_when_app_created", "handle_sync_webhook_when_app_created", "handle_sync_workflow_schedule_when_app_published", diff --git a/api/events/event_handlers/queue_credential_sync_when_tenant_created.py b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py new file mode 100644 index 0000000000..6566c214b0 --- /dev/null +++ b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py @@ -0,0 +1,19 @@ +from configs import dify_config +from events.tenant_event import tenant_was_created +from services.enterprise.workspace_sync import WorkspaceSyncService + + +@tenant_was_created.connect +def handle(sender, **kwargs): + """Queue credential sync when a tenant/workspace is created.""" + # Only queue sync tasks if plugin manager (enterprise feature) is enabled + if not dify_config.ENTERPRISE_ENABLED: + return + + tenant = sender + + # Determine source from kwargs if available, otherwise use generic + source = kwargs.get("source", "tenant_created") + + # Queue credential sync task to Redis for enterprise backend to process + WorkspaceSyncService.queue_credential_sync(tenant.id, source=source) diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 84266ab0fa..1ddcc8f792 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity -from core.entities.provider_entities import QuotaUnit, SystemConfiguration +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration from events.message_event import message_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback @@ -134,22 +134,38 @@ def handle(sender: Message, **kwargs): system_configuration=system_configuration, model_name=model_config.model, ) - if used_quota is not None: - quota_update = _ProviderUpdateOperation( - filters=_ProviderUpdateFilters( + if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( tenant_id=tenant_id, - provider_name=ModelProviderID(model_config.provider).provider_name, - provider_type=ProviderType.SYSTEM, - quota_type=provider_configuration.system_configuration.current_quota_type.value, - ), - values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), - additional_filters=_ProviderUpdateAdditionalFilters( - quota_limit_check=True # Provider.quota_limit > Provider.quota_used - ), - description="quota_deduction_update", - ) - updates_to_perform.append(quota_update) + credits_required=used_quota, + pool_type="trial", + ) + elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + else: + quota_update = _ProviderUpdateOperation( + filters=_ProviderUpdateFilters( + tenant_id=tenant_id, + provider_name=ModelProviderID(model_config.provider).provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=provider_configuration.system_configuration.current_quota_type.value, + ), + values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), + additional_filters=_ProviderUpdateAdditionalFilters( + quota_limit_check=True # Provider.quota_limit > Provider.quota_used + ), + description="quota_deduction_update", + ) + updates_to_perform.append(quota_update) # Execute all updates start_time = time_module.perf_counter() diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index cf994c11df..7d13f0c061 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization") AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN) FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) +EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE) EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") @@ -42,10 +43,28 @@ def init_app(app: DifyApp): _apply_cors_once( web_bp, - resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, - supports_credentials=True, - allow_headers=list(AUTHENTICATED_HEADERS), - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + resources={ + # Embedded bot endpoints (unauthenticated, cross-origin safe) + r"^/chat-messages$": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": False, + "allow_headers": list(EMBED_HEADERS), + "methods": ["GET", "POST", "OPTIONS"], + }, + r"^/chat-messages/.*": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": False, + "allow_headers": list(EMBED_HEADERS), + "methods": ["GET", "POST", "OPTIONS"], + }, + # Default web application endpoints (authenticated) + r"/*": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": True, + "allow_headers": list(AUTHENTICATED_HEADERS), + "methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + }, + }, expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(web_bp) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 9d042a291a..decf8655da 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -12,9 +12,8 @@ from dify_app import DifyApp def _get_celery_ssl_options() -> dict[str, Any] | None: """Get SSL configuration for Celery broker/backend connections.""" - # Use REDIS_USE_SSL for consistency with the main Redis client # Only apply SSL if we're using Redis as broker/backend - if not dify_config.REDIS_USE_SSL: + if not dify_config.BROKER_USE_SSL: return None # Check if Celery is actually using Redis @@ -47,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None: def init_app(app: DifyApp) -> Celery: class FlaskTask(Task): def __call__(self, *args: object, **kwargs: object) -> object: + from core.logging.context import init_request_context + with app.app_context(): + # Initialize logging context for this task (similar to before_request in Flask) + init_request_context() return self.run(*args, **kwargs) broker_transport_options = {} @@ -166,6 +169,13 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise", "schedule": crontab(minute="0", hour="2"), } + if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK: + # for saas only + imports.append("schedule.clean_workflow_runs_task") + beat_schedule["clean_workflow_runs_task"] = { + "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task", + "schedule": crontab(minute="0", hour="0"), + } if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: imports.append("schedule.workflow_schedule_task") beat_schedule["workflow_schedule_task"] = { diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 71a63168a5..46885761a1 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,13 +4,18 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + archive_workflow_runs, + clean_expired_messages, + clean_workflow_runs, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, convert_to_agent_apps, create_tenant, + delete_archived_workflow_runs, extract_plugins, extract_unique_plugins, + file_usage, fix_app_site_missing, install_plugins, install_rag_pipeline_plugins, @@ -21,6 +26,7 @@ def init_app(app: DifyApp): reset_email, reset_encrypt_key_pair, reset_password, + restore_workflow_runs, setup_datasource_oauth_client, setup_system_tool_oauth_client, setup_system_trigger_oauth_client, @@ -47,6 +53,7 @@ def init_app(app: DifyApp): clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, remove_orphaned_files_on_storage, + file_usage, setup_system_tool_oauth_client, setup_system_trigger_oauth_client, cleanup_orphaned_draft_variables, @@ -54,6 +61,11 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + archive_workflow_runs, + delete_archived_workflow_runs, + restore_workflow_runs, + clean_workflow_runs, + clean_expired_messages, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index c90b1d0a9f..2e0d4c889a 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -53,3 +53,10 @@ def _setup_gevent_compatibility(): def init_app(app: DifyApp): db.init_app(app) _setup_gevent_compatibility() + + # Eagerly build the engine so pool_size/max_overflow/etc. come from config + try: + with app.app_context(): + _ = db.engine # triggers engine creation with the configured options + except Exception: + logger.exception("Failed to initialize SQLAlchemy engine during app startup") diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py new file mode 100644 index 0000000000..e6c1bc6bee --- /dev/null +++ b/api/extensions/ext_fastopenapi.py @@ -0,0 +1,45 @@ +from fastopenapi.routers import FlaskRouter +from flask_cors import CORS + +from configs import dify_config +from controllers.fastopenapi import console_router +from dify_app import DifyApp +from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS + +DOCS_PREFIX = "/fastopenapi" + + +def init_app(app: DifyApp) -> None: + docs_enabled = dify_config.SWAGGER_UI_ENABLED + docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None + redoc_url = f"{DOCS_PREFIX}/redoc" if docs_enabled else None + openapi_url = f"{DOCS_PREFIX}/openapi.json" if docs_enabled else None + + router = FlaskRouter( + app=app, + docs_url=docs_url, + redoc_url=redoc_url, + openapi_url=openapi_url, + openapi_version="3.0.0", + title="Dify API (FastOpenAPI PoC)", + version="1.0", + description="FastOpenAPI proof of concept for Dify API", + ) + + # Ensure route decorators are evaluated. + import controllers.console.ping as ping_module + from controllers.console import setup + + _ = ping_module + _ = setup + + router.include_router(console_router, prefix="/console/api") + CORS( + app, + resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, + supports_credentials=True, + allow_headers=list(AUTHENTICATED_HEADERS), + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=list(EXPOSED_HEADERS), + ) + app.extensions["fastopenapi"] = router diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 000d03ac41..978a40c503 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -1,18 +1,19 @@ +"""Logging extension for Dify Flask application.""" + import logging import os import sys -import uuid from logging.handlers import RotatingFileHandler -import flask - from configs import dify_config -from core.helper.trace_id_helper import get_trace_id_from_otel_context from dify_app import DifyApp def init_app(app: DifyApp): + """Initialize logging with support for text or JSON format.""" log_handlers: list[logging.Handler] = [] + + # File handler log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) @@ -25,27 +26,53 @@ def init_app(app: DifyApp): ) ) - # Always add StreamHandler to log to console + # Console handler sh = logging.StreamHandler(sys.stdout) log_handlers.append(sh) - # Apply RequestIdFilter to all handlers - for handler in log_handlers: - handler.addFilter(RequestIdFilter()) + # Apply filters to all handlers + from core.logging.filters import IdentityContextFilter, TraceContextFilter + for handler in log_handlers: + handler.addFilter(TraceContextFilter()) + handler.addFilter(IdentityContextFilter()) + + # Configure formatter based on format type + formatter = _create_formatter() + for handler in log_handlers: + handler.setFormatter(formatter) + + # Configure root logger logging.basicConfig( level=dify_config.LOG_LEVEL, - format=dify_config.LOG_FORMAT, - datefmt=dify_config.LOG_DATEFORMAT, handlers=log_handlers, force=True, ) - # Apply RequestIdFormatter to all handlers - apply_request_id_formatter() - # Disable propagation for noisy loggers to avoid duplicate logs logging.getLogger("sqlalchemy.engine").propagate = False + + # Apply timezone if specified (only for text format) + if dify_config.LOG_OUTPUT_FORMAT == "text": + _apply_timezone(log_handlers) + + +def _create_formatter() -> logging.Formatter: + """Create appropriate formatter based on configuration.""" + if dify_config.LOG_OUTPUT_FORMAT == "json": + from core.logging.structured_formatter import StructuredJSONFormatter + + return StructuredJSONFormatter() + else: + # Text format - use existing pattern with backward compatible formatter + return _TextFormatter( + fmt=dify_config.LOG_FORMAT, + datefmt=dify_config.LOG_DATEFORMAT, + ) + + +def _apply_timezone(handlers: list[logging.Handler]): + """Apply timezone conversion to text formatters.""" log_tz = dify_config.LOG_TZ if log_tz: from datetime import datetime @@ -57,34 +84,51 @@ def init_app(app: DifyApp): def time_converter(seconds): return datetime.fromtimestamp(seconds, tz=timezone).timetuple() - for handler in logging.root.handlers: + for handler in handlers: if handler.formatter: - handler.formatter.converter = time_converter + handler.formatter.converter = time_converter # type: ignore[attr-defined] -def get_request_id(): - if getattr(flask.g, "request_id", None): - return flask.g.request_id +class _TextFormatter(logging.Formatter): + """Text formatter that ensures trace_id and req_id are always present.""" - new_uuid = uuid.uuid4().hex[:10] - flask.g.request_id = new_uuid - - return new_uuid + def format(self, record: logging.LogRecord) -> str: + if not hasattr(record, "req_id"): + record.req_id = "" + if not hasattr(record, "trace_id"): + record.trace_id = "" + if not hasattr(record, "span_id"): + record.span_id = "" + return super().format(record) +def get_request_id() -> str: + """Get request ID for current request context. + + Deprecated: Use core.logging.context.get_request_id() directly. + """ + from core.logging.context import get_request_id as _get_request_id + + return _get_request_id() + + +# Backward compatibility aliases class RequestIdFilter(logging.Filter): - # This is a logging filter that makes the request ID available for use in - # the logging format. Note that we're checking if we're in a request - # context, as we may want to log things before Flask is fully loaded. - def filter(self, record): - trace_id = get_trace_id_from_otel_context() or "" - record.req_id = get_request_id() if flask.has_request_context() else "" - record.trace_id = trace_id + """Deprecated: Use TraceContextFilter from core.logging.filters instead.""" + + def filter(self, record: logging.LogRecord) -> bool: + from core.logging.context import get_request_id as _get_request_id + from core.logging.context import get_trace_id as _get_trace_id + + record.req_id = _get_request_id() + record.trace_id = _get_trace_id() return True class RequestIdFormatter(logging.Formatter): - def format(self, record): + """Deprecated: Use _TextFormatter instead.""" + + def format(self, record: logging.LogRecord) -> str: if not hasattr(record, "req_id"): record.req_id = "" if not hasattr(record, "trace_id"): @@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter): def apply_request_id_formatter(): + """Deprecated: Formatter is now applied in init_app.""" for handler in logging.root.handlers: if handler.formatter: handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT) diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py index 502f0bb46b..cda2d1ad1e 100644 --- a/api/extensions/ext_logstore.py +++ b/api/extensions/ext_logstore.py @@ -10,6 +10,7 @@ import os from dotenv import load_dotenv +from configs import dify_config from dify_app import DifyApp logger = logging.getLogger(__name__) @@ -19,12 +20,17 @@ def is_enabled() -> bool: """ Check if logstore extension is enabled. + Logstore is considered enabled when: + 1. All required Aliyun SLS environment variables are set + 2. At least one repository configuration points to a logstore implementation + Returns: - True if all required Aliyun SLS environment variables are set, False otherwise + True if logstore should be initialized, False otherwise """ # Load environment variables from .env file load_dotenv() + # Check if Aliyun SLS connection parameters are configured required_vars = [ "ALIYUN_SLS_ACCESS_KEY_ID", "ALIYUN_SLS_ACCESS_KEY_SECRET", @@ -33,24 +39,32 @@ def is_enabled() -> bool: "ALIYUN_SLS_PROJECT_NAME", ] - all_set = all(os.environ.get(var) for var in required_vars) + sls_vars_set = all(os.environ.get(var) for var in required_vars) - if not all_set: - logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set") + if not sls_vars_set: + return False - return all_set + # Check if any repository configuration points to logstore implementation + repository_configs = [ + dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY, + dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY, + dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY, + dify_config.API_WORKFLOW_RUN_REPOSITORY, + ] + + uses_logstore = any("logstore" in config.lower() for config in repository_configs) + + if not uses_logstore: + return False + + logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore") + return True def init_app(app: DifyApp): """ Initialize logstore on application startup. - - This function: - 1. Creates Aliyun SLS project if it doesn't exist - 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist - 3. Creates indexes with field configurations based on PostgreSQL table structures - - This operation is idempotent and only executes once during application startup. + If initialization fails, the application continues running without logstore features. Args: app: The Dify application instance @@ -58,17 +72,23 @@ def init_app(app: DifyApp): try: from extensions.logstore.aliyun_logstore import AliyunLogStore - logger.info("Initializing logstore...") + logger.info("Initializing Aliyun SLS Logstore...") - # Create logstore client and initialize project/logstores/indexes + # Create logstore client and initialize resources logstore_client = AliyunLogStore() logstore_client.init_project_logstore() - # Attach to app for potential later use app.extensions["logstore"] = logstore_client logger.info("Logstore initialized successfully") + except Exception: - logger.exception("Failed to initialize logstore") - # Don't raise - allow application to continue even if logstore init fails - # This ensures that the application can still run if logstore is misconfigured + logger.exception( + "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. " + "Application will continue but logstore features will NOT work.", + os.environ.get("ALIYUN_SLS_ENDPOINT"), + os.environ.get("ALIYUN_SLS_REGION"), + os.environ.get("ALIYUN_SLS_PROJECT_NAME"), + os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"), + ) + # Don't raise - allow application to continue even if logstore setup fails diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py index 22d1f473a3..f6a4765f14 100644 --- a/api/extensions/logstore/aliyun_logstore.py +++ b/api/extensions/logstore/aliyun_logstore.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging import os +import socket import threading import time from collections.abc import Sequence @@ -33,7 +36,7 @@ class AliyunLogStore: Ensures only one instance exists to prevent multiple PG connection pools. """ - _instance: "AliyunLogStore | None" = None + _instance: AliyunLogStore | None = None _initialized: bool = False # Track delayed PG connection for newly created projects @@ -66,7 +69,7 @@ class AliyunLogStore: "\t", ] - def __new__(cls) -> "AliyunLogStore": + def __new__(cls) -> AliyunLogStore: """Implement singleton pattern.""" if cls._instance is None: cls._instance = super().__new__(cls) @@ -177,9 +180,18 @@ class AliyunLogStore: self.region: str = os.environ.get("ALIYUN_SLS_REGION", "") self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "") self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365)) - self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + self.log_enabled: bool = ( + os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true" + ) self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true" + # Get timeout configuration + check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30)) + + # Pre-check endpoint connectivity to prevent indefinite hangs + self._check_endpoint_connectivity(self.endpoint, check_timeout) + # Initialize SDK client self.client = LogClient( self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region @@ -197,6 +209,49 @@ class AliyunLogStore: self.__class__._initialized = True + @staticmethod + def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None: + """ + Check if the SLS endpoint is reachable before creating LogClient. + Prevents indefinite hangs when the endpoint is unreachable. + + Args: + endpoint: SLS endpoint URL + timeout: Connection timeout in seconds + + Raises: + ConnectionError: If endpoint is not reachable + """ + # Parse endpoint URL to extract hostname and port + from urllib.parse import urlparse + + parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}") + hostname = parsed_url.hostname + port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80) + + if not hostname: + raise ConnectionError(f"Invalid endpoint URL: {endpoint}") + + sock = None + try: + # Create socket and set timeout + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.connect((hostname, port)) + except Exception as e: + # Catch all exceptions and provide clear error message + error_type = type(e).__name__ + raise ConnectionError( + f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}" + ) from e + finally: + # Ensure socket is properly closed + if sock: + try: + sock.close() + except Exception: # noqa: S110 + pass # Ignore errors during cleanup + @property def supports_pg_protocol(self) -> bool: """Check if PG protocol is supported and enabled.""" @@ -218,19 +273,16 @@ class AliyunLogStore: try: self._use_pg_protocol = self._pg_client.init_connection() if self._use_pg_protocol: - logger.info("Successfully connected to project %s using PG protocol", self.project_name) + logger.info("Using PG protocol for project %s", self.project_name) # Check if scan_index is enabled for all logstores self._check_and_disable_pg_if_scan_index_disabled() return True else: - logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name) + logger.info("Using SDK mode for project %s", self.project_name) return False except Exception as e: - logger.warning( - "Failed to establish PG connection for project %s: %s. Will use SDK mode.", - self.project_name, - str(e), - ) + logger.info("Using SDK mode for project %s", self.project_name) + logger.debug("PG connection details: %s", str(e)) self._use_pg_protocol = False return False @@ -244,10 +296,6 @@ class AliyunLogStore: if self._use_pg_protocol: return - logger.info( - "Attempting delayed PG connection for newly created project %s ...", - self.project_name, - ) self._attempt_pg_connection_init() self.__class__._pg_connection_timer = None @@ -282,11 +330,7 @@ class AliyunLogStore: if project_is_new: # For newly created projects, schedule delayed PG connection self._use_pg_protocol = False - logger.info( - "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.", - self.project_name, - self.__class__._pg_connection_delay, - ) + logger.info("Using SDK mode for project %s (newly created)", self.project_name) if self.__class__._pg_connection_timer is not None: self.__class__._pg_connection_timer.cancel() self.__class__._pg_connection_timer = threading.Timer( @@ -297,7 +341,6 @@ class AliyunLogStore: self.__class__._pg_connection_timer.start() else: # For existing projects, attempt PG connection immediately - logger.info("Project %s already exists. Attempting PG connection...", self.project_name) self._attempt_pg_connection_init() def _check_and_disable_pg_if_scan_index_disabled(self) -> None: @@ -316,9 +359,9 @@ class AliyunLogStore: existing_config = self.get_existing_index_config(logstore_name) if existing_config and not existing_config.scan_index: logger.info( - "Logstore %s has scan_index=false, USE SDK mode for read/write operations. " - "PG protocol requires scan_index to be enabled.", + "Logstore %s requires scan_index enabled, using SDK mode for project %s", logstore_name, + self.project_name, ) self._use_pg_protocol = False # Close PG connection if it was initialized @@ -746,7 +789,6 @@ class AliyunLogStore: reverse=reverse, ) - # Log query info if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | " @@ -768,7 +810,6 @@ class AliyunLogStore: for log in logs: result.append(log.get_contents()) - # Log result count if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d", @@ -843,7 +884,6 @@ class AliyunLogStore: query=full_query, ) - # Log query info if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s", @@ -851,8 +891,7 @@ class AliyunLogStore: self.project_name, from_time, to_time, - query, - sql, + full_query, ) try: @@ -863,7 +902,6 @@ class AliyunLogStore: for log in logs: result.append(log.get_contents()) - # Log result count if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py index 35aa51ce53..874c20d144 100644 --- a/api/extensions/logstore/aliyun_logstore_pg.py +++ b/api/extensions/logstore/aliyun_logstore_pg.py @@ -7,8 +7,7 @@ from contextlib import contextmanager from typing import Any import psycopg2 -import psycopg2.pool -from psycopg2 import InterfaceError, OperationalError +from sqlalchemy import create_engine from configs import dify_config @@ -16,11 +15,7 @@ logger = logging.getLogger(__name__) class AliyunLogStorePG: - """ - PostgreSQL protocol support for Aliyun SLS LogStore. - - Handles PG connection pooling and operations for regions that support PG protocol. - """ + """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool.""" def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str): """ @@ -36,24 +31,11 @@ class AliyunLogStorePG: self._access_key_secret = access_key_secret self._endpoint = endpoint self.project_name = project_name - self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None + self._engine: Any = None # SQLAlchemy Engine self._use_pg_protocol = False def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool: - """ - Check if a TCP port is reachable using socket connection. - - This provides a fast check before attempting full database connection, - preventing long waits when connecting to unsupported regions. - - Args: - host: Hostname or IP address - port: Port number - timeout: Connection timeout in seconds (default: 2.0) - - Returns: - True if port is reachable, False otherwise - """ + """Fast TCP port check to avoid long waits on unsupported regions.""" try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) @@ -65,166 +47,101 @@ class AliyunLogStorePG: return False def init_connection(self) -> bool: - """ - Initialize PostgreSQL connection pool for SLS PG protocol support. - - Attempts to connect to SLS using PostgreSQL protocol. If successful, sets - _use_pg_protocol to True and creates a connection pool. If connection fails - (region doesn't support PG protocol or other errors), returns False. - - Returns: - True if PG protocol is supported and initialized, False otherwise - """ + """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support.""" try: - # Extract hostname from endpoint (remove protocol if present) pg_host = self._endpoint.replace("http://", "").replace("https://", "") - # Get pool configuration - pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10)) + # Pool configuration + pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5)) + max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5)) + pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600)) + pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true" - logger.debug( - "Check PG protocol connection to SLS: host=%s, project=%s", - pg_host, - self.project_name, - ) + logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name) - # Fast port connectivity check before attempting full connection - # This prevents long waits when connecting to unsupported regions + # Fast port check to avoid long waits if not self._check_port_connectivity(pg_host, 5432, timeout=1.0): - logger.info( - "USE SDK mode for read/write operations, host=%s", - pg_host, - ) + logger.debug("Using SDK mode for host=%s", pg_host) return False - # Create connection pool - self._pg_pool = psycopg2.pool.SimpleConnectionPool( - minconn=1, - maxconn=pg_max_connections, - host=pg_host, - port=5432, - database=self.project_name, - user=self._access_key_id, - password=self._access_key_secret, - sslmode="require", - connect_timeout=5, - application_name=f"Dify-{dify_config.project.version}", + # Build connection URL + from urllib.parse import quote_plus + + username = quote_plus(self._access_key_id) + password = quote_plus(self._access_key_secret) + database_url = ( + f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require" ) - # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables - # Connection pool creation success already indicates connectivity + # Create SQLAlchemy engine with connection pool + self._engine = create_engine( + database_url, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + pool_timeout=30, + connect_args={ + "connect_timeout": 5, + "application_name": f"Dify-{dify_config.project.version}-fixautocommit", + "keepalives": 1, + "keepalives_idle": 60, + "keepalives_interval": 10, + "keepalives_count": 5, + }, + ) self._use_pg_protocol = True logger.info( - "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.", + "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)", self.project_name, + pool_size, + pool_recycle, ) return True except Exception as e: - # PG connection failed - fallback to SDK mode self._use_pg_protocol = False - if self._pg_pool: + if self._engine: try: - self._pg_pool.closeall() + self._engine.dispose() except Exception: - logger.debug("Failed to close PG connection pool during cleanup, ignoring") - self._pg_pool = None + logger.debug("Failed to dispose engine during cleanup, ignoring") + self._engine = None - logger.info( - "PG protocol connection failed (region may not support PG protocol): %s. " - "Falling back to SDK mode for read/write operations.", - str(e), - ) - return False - - def _is_connection_valid(self, conn: Any) -> bool: - """ - Check if a connection is still valid. - - Args: - conn: psycopg2 connection object - - Returns: - True if connection is valid, False otherwise - """ - try: - # Check if connection is closed - if conn.closed: - return False - - # Quick ping test - execute a lightweight query - # For SLS PG protocol, we can't use SELECT 1 without FROM, - # so we just check the connection status - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - cursor.fetchone() - return True - except Exception: + logger.debug("Using SDK mode for region: %s", str(e)) return False @contextmanager def _get_connection(self): - """ - Context manager to get a PostgreSQL connection from the pool. + """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically.""" + if not self._engine: + raise RuntimeError("SQLAlchemy engine is not initialized") - Automatically validates and refreshes stale connections. - - Note: Aliyun SLS PG protocol does not support transactions, so we always - use autocommit mode. - - Yields: - psycopg2 connection object - - Raises: - RuntimeError: If PG pool is not initialized - """ - if not self._pg_pool: - raise RuntimeError("PG connection pool is not initialized") - - conn = self._pg_pool.getconn() + connection = self._engine.raw_connection() try: - # Validate connection and get a fresh one if needed - if not self._is_connection_valid(conn): - logger.debug("Connection is stale, marking as bad and getting a new one") - # Mark connection as bad and get a new one - self._pg_pool.putconn(conn, close=True) - conn = self._pg_pool.getconn() - - # Aliyun SLS PG protocol does not support transactions, always use autocommit - conn.autocommit = True - yield conn + connection.autocommit = True # SLS PG protocol does not support transactions + yield connection + except Exception: + raise finally: - # Return connection to pool (or close if it's bad) - if self._is_connection_valid(conn): - self._pg_pool.putconn(conn) - else: - self._pg_pool.putconn(conn, close=True) + connection.close() def close(self) -> None: - """Close the PostgreSQL connection pool.""" - if self._pg_pool: + """Dispose SQLAlchemy engine and close all connections.""" + if self._engine: try: - self._pg_pool.closeall() - logger.info("PG connection pool closed") + self._engine.dispose() + logger.info("SQLAlchemy engine disposed") except Exception: - logger.exception("Failed to close PG connection pool") + logger.exception("Failed to dispose engine") def _is_retriable_error(self, error: Exception) -> bool: - """ - Check if an error is retriable (connection-related issues). - - Args: - error: Exception to check - - Returns: - True if the error is retriable, False otherwise - """ - # Retry on connection-related errors - if isinstance(error, (OperationalError, InterfaceError)): + """Check if error is retriable (connection-related issues).""" + # Check for psycopg2 connection errors directly + if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)): return True - # Check error message for specific connection issues error_msg = str(error).lower() retriable_patterns = [ "connection", @@ -234,34 +151,18 @@ class AliyunLogStorePG: "reset by peer", "no route to host", "network", + "operational error", + "interface error", ] return any(pattern in error_msg for pattern in retriable_patterns) def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None: - """ - Write log to SLS using PostgreSQL protocol with automatic retry. - - Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only - writes with log_version field for versioning, same as SDK implementation. - - Args: - logstore: Name of the logstore table - contents: List of (field_name, value) tuples - log_enabled: Whether to enable logging - - Raises: - psycopg2.Error: If database operation fails after all retries - """ + """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff).""" if not contents: return - # Extract field names and values from contents fields = [field_name for field_name, _ in contents] values = [value for _, value in contents] - - # Build INSERT statement with literal values - # Note: Aliyun SLS PG protocol doesn't support parameterized queries, - # so we need to use mogrify to safely create literal values field_list = ", ".join([f'"{field}"' for field in fields]) if log_enabled: @@ -272,67 +173,40 @@ class AliyunLogStorePG: len(contents), ) - # Retry configuration max_retries = 3 - retry_delay = 0.1 # Start with 100ms + retry_delay = 0.1 for attempt in range(max_retries): try: with self._get_connection() as conn: with conn.cursor() as cursor: - # Use mogrify to safely convert values to SQL literals placeholders = ", ".join(["%s"] * len(fields)) values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8") insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}' cursor.execute(insert_sql) - # Success - exit retry loop return except psycopg2.Error as e: - # Check if error is retriable if not self._is_retriable_error(e): - # Not a retriable error (e.g., data validation error), fail immediately - logger.exception( - "Failed to put logs to logstore %s via PG protocol (non-retriable error)", - logstore, - ) + logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore) raise - # Retriable error - log and retry if we have attempts left if attempt < max_retries - 1: logger.warning( - "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...", logstore, attempt + 1, max_retries, str(e), ) time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + retry_delay *= 2 else: - # Last attempt failed - logger.exception( - "Failed to put logs to logstore %s via PG protocol after %d attempts", - logstore, - max_retries, - ) + logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries) raise def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]: - """ - Execute SQL query using PostgreSQL protocol with automatic retry. - - Args: - sql: SQL query string - logstore: Name of the logstore (for logging purposes) - log_enabled: Whether to enable logging - - Returns: - List of result rows as dictionaries - - Raises: - psycopg2.Error: If database operation fails after all retries - """ + """Execute SQL query with automatic retry (3 attempts with exponential backoff).""" if log_enabled: logger.info( "[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s", @@ -341,20 +215,16 @@ class AliyunLogStorePG: sql, ) - # Retry configuration max_retries = 3 - retry_delay = 0.1 # Start with 100ms + retry_delay = 0.1 for attempt in range(max_retries): try: with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(sql) - - # Get column names from cursor description columns = [desc[0] for desc in cursor.description] - # Fetch all results and convert to list of dicts result = [] for row in cursor.fetchall(): row_dict = {} @@ -372,36 +242,31 @@ class AliyunLogStorePG: return result except psycopg2.Error as e: - # Check if error is retriable if not self._is_retriable_error(e): - # Not a retriable error (e.g., SQL syntax error), fail immediately logger.exception( - "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s", + "Failed to execute SQL on logstore %s (non-retriable error): sql=%s", logstore, sql, ) raise - # Retriable error - log and retry if we have attempts left if attempt < max_retries - 1: logger.warning( - "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...", logstore, attempt + 1, max_retries, str(e), ) time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + retry_delay *= 2 else: - # Last attempt failed logger.exception( - "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s", + "Failed to execute SQL on logstore %s after %d attempts: sql=%s", logstore, max_retries, sql, ) raise - # This line should never be reached due to raise above, but makes type checker happy return [] diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py index e69de29bb2..b5a4fcf844 100644 --- a/api/extensions/logstore/repositories/__init__.py +++ b/api/extensions/logstore/repositories/__init__.py @@ -0,0 +1,29 @@ +""" +LogStore repository utilities. +""" + +from typing import Any + + +def safe_float(value: Any, default: float = 0.0) -> float: + """ + Safely convert a value to float, handling 'null' strings and None. + """ + if value is None or value in {"null", ""}: + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + +def safe_int(value: Any, default: int = 0) -> int: + """ + Safely convert a value to int, handling 'null' strings and None. + """ + if value is None or value in {"null", ""}: + return default + try: + return int(float(value)) + except (ValueError, TypeError): + return default diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 1cae14b726..817c8b0448 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -15,6 +15,8 @@ from sqlalchemy.orm import sessionmaker from core.workflow.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value from models.workflow import WorkflowNodeExecutionModel from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -53,9 +55,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.created_by_role = data.get("created_by_role") or "" model.created_by = data.get("created_by") or "" - # Numeric fields with defaults - model.index = int(data.get("index", 0)) - model.elapsed_time = float(data.get("elapsed_time", 0)) + model.index = safe_int(data.get("index", 0)) + model.elapsed_time = safe_float(data.get("elapsed_time", 0)) # Optional fields model.workflow_run_id = data.get("workflow_run_id") @@ -131,6 +132,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep node_id, ) try: + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_workflow_id = escape_identifier(workflow_id) + escaped_node_id = escape_identifier(node_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of each record) @@ -139,10 +146,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE tenant_id = '{tenant_id}' - AND app_id = '{app_id}' - AND workflow_id = '{workflow_id}' - AND node_id = '{node_id}' + WHERE tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND workflow_id = '{escaped_workflow_id}' + AND node_id = '{escaped_node_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 @@ -154,7 +161,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep else: # Use SDK with LogStore query syntax query = ( - f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}" + f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} " + f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}" ) from_time = 0 to_time = int(time.time()) # now @@ -230,6 +238,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep workflow_run_id, ) try: + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_workflow_run_id = escape_identifier(workflow_run_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of each record) @@ -238,9 +251,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE tenant_id = '{tenant_id}' - AND app_id = '{app_id}' - AND workflow_run_id = '{workflow_run_id}' + WHERE tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND workflow_run_id = '{escaped_workflow_run_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 1000 @@ -251,7 +264,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep ) else: # Use SDK with LogStore query syntax - query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}" + query = ( + f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} " + f"and workflow_run_id: {escaped_workflow_run_id}" + ) from_time = 0 to_time = int(time.time()) # now @@ -318,16 +334,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep """ logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id) try: + # Escape parameters to prevent SQL injection + escaped_execution_id = escape_identifier(execution_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) - tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else "" + if tenant_id: + escaped_tenant_id = escape_identifier(tenant_id) + tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'" + else: + tenant_filter = "" + sql_query = f""" SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0 + WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 1 """ @@ -337,10 +361,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep ) else: # Use SDK with LogStore query syntax + # Note: Values must be quoted in LogStore query syntax to prevent injection if tenant_id: - query = f"id: {execution_id} and tenant_id: {tenant_id}" + query = ( + f"id:{escape_logstore_query_value(execution_id)} " + f"and tenant_id:{escape_logstore_query_value(tenant_id)}" + ) else: - query = f"id: {execution_id}" + query = f"id:{escape_logstore_query_value(execution_id)}" from_time = 0 to_time = int(time.time()) # now diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 252cdcc4df..14382ed876 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -10,6 +10,7 @@ Key Features: - Optimized deduplication using finished_at IS NOT NULL filter - Window functions only when necessary (running status queries) - Multi-tenant data isolation and security +- SQL injection prevention via parameter escaping """ import logging @@ -22,6 +23,8 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowRun @@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: model.created_by_role = data.get("created_by_role") or "" model.created_by = data.get("created_by") or "" - # Numeric fields with defaults - model.total_tokens = int(data.get("total_tokens", 0)) - model.total_steps = int(data.get("total_steps", 0)) - model.exceptions_count = int(data.get("exceptions_count", 0)) + model.total_tokens = safe_int(data.get("total_tokens", 0)) + model.total_steps = safe_int(data.get("total_steps", 0)) + model.exceptions_count = safe_int(data.get("exceptions_count", 0)) # Optional fields model.graph = data.get("graph") @@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: if model.finished_at and model.created_at: model.elapsed_time = (model.finished_at - model.created_at).total_seconds() else: - model.elapsed_time = float(data.get("elapsed_time", 0)) + # Use safe conversion to handle 'null' strings and None values + model.elapsed_time = safe_float(data.get("elapsed_time", 0)) return model @@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): status, ) # Convert triggered_from to list if needed - if isinstance(triggered_from, WorkflowRunTriggeredFrom): + if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)): triggered_from_list = [triggered_from] else: triggered_from_list = list(triggered_from) - # Build triggered_from filter - triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list]) + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) - # Build status filter - status_filter = f"AND status='{status}'" if status else "" + # Build triggered_from filter with escaped values + # Support both enum and string values for triggered_from + triggered_from_filter = " OR ".join( + [ + f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'" + for tf in triggered_from_list + ] + ) + + # Build status filter with escaped value + status_filter = f"AND status='{escape_sql_string(status)}'" if status else "" # Build last_id filter for pagination # Note: This is simplified. In production, you'd need to track created_at from last record @@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' AND ({triggered_from_filter}) {status_filter} {last_id_filter} @@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id) try: + # Escape parameters to prevent SQL injection + escaped_run_id = escape_identifier(run_id) + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) @@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_execution_logstore}" - WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0 + WHERE id = '{escaped_run_id}' + AND tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 """ @@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) else: # Use SDK with LogStore query syntax - query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}" + # Note: Values must be quoted in LogStore query syntax to prevent injection + query = ( + f"id:{escape_logstore_query_value(run_id)} " + f"and tenant_id:{escape_logstore_query_value(tenant_id)} " + f"and app_id:{escape_logstore_query_value(app_id)}" + ) from_time = 0 to_time = int(time.time()) # now @@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id) try: + # Escape parameter to prevent SQL injection + escaped_run_id = escape_identifier(run_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) @@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_execution_logstore}" - WHERE id = '{run_id}' AND __time__ > 0 + WHERE id = '{escaped_run_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 """ @@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) else: # Use SDK with LogStore query syntax - query = f"id: {run_id}" + # Note: Values must be quoted in LogStore query syntax + query = f"id:{escape_logstore_query_value(run_id)}" from_time = 0 to_time = int(time.time()) # now @@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): triggered_from, status, ) + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + # Build time range filter time_filter = "" if time_range: @@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): # If status is provided, simple count if status: + escaped_status = escape_sql_string(status) + if status == "running": # Running status requires window function sql = f""" @@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND status='running' {time_filter} ) t @@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT COUNT(DISTINCT id) as count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' - AND status='{status}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' + AND status='{escaped_status}' AND finished_at IS NOT NULL {time_filter} """ @@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): # No status filter - get counts grouped by status # Use optimized query for finished runs, separate query for running try: + # Escape parameters (already escaped above, reuse variables) # Count finished runs grouped by status finished_sql = f""" SELECT status, COUNT(DISTINCT id) as count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY status @@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND status='running' {time_filter} ) t @@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug( "get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): created_by, COUNT(DISTINCT id) AS interactions FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date, created_by diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 6e6631cfef..9928879a7b 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.entities import WorkflowExecution from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id from models import ( @@ -67,7 +68,12 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Control flag for dual-write (write to both LogStore and SQL database) # Set to True to enable dual-write for safe migration, False to use LogStore only - self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true" + + # Control flag for whether to write the `graph` field to LogStore. + # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; + # otherwise write an empty {} instead. Defaults to writing the `graph` field. + self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true" def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]: """ @@ -96,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Generate log_version as nanosecond timestamp for record versioning log_version = str(time.time_ns()) + # Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.) + json_converter = WorkflowRuntimeTypeConverter() + logstore_model = [ ("id", domain_model.id_), ("log_version", log_version), # Add log_version field for append-only writes @@ -108,9 +117,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): ), ("type", domain_model.workflow_type.value), ("version", domain_model.workflow_version), - ("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"), - ("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"), - ("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"), + ( + "graph", + json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False) + if domain_model.graph and self._enable_put_graph_field + else "{}", + ), + ( + "inputs", + json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False) + if domain_model.inputs + else "{}", + ), + ( + "outputs", + json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False) + if domain_model.outputs + else "{}", + ), ("status", domain_model.status.value), ("error_message", domain_model.error_message or ""), ("total_tokens", str(domain_model.total_tokens)), diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 400a089516..4897171b12 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -24,6 +24,8 @@ from core.workflow.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier from libs.helper import extract_tenant_id from models import ( Account, @@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut node_execution_id=data.get("node_execution_id"), workflow_id=data.get("workflow_id", ""), workflow_execution_id=data.get("workflow_run_id"), - index=int(data.get("index", 0)), + index=safe_int(data.get("index", 0)), predecessor_node_id=data.get("predecessor_node_id"), node_id=data.get("node_id", ""), node_type=NodeType(data.get("node_type", "start")), @@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut outputs=outputs, status=status, error=data.get("error"), - elapsed_time=float(data.get("elapsed_time", 0.0)), + elapsed_time=safe_float(data.get("elapsed_time", 0.0)), metadata=domain_metadata, created_at=created_at, finished_at=finished_at, @@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # Control flag for dual-write (write to both LogStore and SQL database) # Set to True to enable dual-write for safe migration, False to use LogStore only - self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true" def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]: logger.debug( @@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): Save or update the inputs, process_data, or outputs associated with a specific node_execution record. - For LogStore implementation, this is similar to save() since we always write - complete records. We append a new record with updated data fields. + For LogStore implementation, this is a no-op for the LogStore write because save() + already writes all fields including inputs, process_data, and outputs. The caller + typically calls save() first to persist status/metadata, then calls save_execution_data() + to persist data fields. Since LogStore writes complete records atomically, we don't + need a separate write here to avoid duplicate records. + + However, if dual-write is enabled, we still need to call the SQL repository's + save_execution_data() method to properly update the SQL database. Args: execution: The NodeExecution instance with data to save """ - logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id) - # In LogStore, we simply write a new complete record with the data - # The log_version timestamp will ensure this is treated as the latest version - self.save(execution) + logger.debug( + "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s", + execution.id, + execution.node_execution_id, + ) + # No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs + # Calling save() again would create a duplicate record in the append-only LogStore + + # Dual-write to SQL database if enabled (for safe migration) + if self._enable_dual_write: + try: + self.sql_repository.save_execution_data(execution) + logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id) + except Exception: + logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) + # Don't raise - LogStore write succeeded, SQL is just a backup def get_by_workflow_run( self, @@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. - Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication. - This ensures we only get the final version of each node execution. + Uses LogStore SQL query with window function to get the latest version of each node execution. + This ensures we only get the most recent version of each node execution record. Args: workflow_run_id: The workflow run ID order_config: Optional configuration for ordering results @@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): A list of NodeExecution instances Note: - This method filters by finished_at IS NOT NULL to avoid duplicates from - version updates. For complete history including intermediate states, - a different query strategy would be needed. + This method uses ROW_NUMBER() window function partitioned by node_execution_id + to get the latest version (highest log_version) of each node execution. """ logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) - # Build SQL query with deduplication using finished_at IS NOT NULL - # This optimization avoids window functions for common case where we only - # want the final state of each node execution + # Build SQL query with deduplication using window function + # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) + # ensures we get the latest version of each node execution - # Build ORDER BY clause + # Escape parameters to prevent SQL injection + escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_tenant_id = escape_identifier(self._tenant_id) + + # Build ORDER BY clause for outer query order_clause = "" if order_config and order_config.order_by: order_fields = [] @@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): if order_fields: order_clause = "ORDER BY " + ", ".join(order_fields) - sql = f""" - SELECT * - FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{workflow_run_id}' - AND tenant_id='{self._tenant_id}' - AND finished_at IS NOT NULL - """ - + # Build app_id filter for subquery + app_id_filter = "" if self._app_id: - sql += f" AND app_id='{self._app_id}'" + escaped_app_id = escape_identifier(self._app_id) + app_id_filter = f" AND app_id='{escaped_app_id}'" + + # Use window function to get latest version of each node execution + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_node_execution_logstore} + WHERE workflow_run_id='{escaped_workflow_run_id}' + AND tenant_id='{escaped_tenant_id}' + {app_id_filter} + ) t + WHERE rn = 1 + """ if order_clause: sql += f" {order_clause}" diff --git a/api/extensions/logstore/sql_escape.py b/api/extensions/logstore/sql_escape.py new file mode 100644 index 0000000000..d88d6bd959 --- /dev/null +++ b/api/extensions/logstore/sql_escape.py @@ -0,0 +1,134 @@ +""" +SQL Escape Utility for LogStore Queries + +This module provides escaping utilities to prevent injection attacks in LogStore queries. + +LogStore supports two query modes: +1. PG Protocol Mode: Uses SQL syntax with single quotes for strings +2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes + +Key Security Concerns: +- Prevent tenant A from accessing tenant B's data via injection +- SLS queries are read-only, so we focus on data access control +- Different escaping strategies for SQL vs LogStore query syntax +""" + + +def escape_sql_string(value: str) -> str: + """ + Escape a string value for safe use in SQL queries. + + This function escapes single quotes by doubling them, which is the standard + SQL escaping method. This prevents SQL injection by ensuring that user input + cannot break out of string literals. + + Args: + value: The string value to escape + + Returns: + Escaped string safe for use in SQL queries + + Examples: + >>> escape_sql_string("normal_value") + "normal_value" + >>> escape_sql_string("value' OR '1'='1") + "value'' OR ''1''=''1" + >>> escape_sql_string("tenant's_id") + "tenant''s_id" + + Security: + - Prevents breaking out of string literals + - Stops injection attacks like: ' OR '1'='1 + - Protects against cross-tenant data access + """ + if not value: + return value + + # Escape single quotes by doubling them (standard SQL escaping) + # This prevents breaking out of string literals in SQL queries + return value.replace("'", "''") + + +def escape_identifier(value: str) -> str: + """ + Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use. + + This function is for PG protocol mode (SQL syntax). + For SDK mode, use escape_logstore_query_value() instead. + + Args: + value: The identifier value to escape + + Returns: + Escaped identifier safe for use in SQL queries + + Examples: + >>> escape_identifier("550e8400-e29b-41d4-a716-446655440000") + "550e8400-e29b-41d4-a716-446655440000" + >>> escape_identifier("tenant_id' OR '1'='1") + "tenant_id'' OR ''1''=''1" + + Security: + - Prevents SQL injection via identifiers + - Stops cross-tenant access attempts + - Works for UUIDs, alphanumeric IDs, and similar identifiers + """ + # For identifiers, use the same escaping as strings + # This is simple and effective for preventing injection + return escape_sql_string(value) + + +def escape_logstore_query_value(value: str) -> str: + """ + Escape value for LogStore query syntax (SDK mode). + + LogStore query syntax rules: + 1. Keywords (and/or/not) are case-insensitive + 2. Single quotes are ordinary characters (no special meaning) + 3. Double quotes wrap values: key:"value" + 4. Backslash is the escape character: + - \" for double quote inside value + - \\ for backslash itself + 5. Parentheses can change query structure + + To prevent injection: + - Wrap value in double quotes to treat special chars as literals + - Escape backslashes and double quotes using backslash + + Args: + value: The value to escape for LogStore query syntax + + Returns: + Quoted and escaped value safe for LogStore query syntax (includes the quotes) + + Examples: + >>> escape_logstore_query_value("normal_value") + '"normal_value"' + >>> escape_logstore_query_value("value or field:evil") + '"value or field:evil"' # 'or' and ':' are now literals + >>> escape_logstore_query_value('value"test') + '"value\\"test"' # Internal double quote escaped + >>> escape_logstore_query_value('value\\test') + '"value\\\\test"' # Backslash escaped + + Security: + - Prevents injection via and/or/not keywords + - Prevents injection via colons (:) + - Prevents injection via parentheses + - Protects against cross-tenant data access + + Note: + Escape order is critical: backslash first, then double quotes. + Otherwise, we'd double-escape the escape character itself. + """ + if not value: + return '""' + + # IMPORTANT: Escape backslashes FIRST, then double quotes + # This prevents double-escaping (e.g., " -> \" -> \\" incorrectly) + escaped = value.replace("\\", "\\\\") # \ -> \\ + escaped = escaped.replace('"', '\\"') # " -> \" + + # Wrap in double quotes to treat as literal string + # This prevents and/or/not/:/() from being interpreted as operators + return f'"{escaped}"' diff --git a/api/extensions/otel/instrumentation.py b/api/extensions/otel/instrumentation.py index 3597110cba..6617f69513 100644 --- a/api/extensions/otel/instrumentation.py +++ b/api/extensions/otel/instrumentation.py @@ -19,26 +19,43 @@ logger = logging.getLogger(__name__) class ExceptionLoggingHandler(logging.Handler): + """ + Handler that records exceptions to the current OpenTelemetry span. + + Unlike creating a new span, this records exceptions on the existing span + to maintain trace context consistency throughout the request lifecycle. + """ + def emit(self, record: logging.LogRecord): with contextlib.suppress(Exception): - if record.exc_info: - tracer = get_tracer_provider().get_tracer("dify.exception.logging") - with tracer.start_as_current_span( - "log.exception", - attributes={ - "log.level": record.levelname, - "log.message": record.getMessage(), - "log.logger": record.name, - "log.file.path": record.pathname, - "log.file.line": record.lineno, - }, - ) as span: - span.set_status(StatusCode.ERROR) - if record.exc_info[1]: - span.record_exception(record.exc_info[1]) - span.set_attribute("exception.message", str(record.exc_info[1])) - if record.exc_info[0]: - span.set_attribute("exception.type", record.exc_info[0].__name__) + if not record.exc_info: + return + + from opentelemetry.trace import get_current_span + + span = get_current_span() + if not span or not span.is_recording(): + return + + # Record exception on the current span instead of creating a new one + span.set_status(StatusCode.ERROR, record.getMessage()) + + # Add log context as span events/attributes + span.add_event( + "log.exception", + attributes={ + "log.level": record.levelname, + "log.message": record.getMessage(), + "log.logger": record.name, + "log.file.path": record.pathname, + "log.file.line": record.lineno, + }, + ) + + if record.exc_info[1]: + span.record_exception(record.exc_info[1]) + if record.exc_info[0]: + span.set_attribute("exception.type", record.exc_info[0].__name__) def instrument_exception_logging() -> None: diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py new file mode 100644 index 0000000000..164db7c275 --- /dev/null +++ b/api/extensions/otel/parser/__init__.py @@ -0,0 +1,20 @@ +""" +OpenTelemetry node parsers for workflow nodes. + +This module provides parsers that extract node-specific metadata and set +OpenTelemetry span attributes according to semantic conventions. +""" + +from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps +from extensions.otel.parser.llm import LLMNodeOTelParser +from extensions.otel.parser.retrieval import RetrievalNodeOTelParser +from extensions.otel.parser.tool import ToolNodeOTelParser + +__all__ = [ + "DefaultNodeOTelParser", + "LLMNodeOTelParser", + "NodeOTelParser", + "RetrievalNodeOTelParser", + "ToolNodeOTelParser", + "safe_json_dumps", +] diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py new file mode 100644 index 0000000000..f4db26e840 --- /dev/null +++ b/api/extensions/otel/parser/base.py @@ -0,0 +1,117 @@ +""" +Base parser interface and utilities for OpenTelemetry node parsers. +""" + +import json +from typing import Any, Protocol + +from opentelemetry.trace import Span +from opentelemetry.trace.status import Status, StatusCode +from pydantic import BaseModel + +from core.file.models import File +from core.variables import Segment +from core.workflow.enums import NodeType +from core.workflow.graph_events import GraphNodeEventBase +from core.workflow.nodes.base.node import Node +from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes + + +def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str: + """ + Safely serialize objects to JSON, handling non-serializable types. + + Handles: + - Segment types (ArrayFileSegment, FileSegment, etc.) - converts to their value + - File objects - converts to dict using to_dict() + - BaseModel objects - converts using model_dump() + - Other types - falls back to str() representation + + Args: + obj: Object to serialize + ensure_ascii: Whether to ensure ASCII encoding + + Returns: + JSON string representation of the object + """ + + def _convert_value(value: Any) -> Any: + """Recursively convert non-serializable values.""" + if value is None: + return None + if isinstance(value, (bool, int, float, str)): + return value + if isinstance(value, Segment): + # Convert Segment to its underlying value + return _convert_value(value.value) + if isinstance(value, File): + # Convert File to dict + return value.to_dict() + if isinstance(value, BaseModel): + # Convert Pydantic model to dict + return _convert_value(value.model_dump(mode="json")) + if isinstance(value, dict): + return {k: _convert_value(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_convert_value(item) for item in value] + # Fallback to string representation for unknown types + return str(value) + + try: + converted = _convert_value(obj) + return json.dumps(converted, ensure_ascii=ensure_ascii) + except (TypeError, ValueError) as e: + # If conversion still fails, return error message as string + return json.dumps( + {"error": f"Failed to serialize: {type(obj).__name__}", "message": str(e)}, ensure_ascii=ensure_ascii + ) + + +class NodeOTelParser(Protocol): + """Parser interface for node-specific OpenTelemetry enrichment.""" + + def parse( + self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: ... + + +class DefaultNodeOTelParser: + """Fallback parser used when no node-specific parser is registered.""" + + def parse( + self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: + span.set_attribute("node.id", node.id) + if node.execution_id: + span.set_attribute("node.execution_id", node.execution_id) + if hasattr(node, "node_type") and node.node_type: + span.set_attribute("node.type", node.node_type.value) + + span.set_attribute(GenAIAttributes.FRAMEWORK, "dify") + + node_type = getattr(node, "node_type", None) + if isinstance(node_type, NodeType): + if node_type == NodeType.LLM: + span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM") + elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: + span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER") + elif node_type == NodeType.TOOL: + span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL") + else: + span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK") + else: + span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK") + + # Extract inputs and outputs from result_event + if result_event and result_event.node_run_result: + node_run_result = result_event.node_run_result + if node_run_result.inputs: + span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) + if node_run_result.outputs: + span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) + + if error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + else: + span.set_status(Status(StatusCode.OK)) diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py new file mode 100644 index 0000000000..8556974080 --- /dev/null +++ b/api/extensions/otel/parser/llm.py @@ -0,0 +1,155 @@ +""" +Parser for LLM nodes that captures LLM-specific metadata. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +from opentelemetry.trace import Span + +from core.workflow.graph_events import GraphNodeEventBase +from core.workflow.nodes.base.node import Node +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.semconv.gen_ai import LLMAttributes + +logger = logging.getLogger(__name__) + + +def _format_input_messages(process_data: Mapping[str, Any]) -> str: + """ + Format input messages from process_data for LLM spans. + + Args: + process_data: Process data containing prompts + + Returns: + JSON string of formatted input messages + """ + try: + if not isinstance(process_data, dict): + return safe_json_dumps([]) + + prompts = process_data.get("prompts", []) + if not prompts: + return safe_json_dumps([]) + + valid_roles = {"system", "user", "assistant", "tool"} + input_messages = [] + for prompt in prompts: + if not isinstance(prompt, dict): + continue + + role = prompt.get("role", "") + text = prompt.get("text", "") + + if not role or role not in valid_roles: + continue + + if text: + message = {"role": role, "parts": [{"type": "text", "content": text}]} + input_messages.append(message) + + return safe_json_dumps(input_messages) + except Exception as e: + logger.warning("Failed to format input messages: %s", e, exc_info=True) + return safe_json_dumps([]) + + +def _format_output_messages(outputs: Mapping[str, Any]) -> str: + """ + Format output messages from outputs for LLM spans. + + Args: + outputs: Output data containing text and finish_reason + + Returns: + JSON string of formatted output messages + """ + try: + if not isinstance(outputs, dict): + return safe_json_dumps([]) + + text = outputs.get("text", "") + finish_reason = outputs.get("finish_reason", "") + + if not text: + return safe_json_dumps([]) + + valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"} + if finish_reason not in valid_finish_reasons: + finish_reason = "stop" + + output_message = { + "role": "assistant", + "parts": [{"type": "text", "content": text}], + "finish_reason": finish_reason, + } + + return safe_json_dumps([output_message]) + except Exception as e: + logger.warning("Failed to format output messages: %s", e, exc_info=True) + return safe_json_dumps([]) + + +class LLMNodeOTelParser: + """Parser for LLM nodes that captures LLM-specific metadata.""" + + def __init__(self) -> None: + self._delegate = DefaultNodeOTelParser() + + def parse( + self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: + self._delegate.parse(node=node, span=span, error=error, result_event=result_event) + + if not result_event or not result_event.node_run_result: + return + + node_run_result = result_event.node_run_result + process_data = node_run_result.process_data or {} + outputs = node_run_result.outputs or {} + + # Extract usage data (from process_data or outputs) + usage_data = process_data.get("usage") or outputs.get("usage") or {} + + # Model and provider information + model_name = process_data.get("model_name") or "" + model_provider = process_data.get("model_provider") or "" + + if model_name: + span.set_attribute(LLMAttributes.REQUEST_MODEL, model_name) + if model_provider: + span.set_attribute(LLMAttributes.PROVIDER_NAME, model_provider) + + # Token usage + if usage_data: + prompt_tokens = usage_data.get("prompt_tokens", 0) + completion_tokens = usage_data.get("completion_tokens", 0) + total_tokens = usage_data.get("total_tokens", 0) + + span.set_attribute(LLMAttributes.USAGE_INPUT_TOKENS, prompt_tokens) + span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens) + span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens) + + # Prompts and completion + prompts = process_data.get("prompts", []) + if prompts: + prompts_json = safe_json_dumps(prompts) + span.set_attribute(LLMAttributes.PROMPT, prompts_json) + + text_output = str(outputs.get("text", "")) + if text_output: + span.set_attribute(LLMAttributes.COMPLETION, text_output) + + # Finish reason + finish_reason = outputs.get("finish_reason") or "" + if finish_reason: + span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason) + + # Structured input/output messages + gen_ai_input_message = _format_input_messages(process_data) + gen_ai_output_message = _format_output_messages(outputs) + + span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message) + span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py new file mode 100644 index 0000000000..fc151af691 --- /dev/null +++ b/api/extensions/otel/parser/retrieval.py @@ -0,0 +1,105 @@ +""" +Parser for knowledge retrieval nodes that captures retrieval-specific metadata. +""" + +import logging +from collections.abc import Sequence +from typing import Any + +from opentelemetry.trace import Span + +from core.variables import Segment +from core.workflow.graph_events import GraphNodeEventBase +from core.workflow.nodes.base.node import Node +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.semconv.gen_ai import RetrieverAttributes + +logger = logging.getLogger(__name__) + + +def _format_retrieval_documents(retrieval_documents: list[Any]) -> list: + """ + Format retrieval documents for semantic conventions. + + Args: + retrieval_documents: List of retrieval document dictionaries + + Returns: + List of formatted semantic documents + """ + try: + if not isinstance(retrieval_documents, list): + return [] + + semantic_documents = [] + for doc in retrieval_documents: + if not isinstance(doc, dict): + continue + + metadata = doc.get("metadata", {}) + content = doc.get("content", "") + title = doc.get("title", "") + score = metadata.get("score", 0.0) + document_id = metadata.get("document_id", "") + + semantic_metadata = {} + if title: + semantic_metadata["title"] = title + if metadata.get("source"): + semantic_metadata["source"] = metadata["source"] + elif metadata.get("_source"): + semantic_metadata["source"] = metadata["_source"] + if metadata.get("doc_metadata"): + doc_metadata = metadata["doc_metadata"] + if isinstance(doc_metadata, dict): + semantic_metadata.update(doc_metadata) + + semantic_doc = { + "document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id} + } + semantic_documents.append(semantic_doc) + + return semantic_documents + except Exception as e: + logger.warning("Failed to format retrieval documents: %s", e, exc_info=True) + return [] + + +class RetrievalNodeOTelParser: + """Parser for knowledge retrieval nodes that captures retrieval-specific metadata.""" + + def __init__(self) -> None: + self._delegate = DefaultNodeOTelParser() + + def parse( + self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: + self._delegate.parse(node=node, span=span, error=error, result_event=result_event) + + if not result_event or not result_event.node_run_result: + return + + node_run_result = result_event.node_run_result + inputs = node_run_result.inputs or {} + outputs = node_run_result.outputs or {} + + # Extract query from inputs + query = str(inputs.get("query", "")) if inputs else "" + if query: + span.set_attribute(RetrieverAttributes.QUERY, query) + + # Extract and format retrieval documents from outputs + result_value = outputs.get("result") if outputs else None + retrieval_documents: list[Any] = [] + if result_value: + value_to_check = result_value + if isinstance(result_value, Segment): + value_to_check = result_value.value + + if isinstance(value_to_check, (list, Sequence)): + retrieval_documents = list(value_to_check) + + if retrieval_documents: + semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents) + semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents) + span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py new file mode 100644 index 0000000000..b99180722b --- /dev/null +++ b/api/extensions/otel/parser/tool.py @@ -0,0 +1,47 @@ +""" +Parser for tool nodes that captures tool-specific metadata. +""" + +from opentelemetry.trace import Span + +from core.workflow.enums import WorkflowNodeExecutionMetadataKey +from core.workflow.graph_events import GraphNodeEventBase +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.tool.entities import ToolNodeData +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.semconv.gen_ai import ToolAttributes + + +class ToolNodeOTelParser: + """Parser for tool nodes that captures tool-specific metadata.""" + + def __init__(self) -> None: + self._delegate = DefaultNodeOTelParser() + + def parse( + self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: + self._delegate.parse(node=node, span=span, error=error, result_event=result_event) + + tool_data = getattr(node, "_node_data", None) + if not isinstance(tool_data, ToolNodeData): + return + + span.set_attribute(ToolAttributes.TOOL_NAME, node.title) + span.set_attribute(ToolAttributes.TOOL_TYPE, tool_data.provider_type.value) + + # Extract tool info from metadata (consistent with aliyun_trace) + tool_info = {} + if result_event and result_event.node_run_result: + node_run_result = result_event.node_run_result + if node_run_result.metadata: + tool_info = node_run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {}) + + if tool_info: + span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info)) + + if result_event and result_event.node_run_result and result_event.node_run_result.inputs: + span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs)) + + if result_event and result_event.node_run_result and result_event.node_run_result.outputs: + span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs)) diff --git a/api/extensions/otel/semconv/__init__.py b/api/extensions/otel/semconv/__init__.py index dc79dee222..0db3075815 100644 --- a/api/extensions/otel/semconv/__init__.py +++ b/api/extensions/otel/semconv/__init__.py @@ -1,6 +1,13 @@ """Semantic convention shortcuts for Dify-specific spans.""" from .dify import DifySpanAttributes -from .gen_ai import GenAIAttributes +from .gen_ai import ChainAttributes, GenAIAttributes, LLMAttributes, RetrieverAttributes, ToolAttributes -__all__ = ["DifySpanAttributes", "GenAIAttributes"] +__all__ = [ + "ChainAttributes", + "DifySpanAttributes", + "GenAIAttributes", + "LLMAttributes", + "RetrieverAttributes", + "ToolAttributes", +] diff --git a/api/extensions/otel/semconv/gen_ai.py b/api/extensions/otel/semconv/gen_ai.py index 83c52ed34f..88c2058c06 100644 --- a/api/extensions/otel/semconv/gen_ai.py +++ b/api/extensions/otel/semconv/gen_ai.py @@ -62,3 +62,37 @@ class ToolAttributes: TOOL_CALL_RESULT = "gen_ai.tool.call.result" """Tool invocation result.""" + + +class LLMAttributes: + """LLM operation attribute keys.""" + + REQUEST_MODEL = "gen_ai.request.model" + """Model identifier.""" + + PROVIDER_NAME = "gen_ai.provider.name" + """Provider name.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens.""" + + PROMPT = "gen_ai.prompt" + """Prompt text.""" + + COMPLETION = "gen_ai.completion" + """Completion text.""" + + RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason" + """Finish reason for the response.""" + + INPUT_MESSAGE = "gen_ai.input.messages" + """Input messages in structured format.""" + + OUTPUT_MESSAGE = "gen_ai.output.messages" + """Output messages in structured format.""" diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 51a97b20f8..1d9911465b 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -5,6 +5,8 @@ automatic cleanup, backup and restore. Supports complete lifecycle management for knowledge base files. """ +from __future__ import annotations + import json import logging import operator @@ -48,7 +50,7 @@ class FileMetadata: return data @classmethod - def from_dict(cls, data: dict) -> "FileMetadata": + def from_dict(cls, data: dict) -> FileMetadata: """Create instance from dictionary""" data = data.copy() data["created_at"] = datetime.fromisoformat(data["created_at"]) diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index ea5d982efc..cf092c6973 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage): super().__init__() self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME - config = CosConfig( - Region=dify_config.TENCENT_COS_REGION, - SecretId=dify_config.TENCENT_COS_SECRET_ID, - SecretKey=dify_config.TENCENT_COS_SECRET_KEY, - Scheme=dify_config.TENCENT_COS_SCHEME, - ) + if dify_config.TENCENT_COS_CUSTOM_DOMAIN: + config = CosConfig( + Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN, + SecretId=dify_config.TENCENT_COS_SECRET_ID, + SecretKey=dify_config.TENCENT_COS_SECRET_KEY, + Scheme=dify_config.TENCENT_COS_SCHEME, + ) + else: + config = CosConfig( + Region=dify_config.TENCENT_COS_REGION, + SecretId=dify_config.TENCENT_COS_SECRET_ID, + SecretKey=dify_config.TENCENT_COS_SECRET_KEY, + Scheme=dify_config.TENCENT_COS_SCHEME, + ) self.client = CosS3Client(config) def save(self, filename, data): diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index bd71f18af2..0be836c8f1 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -115,7 +115,18 @@ def build_from_mappings( # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. # Implement batch processing to reduce database load when handling multiple files. # Filter out None/empty mappings to avoid errors - valid_mappings = [m for m in mappings if m and m.get("transfer_method")] + def is_valid_mapping(m: Mapping[str, Any]) -> bool: + if not m or not m.get("transfer_method"): + return False + # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None + transfer_method = m.get("transfer_method") + if transfer_method == FileTransferMethod.REMOTE_URL: + url = m.get("url") or m.get("remote_url") + if not url: + return False + return True + + valid_mappings = [m for m in mappings if is_valid_mapping(m)] files = [ build_from_mapping( mapping=mapping, diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 494194369a..3f030ae127 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -38,7 +38,7 @@ from core.variables.variables import ( ObjectVariable, SecretVariable, StringVariable, - Variable, + VariableBase, ) from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, @@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = { } -def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) -def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) -def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("variable"): raise VariableError("missing variable") return mapping["variable"] -def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: +def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase: """ This factory function is used to create the environment variable or the conversation variable, not support the File type. @@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if (value := mapping.get("value")) is None: raise VariableError("missing value") - result: Variable + result: VariableBase match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return cast(Variable, result) + return cast(VariableBase, result) def build_segment(value: Any, /) -> Segment: @@ -285,8 +285,8 @@ def segment_to_variable( id: str | None = None, name: str | None = None, description: str = "", -) -> Variable: - if isinstance(segment, Variable): +) -> VariableBase: + if isinstance(segment, VariableBase): return segment name = name or selector[-1] id = id or str(uuid4()) @@ -297,7 +297,7 @@ def segment_to_variable( variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] return cast( - Variable, + VariableBase, variable_class( id=id, name=name, diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 38835d5ac7..e69306dcb2 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import TimestampField @@ -12,7 +12,7 @@ annotation_fields = { } -def build_annotation_model(api_or_ns: Api | Namespace): +def build_annotation_model(api_or_ns: Namespace): """Build the annotation model for the API or Namespace.""" return api_or_ns.model("Annotation", annotation_fields) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 0bcc797ec7..cda46f2339 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,241 +1,339 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from fields.member_fields import simple_account_fields -from libs.helper import TimestampField +from datetime import datetime +from typing import Any, TypeAlias -from .raws import FilesContainedField +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from core.file import File + +JSONValue: TypeAlias = Any -class MessageTextField(fields.Raw): - def format(self, value): - return value[0]["text"] if value else "" +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -feedback_fields = { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_fields, allow_null=True), -} +class MessageFile(ResponseModel): + id: str + filename: str + type: str + url: str | None = None + mime_type: str | None = None + size: int | None = None + transfer_method: str + belongs_to: str | None = None + upload_file_id: str | None = None -annotation_fields = { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_fields, allow_null=True), - "created_at": TimestampField, -} - -annotation_hit_history_fields = { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True), - "created_at": TimestampField, -} - -message_file_fields = { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), -} + @field_validator("transfer_method", mode="before") + @classmethod + def _normalize_transfer_method(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) -def build_message_file_model(api_or_ns: Api | Namespace): - """Build the message file fields for the API or Namespace.""" - return api_or_ns.model("MessageFile", message_file_fields) +class SimpleConversation(ResponseModel): + id: str + name: str + inputs: dict[str, JSONValue] + status: str + introduction: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value -agent_thought_fields = { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), -} - -message_detail_fields = { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_fields)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_fields, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields)), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, -} - -feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} -status_count_fields = { - "success": fields.Integer, - "failed": fields.Integer, - "partial_success": fields.Integer, - "paused": fields.Integer, -} -model_config_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "model": fields.Raw, - "user_input_form": fields.Raw, - "pre_prompt": fields.String, - "agent_mode": fields.Raw, -} - -simple_model_config_fields = { - "model": fields.Raw(attribute="model_dict"), - "pre_prompt": fields.String, -} - -simple_message_detail_fields = { - "inputs": FilesContainedField, - "query": fields.String, - "message": MessageTextField, - "answer": fields.String, -} - -conversation_fields = { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String(), - "from_account_id": fields.String, - "from_account_name": fields.String, - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotation": fields.Nested(annotation_fields, allow_null=True), - "model_config": fields.Nested(simple_model_config_fields), - "user_feedback_stats": fields.Nested(feedback_stat_fields), - "admin_feedback_stats": fields.Nested(feedback_stat_fields), - "message": fields.Nested(simple_message_detail_fields, attribute="first_message"), -} - -conversation_pagination_fields = { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_fields), attribute="items"), -} - -conversation_message_detail_fields = { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "model_config": fields.Nested(model_config_fields), - "message": fields.Nested(message_detail_fields, attribute="first_message"), -} - -conversation_with_summary_fields = { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String, - "from_account_id": fields.String, - "from_account_name": fields.String, - "name": fields.String, - "summary": fields.String(attribute="summary_or_query"), - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "model_config": fields.Nested(simple_model_config_fields), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_fields), - "admin_feedback_stats": fields.Nested(feedback_stat_fields), - "status_count": fields.Nested(status_count_fields), -} - -conversation_with_summary_pagination_fields = { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"), -} - -conversation_detail_fields = { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "introduction": fields.String, - "model_config": fields.Nested(model_config_fields), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_fields), - "admin_feedback_stats": fields.Nested(feedback_stat_fields), -} - -simple_conversation_fields = { - "id": fields.String, - "name": fields.String, - "inputs": FilesContainedField, - "status": fields.String, - "introduction": fields.String, - "created_at": TimestampField, - "updated_at": TimestampField, -} - -conversation_delete_fields = { - "result": fields.String, -} - -conversation_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(simple_conversation_fields)), -} +class ConversationInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[SimpleConversation] -def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): - """Build the conversation infinite scroll pagination model for the API or Namespace.""" - simple_conversation_model = build_simple_conversation_model(api_or_ns) - - copied_fields = conversation_infinite_scroll_pagination_fields.copy() - copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model)) - return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields) +class ConversationDelete(ResponseModel): + result: str -def build_conversation_delete_model(api_or_ns: Api | Namespace): - """Build the conversation delete model for the API or Namespace.""" - return api_or_ns.model("ConversationDelete", conversation_delete_fields) +class ResultResponse(ResponseModel): + result: str -def build_simple_conversation_model(api_or_ns: Api | Namespace): - """Build the simple conversation model for the API or Namespace.""" - return api_or_ns.model("SimpleConversation", simple_conversation_fields) +class SimpleAccount(ResponseModel): + id: str + name: str + email: str + + +class Feedback(ResponseModel): + rating: str + content: str | None = None + from_source: str + from_end_user_id: str | None = None + from_account: SimpleAccount | None = None + + +class Annotation(ResponseModel): + id: str + question: str | None = None + content: str + account: SimpleAccount | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class AnnotationHitHistory(ResponseModel): + annotation_id: str + annotation_create_account: SimpleAccount | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class AgentThought(ResponseModel): + id: str + chain_id: str | None = None + message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id") + message_id: str + position: int + thought: str | None = None + tool: str | None = None + tool_labels: JSONValue + tool_input: str | None = None + created_at: int | None = None + observation: str | None = None + files: list[str] + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + @model_validator(mode="after") + def _fallback_chain_id(self): + if self.chain_id is None and self.message_chain_id: + self.chain_id = self.message_chain_id + return self + + +class MessageDetail(ResponseModel): + id: str + conversation_id: str + inputs: dict[str, JSONValue] + query: str + message: JSONValue + message_tokens: int + answer: str + answer_tokens: int + provider_response_latency: float + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + feedbacks: list[Feedback] + workflow_run_id: str | None = None + annotation: Annotation | None = None + annotation_hit_history: AnnotationHitHistory | None = None + created_at: int | None = None + agent_thoughts: list[AgentThought] + message_files: list[MessageFile] + metadata: JSONValue + status: str + error: str | None = None + parent_message_id: str | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class FeedbackStat(ResponseModel): + like: int + dislike: int + + +class StatusCount(ResponseModel): + success: int + failed: int + partial_success: int + paused: int + + +class ModelConfig(ResponseModel): + opening_statement: str | None = None + suggested_questions: JSONValue | None = None + model: JSONValue | None = None + user_input_form: JSONValue | None = None + pre_prompt: str | None = None + agent_mode: JSONValue | None = None + + +class SimpleModelConfig(ResponseModel): + model: JSONValue | None = None + pre_prompt: str | None = None + + +class SimpleMessageDetail(ResponseModel): + inputs: dict[str, JSONValue] + query: str + message: str + answer: str + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + +class Conversation(ResponseModel): + id: str + status: str + from_source: str + from_end_user_id: str | None = None + from_end_user_session_id: str | None = None + from_account_id: str | None = None + from_account_name: str | None = None + read_at: int | None = None + created_at: int | None = None + updated_at: int | None = None + annotation: Annotation | None = None + model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config") + user_feedback_stats: FeedbackStat | None = None + admin_feedback_stats: FeedbackStat | None = None + message: SimpleMessageDetail | None = None + + +class ConversationPagination(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[Conversation] + + +class ConversationMessageDetail(ResponseModel): + id: str + status: str + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + created_at: int | None = None + model_config_: ModelConfig | None = Field(default=None, alias="model_config") + message: MessageDetail | None = None + + +class ConversationWithSummary(ResponseModel): + id: str + status: str + from_source: str + from_end_user_id: str | None = None + from_end_user_session_id: str | None = None + from_account_id: str | None = None + from_account_name: str | None = None + name: str + summary: str + read_at: int | None = None + created_at: int | None = None + updated_at: int | None = None + annotated: bool + model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config") + message_count: int + user_feedback_stats: FeedbackStat | None = None + admin_feedback_stats: FeedbackStat | None = None + status_count: StatusCount | None = None + + +class ConversationWithSummaryPagination(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationWithSummary] + + +class ConversationDetail(ResponseModel): + id: str + status: str + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + created_at: int | None = None + updated_at: int | None = None + annotated: bool + introduction: str | None = None + model_config_: ModelConfig | None = Field(default=None, alias="model_config") + message_count: int + user_feedback_stats: FeedbackStat | None = None + admin_feedback_stats: FeedbackStat | None = None + + +def to_timestamp(value: datetime | None) -> int | None: + if value is None: + return None + return int(value.timestamp()) + + +def format_files_contained(value: JSONValue) -> JSONValue: + if isinstance(value, File): + return value.model_dump() + if isinstance(value, dict): + return {k: format_files_contained(v) for k, v in value.items()} + if isinstance(value, list): + return [format_files_contained(v) for v in value] + return value + + +def message_text(value: JSONValue) -> str: + if isinstance(value, list) and value: + first = value[0] + if isinstance(first, dict): + text = first.get("text") + if isinstance(text, str): + return text + return "" + + +def extract_model_config(value: object | None) -> dict[str, JSONValue]: + if value is None: + return {} + if isinstance(value, dict): + return value + if hasattr(value, "to_dict"): + return value.to_dict() + return {} diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 7d5e311591..c55014a368 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import TimestampField @@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = { } -def build_conversation_variable_model(api_or_ns: Api | Namespace): +def build_conversation_variable_model(api_or_ns: Namespace): """Build the conversation variable model for the API or Namespace.""" return api_or_ns.model("ConversationVariable", conversation_variable_fields) -def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): +def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace): """Build the conversation variable infinite scroll pagination model for the API or Namespace.""" # Build the nested variable model first conversation_variable_model = build_conversation_variable_model(api_or_ns) diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index ea43e3b5fd..5389b0213a 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields simple_end_user_fields = { "id": fields.String, @@ -8,5 +8,5 @@ simple_end_user_fields = { } -def build_simple_end_user_model(api_or_ns: Api | Namespace): +def build_simple_end_user_model(api_or_ns: Namespace): return api_or_ns.model("SimpleEndUser", simple_end_user_fields) diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index a707500445..913fb675f9 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,93 +1,85 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -upload_config_fields = { - "file_size_limit": fields.Integer, - "batch_count_limit": fields.Integer, - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "workflow_file_upload_limit": fields.Integer, - "image_file_batch_limit": fields.Integer, - "single_chunk_attachment_limit": fields.Integer, -} +from pydantic import BaseModel, ConfigDict, field_validator -def build_upload_config_model(api_or_ns: Api | Namespace): - """Build the upload config model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("UploadConfig", upload_config_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -file_fields = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "mime_type": fields.String, - "created_by": fields.String, - "created_at": TimestampField, - "preview_url": fields.String, - "source_url": fields.String, -} +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def build_file_model(api_or_ns: Api | Namespace): - """Build the file model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("File", file_fields) +class UploadConfig(ResponseModel): + file_size_limit: int + batch_count_limit: int + file_upload_limit: int | None = None + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + workflow_file_upload_limit: int + image_file_batch_limit: int + single_chunk_attachment_limit: int + attachment_image_file_size_limit: int | None = None -remote_file_info_fields = { - "file_type": fields.String(attribute="file_type"), - "file_length": fields.Integer(attribute="file_length"), -} +class FileResponse(ResponseModel): + id: str + name: str + size: int + extension: str | None = None + mime_type: str | None = None + created_by: str | None = None + created_at: int | None = None + preview_url: str | None = None + source_url: str | None = None + original_url: str | None = None + user_id: str | None = None + tenant_id: str | None = None + conversation_id: str | None = None + file_key: str | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_remote_file_info_model(api_or_ns: Api | Namespace): - """Build the remote file info model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("RemoteFileInfo", remote_file_info_fields) +class RemoteFileInfo(ResponseModel): + file_type: str + file_length: int -file_fields_with_signed_url = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "url": fields.String, - "mime_type": fields.String, - "created_by": fields.String, - "created_at": TimestampField, -} +class FileWithSignedUrl(ResponseModel): + id: str + name: str + size: int + extension: str | None = None + url: str | None = None + mime_type: str | None = None + created_by: str | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_file_with_signed_url_model(api_or_ns: Api | Namespace): - """Build the file with signed URL model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url) +__all__ = [ + "FileResponse", + "FileWithSignedUrl", + "RemoteFileInfo", + "UploadConfig", +] diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 08e38a6931..25160927e6 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import AvatarUrlField, TimestampField @@ -9,7 +9,7 @@ simple_account_fields = { } -def build_simple_account_model(api_or_ns: Api | Namespace): +def build_simple_account_model(api_or_ns: Namespace): return api_or_ns.model("SimpleAccount", simple_account_fields) diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 552f0b598f..53911a18dd 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,78 +1,140 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField +from datetime import datetime +from typing import TypeAlias +from uuid import uuid4 -from .raws import FilesContainedField +from pydantic import BaseModel, ConfigDict, Field, field_validator -feedback_fields = { - "rating": fields.String, -} +from core.entities.execution_extra_content import ExecutionExtraContentDomainModel +from core.file import File +from fields.conversation_fields import AgentThought, JSONValue, MessageFile + +JSONValueType: TypeAlias = JSONValue -def build_feedback_model(api_or_ns: Api | Namespace): - """Build the feedback model for the API or Namespace.""" - return api_or_ns.model("Feedback", feedback_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict(from_attributes=True, extra="ignore") -agent_thought_fields = { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), -} +class SimpleFeedback(ResponseModel): + rating: str | None = None -def build_agent_thought_model(api_or_ns: Api | Namespace): - """Build the agent thought model for the API or Namespace.""" - return api_or_ns.model("AgentThought", agent_thought_fields) +class RetrieverResource(ResponseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + message_id: str = Field(default_factory=lambda: str(uuid4())) + position: int + dataset_id: str | None = None + dataset_name: str | None = None + document_id: str | None = None + document_name: str | None = None + data_source_type: str | None = None + segment_id: str | None = None + score: float | None = None + hit_count: int | None = None + word_count: int | None = None + segment_position: int | None = None + index_node_hash: str | None = None + content: str | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value -retriever_resource_fields = { - "id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "dataset_id": fields.String, - "dataset_name": fields.String, - "document_id": fields.String, - "document_name": fields.String, - "data_source_type": fields.String, - "segment_id": fields.String, - "score": fields.Float, - "hit_count": fields.Integer, - "word_count": fields.Integer, - "segment_position": fields.Integer, - "index_node_hash": fields.String, - "content": fields.String, - "created_at": TimestampField, -} +class MessageListItem(ResponseModel): + id: str + conversation_id: str + parent_message_id: str | None = None + inputs: dict[str, JSONValueType] + query: str + answer: str = Field(validation_alias="re_sign_file_url_answer") + feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback") + retriever_resources: list[RetrieverResource] + created_at: int | None = None + agent_thoughts: list[AgentThought] + message_files: list[MessageFile] + status: str + error: str | None = None + extra_contents: list[ExecutionExtraContentDomainModel] -message_fields = { - "id": fields.String, - "conversation_id": fields.String, - "parent_message_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), - "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), - "extra_contents": fields.List(cls_or_instance=fields.Raw), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields)), - "status": fields.String, - "error": fields.String, -} + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType: + return format_files_contained(value) -message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), -} + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class WebMessageListItem(MessageListItem): + metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict") + + +class MessageInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[MessageListItem] + + +class WebMessageInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[WebMessageListItem] + + +class SavedMessageItem(ResponseModel): + id: str + inputs: dict[str, JSONValueType] + query: str + answer: str + message_files: list[MessageFile] + feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback") + created_at: int | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType: + return format_files_contained(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class SavedMessageInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[SavedMessageItem] + + +class SuggestedQuestionsResponse(ResponseModel): + data: list[str] + + +def to_timestamp(value: datetime | None) -> int | None: + if value is None: + return None + return int(value.timestamp()) + + +def format_files_contained(value: JSONValueType) -> JSONValueType: + if isinstance(value, File): + return value.model_dump() + if isinstance(value, dict): + return {k: format_files_contained(v) for k, v in value.items()} + if isinstance(value, list): + return [format_files_contained(v) for v in value] + return value diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py index f9e858c68b..97c02e7085 100644 --- a/api/fields/rag_pipeline_fields.py +++ b/api/fields/rag_pipeline_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields # type: ignore +from flask_restx import fields from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index d5b7c86a04..e359a4408c 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields dataset_tag_fields = { "id": fields.String, @@ -8,5 +8,5 @@ dataset_tag_fields = { } -def build_dataset_tag_fields(api_or_ns: Api | Namespace): +def build_dataset_tag_fields(api_or_ns: Namespace): return api_or_ns.model("DataSetTag", dataset_tag_fields) diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 4cbdf6f0ca..ae70356322 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,8 +1,13 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields from fields.member_fields import build_simple_account_model, simple_account_fields -from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields +from fields.workflow_run_fields import ( + build_workflow_run_for_archived_log_model, + build_workflow_run_for_log_model, + workflow_run_for_archived_log_fields, + workflow_run_for_log_fields, +) from libs.helper import TimestampField workflow_app_log_partial_fields = { @@ -17,7 +22,7 @@ workflow_app_log_partial_fields = { } -def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): +def build_workflow_app_log_partial_model(api_or_ns: Namespace): """Build the workflow app log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_log_model(api_or_ns) simple_account_model = build_simple_account_model(api_or_ns) @@ -34,6 +39,33 @@ def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): return api_or_ns.model("WorkflowAppLogPartial", copied_fields) +workflow_archived_log_partial_fields = { + "id": fields.String, + "workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True), + "trigger_metadata": fields.Raw, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, +} + + +def build_workflow_archived_log_partial_model(api_or_ns: Namespace): + """Build the workflow archived log partial model for the API or Namespace.""" + workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) + simple_account_model = build_simple_account_model(api_or_ns) + simple_end_user_model = build_simple_end_user_model(api_or_ns) + + copied_fields = workflow_archived_log_partial_fields.copy() + copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) + copied_fields["created_by_account"] = fields.Nested( + simple_account_model, attribute="created_by_account", allow_null=True + ) + copied_fields["created_by_end_user"] = fields.Nested( + simple_end_user_model, attribute="created_by_end_user", allow_null=True + ) + return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) + + workflow_app_log_pagination_fields = { "page": fields.Integer, "limit": fields.Integer, @@ -43,7 +75,7 @@ workflow_app_log_pagination_fields = { } -def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): +def build_workflow_app_log_pagination_model(api_or_ns: Namespace): """Build the workflow app log pagination model for the API or Namespace.""" # Build the nested partial model first workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) @@ -51,3 +83,21 @@ def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): copied_fields = workflow_app_log_pagination_fields.copy() copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) return api_or_ns.model("WorkflowAppLogPagination", copied_fields) + + +workflow_archived_log_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(workflow_archived_log_partial_fields)), +} + + +def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): + """Build the workflow archived log pagination model for the API or Namespace.""" + workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns) + + copied_fields = workflow_archived_log_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model)) + return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index d037b0c442..2755f77f61 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields from core.helper import encrypter -from core.variables import SecretVariable, SegmentType, Variable +from core.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw): "value_type": value.value_type.value, "description": value.description, } - if isinstance(value, Variable): + if isinstance(value, VariableBase): return { "id": value.id, "name": value.name, diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 821ce62ecc..35bb442c59 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -19,10 +19,23 @@ workflow_run_for_log_fields = { } -def build_workflow_run_for_log_model(api_or_ns: Api | Namespace): +def build_workflow_run_for_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) +workflow_run_for_archived_log_fields = { + "id": fields.String, + "status": fields.String, + "triggered_from": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, +} + + +def build_workflow_run_for_archived_log_model(api_or_ns: Namespace): + return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields) + + workflow_run_for_list_fields = { "id": fields.String, "version": fields.String, diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py new file mode 100644 index 0000000000..66b57ac661 --- /dev/null +++ b/api/libs/archive_storage.py @@ -0,0 +1,353 @@ +""" +Archive Storage Client for S3-compatible storage. + +This module provides a dedicated storage client for archiving or exporting logs +to S3-compatible object storage. +""" + +import base64 +import datetime +import hashlib +import logging +from collections.abc import Generator +from typing import Any, cast + +import boto3 +import orjson +from botocore.client import Config +from botocore.exceptions import ClientError + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class ArchiveStorageError(Exception): + """Base exception for archive storage operations.""" + + pass + + +class ArchiveStorageNotConfiguredError(ArchiveStorageError): + """Raised when archive storage is not properly configured.""" + + pass + + +class ArchiveStorage: + """ + S3-compatible storage client for archiving or exporting. + + This client provides methods for storing and retrieving archived data in JSONL format. + """ + + def __init__(self, bucket: str): + if not dify_config.ARCHIVE_STORAGE_ENABLED: + raise ArchiveStorageNotConfiguredError("Archive storage is not enabled") + + if not bucket: + raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured") + if not all( + [ + dify_config.ARCHIVE_STORAGE_ENDPOINT, + bucket, + dify_config.ARCHIVE_STORAGE_ACCESS_KEY, + dify_config.ARCHIVE_STORAGE_SECRET_KEY, + ] + ): + raise ArchiveStorageNotConfiguredError( + "Archive storage configuration is incomplete. " + "Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, " + "ARCHIVE_STORAGE_SECRET_KEY, and a bucket name" + ) + + self.bucket = bucket + self.client = boto3.client( + "s3", + endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT, + aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY, + aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY, + region_name=dify_config.ARCHIVE_STORAGE_REGION, + config=Config( + s3={"addressing_style": "path"}, + max_pool_connections=64, + ), + ) + + # Verify bucket accessibility + try: + self.client.head_bucket(Bucket=self.bucket) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "404": + raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist") + elif error_code == "403": + raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'") + else: + raise ArchiveStorageError(f"Failed to access archive bucket: {e}") + + def put_object(self, key: str, data: bytes) -> str: + """ + Upload an object to the archive storage. + + Args: + key: Object key (path) within the bucket + data: Binary data to upload + + Returns: + MD5 checksum of the uploaded data + + Raises: + ArchiveStorageError: If upload fails + """ + checksum = hashlib.md5(data).hexdigest() + try: + response = self.client.put_object( + Bucket=self.bucket, + Key=key, + Body=data, + ContentMD5=self._content_md5(data), + ) + etag = response.get("ETag") + if not etag: + raise ArchiveStorageError(f"Missing ETag for '{key}'") + normalized_etag = etag.strip('"') + if normalized_etag != checksum: + raise ArchiveStorageError(f"ETag mismatch for '{key}': expected={checksum}, actual={normalized_etag}") + logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum) + return checksum + except ClientError as e: + raise ArchiveStorageError(f"Failed to upload object '{key}': {e}") + + def get_object(self, key: str) -> bytes: + """ + Download an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Returns: + Binary data of the object + + Raises: + ArchiveStorageError: If download fails + FileNotFoundError: If object does not exist + """ + try: + response = self.client.get_object(Bucket=self.bucket, Key=key) + return response["Body"].read() + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "NoSuchKey": + raise FileNotFoundError(f"Archive object not found: {key}") + raise ArchiveStorageError(f"Failed to download object '{key}': {e}") + + def get_object_stream(self, key: str) -> Generator[bytes, None, None]: + """ + Stream an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Yields: + Chunks of binary data + + Raises: + ArchiveStorageError: If download fails + FileNotFoundError: If object does not exist + """ + try: + response = self.client.get_object(Bucket=self.bucket, Key=key) + yield from response["Body"].iter_chunks() + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "NoSuchKey": + raise FileNotFoundError(f"Archive object not found: {key}") + raise ArchiveStorageError(f"Failed to stream object '{key}': {e}") + + def object_exists(self, key: str) -> bool: + """ + Check if an object exists in the archive storage. + + Args: + key: Object key (path) within the bucket + + Returns: + True if object exists, False otherwise + """ + try: + self.client.head_object(Bucket=self.bucket, Key=key) + return True + except ClientError: + return False + + def delete_object(self, key: str) -> None: + """ + Delete an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Raises: + ArchiveStorageError: If deletion fails + """ + try: + self.client.delete_object(Bucket=self.bucket, Key=key) + logger.debug("Deleted object: %s", key) + except ClientError as e: + raise ArchiveStorageError(f"Failed to delete object '{key}': {e}") + + def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str: + """ + Generate a pre-signed URL for downloading an object. + + Args: + key: Object key (path) within the bucket + expires_in: URL validity duration in seconds (default: 1 hour) + + Returns: + Pre-signed URL string. + + Raises: + ArchiveStorageError: If generation fails + """ + try: + return self.client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": self.bucket, "Key": key}, + ExpiresIn=expires_in, + ) + except ClientError as e: + raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}") + + def list_objects(self, prefix: str) -> list[str]: + """ + List objects under a given prefix. + + Args: + prefix: Object key prefix to filter by + + Returns: + List of object keys matching the prefix + """ + keys = [] + paginator = self.client.get_paginator("list_objects_v2") + + try: + for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): + for obj in page.get("Contents", []): + keys.append(obj["Key"]) + except ClientError as e: + raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}") + + return keys + + @staticmethod + def _content_md5(data: bytes) -> str: + """Calculate base64-encoded MD5 for Content-MD5 header.""" + return base64.b64encode(hashlib.md5(data).digest()).decode() + + @staticmethod + def serialize_to_jsonl(records: list[dict[str, Any]]) -> bytes: + """ + Serialize records to JSONL format. + + Args: + records: List of dictionaries to serialize + + Returns: + JSONL bytes + """ + lines = [] + for record in records: + serialized = ArchiveStorage._serialize_record(record) + lines.append(orjson.dumps(serialized)) + + jsonl_content = b"\n".join(lines) + if jsonl_content: + jsonl_content += b"\n" + + return jsonl_content + + @staticmethod + def deserialize_from_jsonl(data: bytes) -> list[dict[str, Any]]: + """ + Deserialize JSONL data to records. + + Args: + data: JSONL bytes + + Returns: + List of dictionaries + """ + records = [] + + for line in data.splitlines(): + if line: + records.append(orjson.loads(line)) + + return records + + @staticmethod + def _serialize_record(record: dict[str, Any]) -> dict[str, Any]: + """Serialize a single record, converting special types.""" + + def _serialize(item: Any) -> Any: + if isinstance(item, datetime.datetime): + return item.isoformat() + if isinstance(item, dict): + return {key: _serialize(value) for key, value in item.items()} + if isinstance(item, list): + return [_serialize(value) for value in item] + return item + + return cast(dict[str, Any], _serialize(record)) + + @staticmethod + def compute_checksum(data: bytes) -> str: + """Compute MD5 checksum of data.""" + return hashlib.md5(data).hexdigest() + + +# Singleton instance (lazy initialization) +_archive_storage: ArchiveStorage | None = None +_export_storage: ArchiveStorage | None = None + + +def get_archive_storage() -> ArchiveStorage: + """ + Get the archive storage singleton instance. + + Returns: + ArchiveStorage instance + + Raises: + ArchiveStorageNotConfiguredError: If archive storage is not configured + """ + global _archive_storage + if _archive_storage is None: + archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET + if not archive_bucket: + raise ArchiveStorageNotConfiguredError( + "Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET" + ) + _archive_storage = ArchiveStorage(bucket=archive_bucket) + return _archive_storage + + +def get_export_storage() -> ArchiveStorage: + """ + Get the export storage singleton instance. + + Returns: + ArchiveStorage instance + """ + global _export_storage + if _export_storage is None: + export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET + if not export_bucket: + raise ArchiveStorageNotConfiguredError( + "Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET" + ) + _export_storage = ArchiveStorage(bucket=export_bucket) + return _export_storage diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index 5bbf0c79a3..d4cb3e9971 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -2,6 +2,8 @@ Broadcast channel for Pub/Sub messaging. """ +from __future__ import annotations + import types from abc import abstractmethod from collections.abc import Iterator @@ -129,6 +131,6 @@ class BroadcastChannel(Protocol): """ @abstractmethod - def topic(self, topic: str) -> "Topic": + def topic(self, topic: str) -> Topic: """topic returns a `Topic` instance for the given topic name.""" ... diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 1fc3db8156..5bb4f579c1 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis @@ -20,7 +22,7 @@ class BroadcastChannel: ): self._client = redis_client - def topic(self, topic: str) -> "Topic": + def topic(self, topic: str) -> Topic: return Topic(self._client, topic) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 991dee64af..9e8ab90e8e 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis @@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel: ): self._client = redis_client - def topic(self, topic: str) -> "ShardedTopic": + def topic(self, topic: str) -> ShardedTopic: return ShardedTopic(self._client, topic) diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index ff74ccbe8e..0828cf80bf 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and eliminates the need for repetitive language switching logic. """ +from __future__ import annotations + from dataclasses import dataclass from enum import StrEnum, auto from typing import Any, Protocol @@ -53,7 +55,7 @@ class EmailLanguage(StrEnum): ZH_HANS = "zh-Hans" @classmethod - def from_language_code(cls, language_code: str) -> "EmailLanguage": + def from_language_code(cls, language_code: str) -> EmailLanguage: """Convert a language code to EmailLanguage with fallback to English.""" if language_code == "zh-Hans": return cls.ZH_HANS diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 61a90ee4a9..e8592407c3 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,5 +1,4 @@ import re -import sys from collections.abc import Mapping from typing import Any @@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api): data.setdefault("code", "unknown") data.setdefault("status", status_code) - # Log stack - exc_info: Any = sys.exc_info() - if exc_info[1] is None: - exc_info = (None, None, None) - current_app.log_exception(exc_info) + # Note: Exception logging is handled by Flask/Flask-RESTX framework automatically + # Explicit log_exception call removed to avoid duplicate log entries return data, status_code diff --git a/api/libs/helper.py b/api/libs/helper.py index 26dd6fdfa6..94e1770810 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -32,6 +32,38 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def escape_like_pattern(pattern: str) -> str: + """ + Escape special characters in a string for safe use in SQL LIKE patterns. + + This function escapes the special characters used in SQL LIKE patterns: + - Backslash (\\) -> \\ + - Percent (%) -> \\% + - Underscore (_) -> \\_ + + The escaped pattern can then be safely used in SQL LIKE queries with the + ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards. + + Args: + pattern: The string pattern to escape + + Returns: + Escaped string safe for use in SQL LIKE queries + + Examples: + >>> escape_like_pattern("50% discount") + '50\\% discount' + >>> escape_like_pattern("test_data") + 'test\\_data' + >>> escape_like_pattern("path\\to\\file") + 'path\\\\to\\\\file' + """ + if not pattern: + return pattern + # Escape backslash first, then percent and underscore + return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: """ Extract tenant_id from Account or EndUser object. diff --git a/api/libs/login.py b/api/libs/login.py index 4b8ee2d1f8..73caa492fe 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Callable from functools import wraps -from typing import Any +from typing import TYPE_CHECKING, Any from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS @@ -9,7 +11,9 @@ from werkzeug.local import LocalProxy from configs import dify_config from libs.token import check_csrf_token from models import Account -from models.model import EndUser + +if TYPE_CHECKING: + from models.model import EndUser def current_account_with_tenant(): diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 4044c6f7ed..6f82f1440a 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -3,6 +3,8 @@ import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from configs import dify_config + logger = logging.getLogger(__name__) @@ -19,20 +21,21 @@ class SMTPClient: self.opportunistic_tls = opportunistic_tls def send(self, mail: dict): - smtp = None + smtp: smtplib.SMTP | None = None + local_host = dify_config.SMTP_LOCAL_HOSTNAME try: - if self.use_tls: - if self.opportunistic_tls: - smtp = smtplib.SMTP(self.server, self.port, timeout=10) - # Send EHLO command with the HELO domain name as the server address - smtp.ehlo(self.server) - smtp.starttls() - # Resend EHLO command to identify the TLS session - smtp.ehlo(self.server) - else: - smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) + if self.use_tls and not self.opportunistic_tls: + # SMTP with SSL (implicit TLS) + smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10, local_hostname=local_host) else: - smtp = smtplib.SMTP(self.server, self.port, timeout=10) + # Plain SMTP or SMTP with STARTTLS (explicit TLS) + smtp = smtplib.SMTP(self.server, self.port, timeout=10, local_hostname=local_host) + + assert smtp is not None + if self.use_tls and self.opportunistic_tls: + smtp.ehlo(self.server) + smtp.starttls() + smtp.ehlo(self.server) # Only authenticate if both username and password are non-empty if self.username and self.password and self.username.strip() and self.password.strip(): diff --git a/api/libs/workspace_permission.py b/api/libs/workspace_permission.py new file mode 100644 index 0000000000..dd42a7facf --- /dev/null +++ b/api/libs/workspace_permission.py @@ -0,0 +1,74 @@ +""" +Workspace permission helper functions. + +These helpers check both billing/plan level and workspace-specific policy level permissions. +Checks are performed at two levels: +1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions) +2. Workspace policy level - via EnterpriseService (admin-configured per workspace) +""" + +import logging + +from werkzeug.exceptions import Forbidden + +from configs import dify_config +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService + +logger = logging.getLogger(__name__) + + +def check_workspace_member_invite_permission(workspace_id: str) -> None: + """ + Check if workspace allows member invitations at both billing and policy levels. + + Checks performed: + 1. Billing/plan level - For future expansion (currently no plan-level restriction) + 2. Enterprise policy level - Admin-configured workspace permission + + Args: + workspace_id: The workspace ID to check permissions for + + Raises: + Forbidden: If either billing plan or workspace policy prohibits member invitations + """ + # Check enterprise workspace policy level (only if enterprise enabled) + if dify_config.ENTERPRISE_ENABLED: + try: + permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id) + if not permission.allow_member_invite: + raise Forbidden("Workspace policy prohibits member invitations") + except Forbidden: + raise + except Exception: + logger.exception("Failed to check workspace invite permission for %s", workspace_id) + + +def check_workspace_owner_transfer_permission(workspace_id: str) -> None: + """ + Check if workspace allows owner transfer at both billing and policy levels. + + Checks performed: + 1. Billing/plan level - SANDBOX plan blocks owner transfer + 2. Enterprise policy level - Admin-configured workspace permission + + Args: + workspace_id: The workspace ID to check permissions for + + Raises: + Forbidden: If either billing plan or workspace policy prohibits ownership transfer + """ + features = FeatureService.get_features(workspace_id) + if not features.is_allow_transfer_workspace: + raise Forbidden("Your current plan does not allow workspace ownership transfer") + + # Check enterprise workspace policy level (only if enterprise enabled) + if dify_config.ENTERPRISE_ENABLED: + try: + permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id) + if not permission.allow_owner_transfer: + raise Forbidden("Workspace policy prohibits ownership transfer") + except Forbidden: + raise + except Exception: + logger.exception("Failed to check workspace transfer permission for %s", workspace_id) diff --git a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py index 17ed067d81..657d28f896 100644 --- a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py +++ b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '00bacef91f18' down_revision = '8ec536f3c800' @@ -23,31 +20,17 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description', sa.Text(), nullable=False)) - batch_op.drop_column('description_str') - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False)) - batch_op.drop_column('description_str') + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False)) + batch_op.drop_column('description_str') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False)) - batch_op.drop_column('description') - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False)) - batch_op.drop_column('description') + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False)) + batch_op.drop_column('description') # ### end Alembic commands ### diff --git a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py index ed70bf5d08..912d9dbfa4 100644 --- a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py +++ b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py @@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '114eed84c228' down_revision = 'c71211c8f604' @@ -32,13 +28,7 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: - batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False)) - else: - with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: - batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False)) + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py index 509bd5d0e8..0ca905129d 100644 --- a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py +++ b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '161cadc1af8d' down_revision = '7e6a8693e07a' @@ -23,16 +20,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: - # Step 1: Add column without NOT NULL constraint - op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False)) - else: - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: - # Step 1: Add column without NOT NULL constraint - op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False)) + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + # Step 1: Add column without NOT NULL constraint + op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py index 0767b725f6..be1b42f883 100644 --- a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py +++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py @@ -9,11 +9,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - # revision identifiers, used by Alembic. revision = '6af6a521a53e' down_revision = 'd57ba9ebb251' @@ -23,58 +18,30 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('document_id', - existing_type=sa.UUID(), - nullable=True) - batch_op.alter_column('data_source_type', - existing_type=sa.TEXT(), - nullable=True) - batch_op.alter_column('segment_id', - existing_type=sa.UUID(), - nullable=True) - else: - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('document_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('data_source_type', - existing_type=models.types.LongText(), - nullable=True) - batch_op.alter_column('segment_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('document_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('data_source_type', + existing_type=models.types.LongText(), + nullable=True) + batch_op.alter_column('segment_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('segment_id', - existing_type=sa.UUID(), - nullable=False) - batch_op.alter_column('data_source_type', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('document_id', - existing_type=sa.UUID(), - nullable=False) - else: - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('segment_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.alter_column('data_source_type', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('document_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('segment_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.alter_column('data_source_type', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('document_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py index a749c8bddf..5d12419bf7 100644 --- a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py +++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py @@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198 from alembic import op import models as models import sqlalchemy as sa -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'd3f6769a94a3' diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py index 45842295ea..a49d6a52f6 100644 --- a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py +++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py @@ -28,85 +28,45 @@ def upgrade(): op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") - if _is_pg(conn): - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) - else: - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) - - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) - - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) - - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) - - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py index fdd8984029..8a36c9c4a5 100644 --- a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py +++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py @@ -49,57 +49,33 @@ def upgrade(): op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL") op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL") op.execute("UPDATE workflows SET features = '' WHERE features IS NULL") - if _is_pg(conn): - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('graph', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('features', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('updated_at', - existing_type=postgresql.TIMESTAMP(), - nullable=False) - else: - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('graph', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('features', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('updated_at', - existing_type=sa.TIMESTAMP(), - nullable=False) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('graph', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('features', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('updated_at', + existing_type=sa.TIMESTAMP(), + nullable=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('updated_at', - existing_type=postgresql.TIMESTAMP(), - nullable=True) - batch_op.alter_column('features', - existing_type=sa.TEXT(), - nullable=True) - batch_op.alter_column('graph', - existing_type=sa.TEXT(), - nullable=True) - else: - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('updated_at', - existing_type=sa.TIMESTAMP(), - nullable=True) - batch_op.alter_column('features', - existing_type=models.types.LongText(), - nullable=True) - batch_op.alter_column('graph', - existing_type=models.types.LongText(), - nullable=True) + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('updated_at', + existing_type=sa.TIMESTAMP(), + nullable=True) + batch_op.alter_column('features', + existing_type=models.types.LongText(), + nullable=True) + batch_op.alter_column('graph', + existing_type=models.types.LongText(), + nullable=True) if _is_pg(conn): with op.batch_alter_table('messages', schema=None) as batch_op: diff --git a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py index 16ca902726..1fc4a64df1 100644 --- a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py +++ b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py @@ -86,57 +86,30 @@ def upgrade(): def migrate_existing_provider_models_data(): """migrate provider_models table data to provider_model_credentials""" - conn = op.get_bind() - # Define table structure for data manipulation - if _is_pg(conn): - provider_models_table = table('provider_models', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('encrypted_config', sa.Text()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()), - column('credential_id', models.types.StringUUID()), - ) - else: - provider_models_table = table('provider_models', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('encrypted_config', models.types.LongText()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()), - column('credential_id', models.types.StringUUID()), - ) + # Define table structure for data manipulatio + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('encrypted_config', models.types.LongText()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) - if _is_pg(conn): - provider_model_credentials_table = table('provider_model_credentials', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('credential_name', sa.String()), - column('encrypted_config', sa.Text()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()) - ) - else: - provider_model_credentials_table = table('provider_model_credentials', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('credential_name', sa.String()), - column('encrypted_config', models.types.LongText()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()) - ) + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', models.types.LongText()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) # Get database connection @@ -183,14 +156,8 @@ def migrate_existing_provider_models_data(): def downgrade(): # Re-add encrypted_config column to provider_models table - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('provider_models', schema=None) as batch_op: - batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('provider_models', schema=None) as batch_op: - batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True)) + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True)) if not context.is_offline_mode(): # Migrate data back from provider_model_credentials to provider_models diff --git a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py index 75b4d61173..79fe9d9bba 100644 --- a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py +++ b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py @@ -8,7 +8,6 @@ Create Date: 2025-08-20 17:47:17.015695 from alembic import op import models as models import sqlalchemy as sa -from libs.uuid_utils import uuidv7 def _is_pg(conn): diff --git a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py index 4f472fe4b4..cf2b973d2d 100644 --- a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py +++ b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py @@ -9,8 +9,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" import sqlalchemy as sa @@ -23,12 +21,7 @@ depends_on = None def upgrade(): # Add encrypted_headers column to tool_mcp_providers table - conn = op.get_bind() - - if _is_pg(conn): - op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True)) - else: - op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True)) + op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True)) def downgrade(): diff --git a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py index 8eac0dee10..bad516dcac 100644 --- a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py +++ b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py @@ -44,6 +44,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'), sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx') ) + if _is_pg(conn): op.create_table('datasource_oauth_tenant_params', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -70,6 +71,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique') ) + if _is_pg(conn): op.create_table('datasource_providers', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -104,6 +106,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name') ) + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False) @@ -133,6 +136,7 @@ def upgrade(): sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey') ) + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False) @@ -174,6 +178,7 @@ def upgrade(): sa.Column('updated_by', models.types.StringUUID(), nullable=True), sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey') ) + if _is_pg(conn): op.create_table('pipeline_customized_templates', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -193,7 +198,6 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') ) else: - # MySQL: Use compatible syntax op.create_table('pipeline_customized_templates', sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), @@ -211,6 +215,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') ) + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False) @@ -236,6 +241,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey') ) + if _is_pg(conn): op.create_table('pipelines', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -266,6 +272,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_pkey') ) + if _is_pg(conn): op.create_table('workflow_draft_variable_files', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -292,6 +299,7 @@ def upgrade(): sa.Column('value_type', sa.String(20), nullable=False), sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey')) ) + if _is_pg(conn): op.create_table('workflow_node_execution_offload', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -316,6 +324,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')), sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key')) ) + if _is_pg(conn): with op.batch_alter_table('datasets', schema=None) as batch_op: batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True)) @@ -342,6 +351,7 @@ def upgrade(): comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',) ) batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False) + if _is_pg(conn): with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False)) diff --git a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py index 0776ab0818..ec0cfbd11d 100644 --- a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py +++ b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py @@ -9,8 +9,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" import sqlalchemy as sa @@ -33,15 +31,9 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: - batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False)) - batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True)) - else: - with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: - batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False)) - batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True)) + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False)) + batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py index 627219cc4b..12905b3674 100644 --- a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py +++ b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py @@ -9,7 +9,6 @@ Create Date: 2025-10-22 16:11:31.805407 from alembic import op import models as models import sqlalchemy as sa -from libs.uuid_utils import uuidv7 def _is_pg(conn): return conn.dialect.name == "postgresql" diff --git a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py index 9641a15c89..c27c1058d1 100644 --- a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py +++ b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py @@ -105,6 +105,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client') ) + if _is_pg(conn): op.create_table('trigger_subscriptions', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), @@ -143,6 +144,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'), sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider') ) + with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op: batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True) batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False) @@ -176,6 +178,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'), sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription') ) + with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op: batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False) @@ -207,6 +210,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'), sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node') ) + with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op: batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False) @@ -264,6 +268,7 @@ def upgrade(): sa.Column('finished_at', sa.DateTime(), nullable=True), sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey') ) + with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op: batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False) batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False) @@ -299,6 +304,7 @@ def upgrade(): sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'), sa.UniqueConstraint('webhook_id', name='uniq_webhook_id') ) + with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op: batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False) diff --git a/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py new file mode 100644 index 0000000000..624be1d073 --- /dev/null +++ b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py @@ -0,0 +1,60 @@ +"""make message annotation question not nullable + +Revision ID: 9e6fa5cbcd80 +Revises: 03f8dcbc611e +Create Date: 2025-11-06 16:03:54.549378 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9e6fa5cbcd80' +down_revision = '288345cd01d1' +branch_labels = None +depends_on = None + + +def upgrade(): + bind = op.get_bind() + message_annotations = sa.table( + "message_annotations", + sa.column("id", sa.String), + sa.column("message_id", sa.String), + sa.column("question", sa.Text), + ) + messages = sa.table( + "messages", + sa.column("id", sa.String), + sa.column("query", sa.Text), + ) + update_question_from_message = ( + sa.update(message_annotations) + .where( + sa.and_( + message_annotations.c.question.is_(None), + message_annotations.c.message_id.isnot(None), + ) + ) + .values( + question=sa.select(sa.func.coalesce(messages.c.query, "")) + .where(messages.c.id == message_annotations.c.message_id) + .scalar_subquery() + ) + ) + bind.execute(update_question_from_message) + + fill_remaining_questions = ( + sa.update(message_annotations) + .where(message_annotations.c.question.is_(None)) + .values(question="") + ) + bind.execute(fill_remaining_questions) + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=False) + + +def downgrade(): + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=True) diff --git a/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py new file mode 100644 index 0000000000..e89fcee7e5 --- /dev/null +++ b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py @@ -0,0 +1,46 @@ +"""add credit pool + +Revision ID: 7df29de0f6be +Revises: 03ea244985ce +Create Date: 2025-12-25 10:39:15.139304 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '7df29de0f6be' +down_revision = '03ea244985ce' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tenant_credit_pools', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False), + sa.Column('quota_limit', sa.BigInteger(), nullable=False), + sa.Column('quota_used', sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey') + ) + with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: + batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False) + batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + + with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: + batch_op.drop_index('tenant_credit_pool_tenant_id_idx') + batch_op.drop_index('tenant_credit_pool_pool_type_idx') + + op.drop_table('tenant_credit_pools') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py new file mode 100644 index 0000000000..7e0cc8ec9d --- /dev/null +++ b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py @@ -0,0 +1,30 @@ +"""add workflow_run_created_at_id_idx + +Revision ID: 905527cc8fd3 +Revises: 7df29de0f6be +Create Date: 2025-01-09 16:30:02.462084 + +""" +from alembic import op +import models as models + +# revision identifiers, used by Alembic. +revision = '905527cc8fd3' +down_revision = '7df29de0f6be' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_created_at_id_idx') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py new file mode 100644 index 0000000000..758369ba99 --- /dev/null +++ b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py @@ -0,0 +1,33 @@ +"""feat: add created_at id index to messages + +Revision ID: 3334862ee907 +Revises: 905527cc8fd3 +Create Date: 2026-01-12 17:29:44.846544 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3334862ee907' +down_revision = '905527cc8fd3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_created_at_id_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py new file mode 100644 index 0000000000..2e1af0c83f --- /dev/null +++ b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py @@ -0,0 +1,35 @@ +"""change workflow node execution workflow_run index + +Revision ID: 288345cd01d1 +Revises: 3334862ee907 +Create Date: 2026-01-16 17:15:00.000000 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "288345cd01d1" +down_revision = "3334862ee907" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op: + batch_op.drop_index("workflow_node_execution_workflow_run_idx") + batch_op.create_index( + "workflow_node_execution_workflow_run_id_idx", + ["workflow_run_id"], + unique=False, + ) + + +def downgrade(): + with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op: + batch_op.drop_index("workflow_node_execution_workflow_run_id_idx") + batch_op.create_index( + "workflow_node_execution_workflow_run_idx", + ["tenant_id", "app_id", "workflow_id", "triggered_from", "workflow_run_id"], + unique=False, + ) diff --git a/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py new file mode 100644 index 0000000000..b99ca04e3f --- /dev/null +++ b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py @@ -0,0 +1,73 @@ +"""add table explore banner and trial + +Revision ID: f9f6d18a37f9 +Revises: 9e6fa5cbcd80 +Create Date: 2026-01-017 11:10:18.079355 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f9f6d18a37f9' +down_revision = '9e6fa5cbcd80' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('account_trial_app_records', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('account_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('count', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'), + sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record') + ) + with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op: + batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False) + batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False) + + op.create_table('exporle_banners', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('content', sa.JSON(), nullable=False), + sa.Column('link', sa.String(length=255), nullable=False), + sa.Column('sort', sa.Integer(), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False), + sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey') + ) + op.create_table('trial_apps', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('trial_limit', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('id', name='trial_app_pkey'), + sa.UniqueConstraint('app_id', name='unique_trail_app_id') + ) + with op.batch_alter_table('trial_apps', schema=None) as batch_op: + batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('trial_apps', schema=None) as batch_op: + batch_op.drop_index('trial_app_tenant_id_idx') + batch_op.drop_index('trial_app_app_id_idx') + + op.drop_table('trial_apps') + op.drop_table('exporle_banners') + with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op: + batch_op.drop_index('account_trial_app_record_app_id_idx') + batch_op.drop_index('account_trial_app_record_account_id_idx') + + op.drop_table('account_trial_app_records') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py new file mode 100644 index 0000000000..5e7298af54 --- /dev/null +++ b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py @@ -0,0 +1,95 @@ +"""create workflow_archive_logs + +Revision ID: 9d77545f524e +Revises: f9f6d18a37f9 +Create Date: 2026-01-06 17:18:56.292479 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + +# revision identifiers, used by Alembic. +revision = '9d77545f524e' +down_revision = 'f9f6d18a37f9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if _is_pg(conn): + op.create_table('workflow_archive_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), + sa.Column('log_id', models.types.StringUUID(), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('log_created_at', sa.DateTime(), nullable=True), + sa.Column('log_created_from', sa.String(length=255), nullable=True), + sa.Column('run_version', sa.String(length=255), nullable=False), + sa.Column('run_status', sa.String(length=255), nullable=False), + sa.Column('run_triggered_from', sa.String(length=255), nullable=False), + sa.Column('run_error', models.types.LongText(), nullable=True), + sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('run_created_at', sa.DateTime(), nullable=False), + sa.Column('run_finished_at', sa.DateTime(), nullable=True), + sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('trigger_metadata', models.types.LongText(), nullable=True), + sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey') + ) + else: + op.create_table('workflow_archive_logs', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('log_id', models.types.StringUUID(), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('log_created_at', sa.DateTime(), nullable=True), + sa.Column('log_created_from', sa.String(length=255), nullable=True), + sa.Column('run_version', sa.String(length=255), nullable=False), + sa.Column('run_status', sa.String(length=255), nullable=False), + sa.Column('run_triggered_from', sa.String(length=255), nullable=False), + sa.Column('run_error', models.types.LongText(), nullable=True), + sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('run_created_at', sa.DateTime(), nullable=False), + sa.Column('run_finished_at', sa.DateTime(), nullable=True), + sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('trigger_metadata', models.types.LongText(), nullable=True), + sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey') + ) + with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op: + batch_op.create_index('workflow_archive_log_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_archive_log_run_created_at_idx', ['run_created_at'], unique=False) + batch_op.create_index('workflow_archive_log_workflow_run_id_idx', ['workflow_run_id'], unique=False) + + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_archive_log_workflow_run_id_idx') + batch_op.drop_index('workflow_archive_log_run_created_at_idx') + batch_op.drop_index('workflow_archive_log_app_idx') + + op.drop_table('workflow_archive_logs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py index fae506906b..127ffd5599 100644 --- a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py +++ b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '23db93619b9d' down_revision = '8ae9bc661daa' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True)) + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py index 2676ef0b94..31829d8e58 100644 --- a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py +++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py @@ -62,14 +62,8 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True)) with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: batch_op.drop_index('app_annotation_settings_app_idx') diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py index 3362a3a09f..07a8cd86b1 100644 --- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py +++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py @@ -11,9 +11,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '2a3aebbbf4bb' down_revision = 'c031d46af369' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('apps', schema=None) as batch_op: - batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('apps', schema=None) as batch_op: - batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True)) + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py index 40bd727f66..211b2d8882 100644 --- a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py +++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py @@ -7,14 +7,10 @@ Create Date: 2023-09-22 15:41:01.243183 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '2e9819ca5b28' down_revision = 'ab23c11305d4' @@ -24,35 +20,19 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True)) - batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) - batch_op.drop_column('dataset_id') - else: - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True)) - batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) - batch_op.drop_column('dataset_id') + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True)) + batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) + batch_op.drop_column('dataset_id') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True)) - batch_op.drop_index('api_token_tenant_idx') - batch_op.drop_column('tenant_id') - else: - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True)) - batch_op.drop_index('api_token_tenant_idx') - batch_op.drop_column('tenant_id') + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True)) + batch_op.drop_index('api_token_tenant_idx') + batch_op.drop_column('tenant_id') # ### end Alembic commands ### diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py index 76056a9460..3491c85e2f 100644 --- a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py +++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py @@ -7,14 +7,10 @@ Create Date: 2024-03-07 08:30:29.133614 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '42e85ed5564d' down_revision = 'f9107f83abab' @@ -24,59 +20,31 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('app_model_config_id', - existing_type=postgresql.UUID(), - nullable=True) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=True) - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=True) - else: - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('app_model_config_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=True) - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('app_model_config_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('app_model_config_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('app_model_config_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('app_model_config_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py index ef066587b7..8537a87233 100644 --- a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py +++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py @@ -6,14 +6,10 @@ Create Date: 2024-01-12 03:42:27.362415 """ from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '4829e54d2fee' down_revision = '114eed84c228' @@ -23,39 +19,21 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - # PostgreSQL: Keep original syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - # MySQL: Use compatible syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=models.types.StringUUID(), - nullable=True) + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - # PostgreSQL: Keep original syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - # MySQL: Use compatible syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=models.types.StringUUID(), - nullable=False) + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py index b080e7680b..22405e3cc8 100644 --- a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -6,14 +6,10 @@ Create Date: 2024-03-14 04:54:56.679506 """ from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '563cf8bf777b' down_revision = 'b5429b71023c' @@ -23,35 +19,19 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py index 1ace8ea5a0..01d7d5ba21 100644 --- a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py +++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py @@ -48,12 +48,9 @@ def upgrade(): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('datasets', schema=None) as batch_op: - batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True)) - else: - with op.batch_alter_table('datasets', schema=None) as batch_op: - batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py index 457338ef42..0faa48f535 100644 --- a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py +++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '714aafe25d39' down_revision = 'f2a6fc85e260' @@ -23,16 +20,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False)) - batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False)) - else: - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False)) - batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False)) + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False)) + batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py index 7bcd1a1be3..aa7b4a21e2 100644 --- a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py +++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '77e83833755c' down_revision = '6dcb43972bdc' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py index 3c0aa082d5..34a17697d3 100644 --- a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py +++ b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py @@ -27,7 +27,6 @@ def upgrade(): conn = op.get_bind() if _is_pg(conn): - # PostgreSQL: Keep original syntax op.create_table('tool_providers', sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('tenant_id', postgresql.UUID(), nullable=False), @@ -40,7 +39,6 @@ def upgrade(): sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') ) else: - # MySQL: Use compatible syntax op.create_table('tool_providers', sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), @@ -52,12 +50,9 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') ) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py index beea90b384..884839c010 100644 --- a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py +++ b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '88072f0caa04' down_revision = '246ba09cbbdb' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tenants', schema=None) as batch_op: - batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('tenants', schema=None) as batch_op: - batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True)) + with op.batch_alter_table('tenants', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/89c7899ca936_.py b/api/migrations/versions/89c7899ca936_.py index 2420710e74..d26f1e82d6 100644 --- a/api/migrations/versions/89c7899ca936_.py +++ b/api/migrations/versions/89c7899ca936_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '89c7899ca936' down_revision = '187385f442fc' @@ -23,39 +20,21 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.VARCHAR(length=255), - type_=sa.Text(), - existing_nullable=True) - else: - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - existing_nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('description', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + existing_nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.Text(), - type_=sa.VARCHAR(length=255), - existing_nullable=True) - else: - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - existing_nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('description', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + existing_nullable=True) # ### end Alembic commands ### diff --git a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py index 111e81240b..6022ea2c20 100644 --- a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py +++ b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '8ec536f3c800' down_revision = 'ad472b61a054' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False)) - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False)) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py index 1c1c6cacbb..9d6d40114d 100644 --- a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py +++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py @@ -57,12 +57,9 @@ def upgrade(): batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False) batch_op.create_index('message_file_message_idx', ['message_id'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True)) if _is_pg(conn): with op.batch_alter_table('upload_files', schema=None) as batch_op: diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py index 5d29d354f3..0b3f92a12e 100644 --- a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py +++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py @@ -24,7 +24,6 @@ def upgrade(): conn = op.get_bind() if _is_pg(conn): - # PostgreSQL: Keep original syntax with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False)) batch_op.drop_index('pinned_conversation_conversation_idx') @@ -35,7 +34,6 @@ def upgrade(): batch_op.drop_index('saved_message_message_idx') batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False) else: - # MySQL: Use compatible syntax with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'"), nullable=False)) batch_op.drop_index('pinned_conversation_conversation_idx') diff --git a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py index 616cb2f163..c8747a51f7 100644 --- a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py +++ b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'a5b56fb053ef' down_revision = 'd3d503a3471c' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py index 900ff78036..f56aeb7e66 100644 --- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'a9836e3baeee' down_revision = '968fff4c0ab9' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/b24be59fbb04_.py b/api/migrations/versions/b24be59fbb04_.py index b0a6d10d8c..ae91eaf1bc 100644 --- a/api/migrations/versions/b24be59fbb04_.py +++ b/api/migrations/versions/b24be59fbb04_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'b24be59fbb04' down_revision = 'de95f5c77138' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('text_to_speech', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py index 772395c25b..c02c24c23f 100644 --- a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py +++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'b3a09c049e8e' down_revision = '2e9819ca5b28' @@ -23,20 +20,11 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) - batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) - batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) + batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 76be794ff4..fe51d1c78d 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -7,7 +7,6 @@ Create Date: 2024-06-17 10:01:00.255189 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types diff --git a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py index 9e02ec5d84..36e934f0fc 100644 --- a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py +++ b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py @@ -54,12 +54,9 @@ def upgrade(): batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False) batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True)) if _is_pg(conn): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: @@ -68,54 +65,31 @@ def upgrade(): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'"), nullable=False)) - if _is_pg(conn): - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.add_column(sa.Column('question', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=True) - batch_op.alter_column('message_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('message_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('message_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - if _is_pg(conn): - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.alter_column('message_id', - existing_type=postgresql.UUID(), - nullable=False) - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=False) - batch_op.drop_column('hit_count') - batch_op.drop_column('question') - else: - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.alter_column('message_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.drop_column('hit_count') - batch_op.drop_column('question') + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.alter_column('message_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.drop_column('hit_count') + batch_op.drop_column('question') with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.drop_column('type') diff --git a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py index 02098e91c1..ac1c14e50c 100644 --- a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py +++ b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py @@ -12,9 +12,6 @@ from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'f2a6fc85e260' down_revision = '46976cc39132' @@ -24,16 +21,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False)) - batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) - else: - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False)) - batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False)) + batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index f648a60ace..1d5d604ba7 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -37,6 +37,7 @@ from .enums import ( from .execution_extra_content import ExecutionExtraContent, HumanInputContent from .human_input import HumanInputForm from .model import ( + AccountTrialAppRecord, ApiRequest, ApiToken, App, @@ -49,6 +50,7 @@ from .model import ( DatasetRetrieverResource, DifySetup, EndUser, + ExporleBanner, IconType, InstalledApp, Message, @@ -62,7 +64,9 @@ from .model import ( Site, Tag, TagBinding, + TenantCreditPool, TraceAppConfig, + TrialApp, UploadFile, ) from .oauth import DatasourceOauthParamConfig, DatasourceProvider @@ -101,6 +105,7 @@ from .workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, + WorkflowArchiveLog, WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom, @@ -115,6 +120,7 @@ __all__ = [ "Account", "AccountIntegrate", "AccountStatus", + "AccountTrialAppRecord", "ApiRequest", "ApiToken", "ApiToolProvider", @@ -152,6 +158,7 @@ __all__ = [ "Embedding", "EndUser", "ExecutionExtraContent", + "ExporleBanner", "ExternalKnowledgeApis", "ExternalKnowledgeBindings", "HumanInputContent", @@ -182,6 +189,7 @@ __all__ = [ "Tenant", "TenantAccountJoin", "TenantAccountRole", + "TenantCreditPool", "TenantDefaultModel", "TenantPreferredModelProvider", "TenantStatus", @@ -191,6 +199,7 @@ __all__ = [ "ToolLabelBinding", "ToolModelInvoke", "TraceAppConfig", + "TrialApp", "TriggerOAuthSystemClient", "TriggerOAuthTenantClient", "TriggerSubscription", @@ -200,6 +209,7 @@ __all__ = [ "Workflow", "WorkflowAppLog", "WorkflowAppLogCreatedFrom", + "WorkflowArchiveLog", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/account.py b/api/models/account.py index 420e6adc6c..f7a9c20026 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,7 +8,7 @@ from uuid import uuid4 import sqlalchemy as sa from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, validates from typing_extensions import deprecated from .base import TypeBase @@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase): role: TenantAccountRole | None = field(default=None, init=False) _current_tenant: "Tenant | None" = field(default=None, init=False) + @validates("status") + def _normalize_status(self, _key: str, value: str | AccountStatus) -> str: + if isinstance(value, AccountStatus): + return value.value + return value + @property def is_password_set(self): return self.password is not None diff --git a/api/models/dataset.py b/api/models/dataset.py index 445ac6086f..62f11b8c72 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1149,7 +1149,7 @@ class DatasetCollectionBinding(TypeBase): ) -class TidbAuthBinding(Base): +class TidbAuthBinding(TypeBase): __tablename__ = "tidb_auth_bindings" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), @@ -1158,7 +1158,13 @@ class TidbAuthBinding(Base): sa.Index("tidb_auth_bindings_created_at_idx", "created_at"), sa.Index("tidb_auth_bindings_status_idx", "status"), ) - id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) @@ -1166,7 +1172,9 @@ class TidbAuthBinding(Base): status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) class Whitelist(TypeBase): diff --git a/api/models/model.py b/api/models/model.py index 76efa0d989..39bb52d668 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re import uuid @@ -5,13 +7,13 @@ from collections.abc import Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast from uuid import uuid4 import sqlalchemy as sa from flask import request -from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text +from flask_login import UserMixin # type: ignore[import-untyped] +from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -54,7 +56,7 @@ class AppMode(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> "AppMode": + def value_of(cls, value: str) -> AppMode: """ Get value of given mode. @@ -70,6 +72,7 @@ class AppMode(StrEnum): class IconType(StrEnum): IMAGE = auto() EMOJI = auto() + LINK = auto() class App(Base): @@ -81,7 +84,7 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) @@ -120,19 +123,19 @@ class App(Base): return "" @property - def site(self) -> Optional["Site"]: + def site(self) -> Site | None: site = db.session.query(Site).where(Site.app_id == self.id).first() return site @property - def app_model_config(self) -> Optional["AppModelConfig"]: + def app_model_config(self) -> AppModelConfig | None: if self.app_model_config_id: return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() return None @property - def workflow(self) -> Optional["Workflow"]: + def workflow(self) -> Workflow | None: if self.workflow_id: from .workflow import Workflow @@ -287,7 +290,7 @@ class App(Base): return deleted_tools @property - def tags(self) -> list["Tag"]: + def tags(self) -> list[Tag]: tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) @@ -312,40 +315,48 @@ class App(Base): return None -class AppModelConfig(Base): +class AppModelConfig(TypeBase): __tablename__ = "app_model_configs" __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - app_id = mapped_column(StringUUID, nullable=False) - provider = mapped_column(String(255), nullable=True) - model_id = mapped_column(String(255), nullable=True) - configs = mapped_column(sa.JSON, nullable=True) - created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + configs: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True, default=None) + created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - opening_statement = mapped_column(LongText) - suggested_questions = mapped_column(LongText) - suggested_questions_after_answer = mapped_column(LongText) - speech_to_text = mapped_column(LongText) - text_to_speech = mapped_column(LongText) - more_like_this = mapped_column(LongText) - model = mapped_column(LongText) - user_input_form = mapped_column(LongText) - dataset_query_variable = mapped_column(String(255)) - pre_prompt = mapped_column(LongText) - agent_mode = mapped_column(LongText) - sensitive_word_avoidance = mapped_column(LongText) - retriever_resource = mapped_column(LongText) - prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'")) - chat_prompt_config = mapped_column(LongText) - completion_prompt_config = mapped_column(LongText) - dataset_configs = mapped_column(LongText) - external_data_tools = mapped_column(LongText) - file_upload = mapped_column(LongText) + updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, + ) + opening_statement: Mapped[str | None] = mapped_column(LongText, default=None) + suggested_questions: Mapped[str | None] = mapped_column(LongText, default=None) + suggested_questions_after_answer: Mapped[str | None] = mapped_column(LongText, default=None) + speech_to_text: Mapped[str | None] = mapped_column(LongText, default=None) + text_to_speech: Mapped[str | None] = mapped_column(LongText, default=None) + more_like_this: Mapped[str | None] = mapped_column(LongText, default=None) + model: Mapped[str | None] = mapped_column(LongText, default=None) + user_input_form: Mapped[str | None] = mapped_column(LongText, default=None) + dataset_query_variable: Mapped[str | None] = mapped_column(String(255), default=None) + pre_prompt: Mapped[str | None] = mapped_column(LongText, default=None) + agent_mode: Mapped[str | None] = mapped_column(LongText, default=None) + sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None) + retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None) + prompt_type: Mapped[str] = mapped_column( + String(255), nullable=False, server_default=sa.text("'simple'"), default="simple" + ) + chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) + completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) + dataset_configs: Mapped[str | None] = mapped_column(LongText, default=None) + external_data_tools: Mapped[str | None] = mapped_column(LongText, default=None) + file_upload: Mapped[str | None] = mapped_column(LongText, default=None) @property def app(self) -> App | None: @@ -600,6 +611,64 @@ class InstalledApp(TypeBase): return tenant +class TrialApp(Base): + __tablename__ = "trial_apps" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="trial_app_pkey"), + sa.Index("trial_app_app_id_idx", "app_id"), + sa.Index("trial_app_tenant_id_idx", "tenant_id"), + sa.UniqueConstraint("app_id", name="unique_trail_app_id"), + ) + + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + trial_limit = mapped_column(sa.Integer, nullable=False, default=3) + + @property + def app(self) -> App | None: + app = db.session.query(App).where(App.id == self.app_id).first() + return app + + +class AccountTrialAppRecord(Base): + __tablename__ = "account_trial_app_records" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"), + sa.Index("account_trial_app_record_account_id_idx", "account_id"), + sa.Index("account_trial_app_record_app_id_idx", "app_id"), + sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), + ) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + account_id = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + count = mapped_column(sa.Integer, nullable=False, default=0) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def app(self) -> App | None: + app = db.session.query(App).where(App.id == self.app_id).first() + return app + + @property + def user(self) -> Account | None: + user = db.session.query(Account).where(Account.id == self.account_id).first() + return user + + +class ExporleBanner(Base): + __tablename__ = "exporle_banners" + __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + content = mapped_column(sa.JSON, nullable=False) + link = mapped_column(String(255), nullable=False) + sort = mapped_column(sa.Integer, nullable=False) + status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying")) + + class OAuthProviderApp(TypeBase): """ Globally shared OAuth provider app information. @@ -749,8 +818,8 @@ class Conversation(Base): override_model_configs = json.loads(self.override_model_configs) if "model" in override_model_configs: - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(override_model_configs) + # where is app_id? + app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs) model_config = app_model_config.to_dict() else: model_config["configs"] = override_model_configs @@ -967,6 +1036,7 @@ class Message(Base): Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), Index("message_created_at_idx", "created_at"), Index("message_app_mode_idx", "app_mode"), + Index("message_created_at_id_idx", "created_at", "id"), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) @@ -1195,7 +1265,7 @@ class Message(Base): return json.loads(self.message_metadata) if self.message_metadata else {} @property - def agent_thoughts(self) -> list["MessageAgentThought"]: + def agent_thoughts(self) -> list[MessageAgentThought]: return ( db.session.query(MessageAgentThought) .where(MessageAgentThought.message_id == self.id) @@ -1316,7 +1386,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "Message": + def from_dict(cls, data: dict[str, Any]) -> Message: return cls( id=data["id"], app_id=data["app_id"], @@ -1429,15 +1499,20 @@ class MessageAnnotation(Base): app_id: Mapped[str] = mapped_column(StringUUID) conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) message_id: Mapped[str | None] = mapped_column(StringUUID) - question = mapped_column(LongText, nullable=True) - content = mapped_column(LongText, nullable=False) + question: Mapped[str] = mapped_column(LongText, nullable=False) + content: Mapped[str] = mapped_column(LongText, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) - account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + @property + def question_text(self) -> str: + """Return a non-null question string, falling back to the answer content.""" + return self.question or self.content + @property def account(self): account = db.session.query(Account).where(Account.id == self.account_id).first() @@ -1449,7 +1524,7 @@ class MessageAnnotation(Base): return account -class AppAnnotationHitHistory(Base): +class AppAnnotationHitHistory(TypeBase): __tablename__ = "app_annotation_hit_histories" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), @@ -1459,17 +1534,19 @@ class AppAnnotationHitHistory(Base): sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - app_id = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - source = mapped_column(LongText, nullable=False) - question = mapped_column(LongText, nullable=False) - account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - score = mapped_column(Float, nullable=False, server_default=sa.text("0")) - message_id = mapped_column(StringUUID, nullable=False) - annotation_question = mapped_column(LongText, nullable=False) - annotation_content = mapped_column(LongText, nullable=False) + source: Mapped[str] = mapped_column(LongText, nullable=False) + question: Mapped[str] = mapped_column(LongText, nullable=False) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + score: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0")) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + annotation_question: Mapped[str] = mapped_column(LongText, nullable=False) + annotation_content: Mapped[str] = mapped_column(LongText, nullable=False) @property def account(self): @@ -1538,7 +1615,7 @@ class OperationLog(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) action: Mapped[str] = mapped_column(String(255), nullable=False) - content: Mapped[Any] = mapped_column(sa.JSON) + content: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -1845,7 +1922,7 @@ class MessageChain(TypeBase): ) -class MessageAgentThought(Base): +class MessageAgentThought(TypeBase): __tablename__ = "message_agent_thoughts" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), @@ -1853,34 +1930,42 @@ class MessageAgentThought(Base): sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - message_id = mapped_column(StringUUID, nullable=False) - message_chain_id = mapped_column(StringUUID, nullable=True) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - thought = mapped_column(LongText, nullable=True) - tool = mapped_column(LongText, nullable=True) - tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'")) - tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'")) - tool_input = mapped_column(LongText, nullable=True) - observation = mapped_column(LongText, nullable=True) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'")) + tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'")) + tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design - tool_process_data = mapped_column(LongText, nullable=True) - message = mapped_column(LongText, nullable=True) - message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - message_unit_price = mapped_column(sa.Numeric, nullable=True) - message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - message_files = mapped_column(LongText, nullable=True) - answer = mapped_column(LongText, nullable=True) - answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - answer_unit_price = mapped_column(sa.Numeric, nullable=True) - answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - total_price = mapped_column(sa.Numeric, nullable=True) - currency = mapped_column(String(255), nullable=True) - latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) - created_by_role = mapped_column(String(255), nullable=False) - created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) + tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + message_price_unit: Mapped[Decimal] = mapped_column( + sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001") + ) + message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + answer_price_unit: Mapped[Decimal] = mapped_column( + sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001") + ) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None) + currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp() + ) @property def files(self) -> list[Any]: @@ -2075,3 +2160,35 @@ class TraceAppConfig(TypeBase): "created_at": str(self.created_at) if self.created_at else None, "updated_at": str(self.updated_at) if self.updated_at else None, } + + +class TenantCreditPool(TypeBase): + __tablename__ = "tenant_credit_pools" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"), + sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"), + sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"), + ) + + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, + ) + + @property + def remaining_credits(self) -> int: + return max(0, self.quota_limit - self.quota_used) + + def has_sufficient_credits(self, required_credits: int) -> bool: + return self.remaining_credits >= required_credits diff --git a/api/models/provider.py b/api/models/provider.py index 2afd8c5329..441b54c797 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from enum import StrEnum, auto from functools import cached_property @@ -19,7 +21,7 @@ class ProviderType(StrEnum): SYSTEM = auto() @staticmethod - def value_of(value: str) -> "ProviderType": + def value_of(value: str) -> ProviderType: for member in ProviderType: if member.value == value: return member @@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum): """hosted trial quota""" @staticmethod - def value_of(value: str) -> "ProviderQuotaType": + def value_of(value: str) -> ProviderQuotaType: for member in ProviderQuotaType: if member.value == value: return member @@ -76,7 +78,7 @@ class Provider(TypeBase): quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="") quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None) - quota_used: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, default=0) + quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/tools.py b/api/models/tools.py index e4f9bcb582..e7b98dcf27 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime from decimal import Decimal @@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase): ) @property - def schema_type(self) -> "ApiProviderSchemaType": + def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) @property - def tools(self) -> list["ApiToolBundle"]: + def tools(self) -> list[ApiToolBundle]: return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)] @property @@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase): return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property - def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]: + def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: return [ WorkflowToolParameterConfiguration.model_validate(config) for config in json.loads(self.parameter_configuration) @@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase): except (json.JSONDecodeError, TypeError): return [] - def to_entity(self) -> "MCPProviderEntity": + def to_entity(self) -> MCPProviderEntity: """Convert to domain entity""" from core.entities.mcp_provider import MCPProviderEntity @@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase): ) @property - def description_i18n(self) -> "I18nObject": + def description_i18n(self) -> I18nObject: return I18nObject.model_validate(json.loads(self.description)) diff --git a/api/models/trigger.py b/api/models/trigger.py index 87e2a5ccfc..209345eb84 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -415,7 +415,7 @@ class AppTrigger(TypeBase): node_id: Mapped[str | None] = mapped_column(String(64), nullable=False) trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) - provider_name: Mapped[str] = mapped_column(String(255), server_default="", default="") # why it is nullable? + provider_name: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default="") status: Mapped[str] = mapped_column( EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.ENABLED ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 5522a16ab7..94e0881bd1 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -45,7 +45,7 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.variables import SecretVariable, Segment, SegmentType, Variable +from core.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper @@ -177,8 +177,8 @@ class Workflow(Base): # bug graph: str, features: str, created_by: str, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", @@ -227,8 +227,7 @@ class Workflow(Base): # bug # # Currently, the following functions / methods would mutate the returned dict: # - # - `_get_graph_and_variable_pool_of_single_iteration`. - # - `_get_graph_and_variable_pool_of_single_loop`. + # - `_get_graph_and_variable_pool_for_single_node_run`. return json.loads(self.graph) if self.graph else {} def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]: @@ -451,7 +450,7 @@ class Workflow(Base): # bug # decrypt secret variables value def decrypt_func( - var: Variable, + var: VariableBase, ) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) @@ -467,7 +466,7 @@ class Workflow(Base): # bug return decrypted_results @environment_variables.setter - def environment_variables(self, value: Sequence[Variable]): + def environment_variables(self, value: Sequence[VariableBase]): if not value: self._environment_variables = "{}" return @@ -491,7 +490,7 @@ class Workflow(Base): # bug value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value - def encrypt_func(var: Variable) -> Variable: + def encrypt_func(var: VariableBase) -> VariableBase: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) else: @@ -521,7 +520,7 @@ class Workflow(Base): # bug return result @property - def conversation_variables(self) -> Sequence[Variable]: + def conversation_variables(self) -> Sequence[VariableBase]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: self._conversation_variables = "{}" @@ -531,7 +530,7 @@ class Workflow(Base): # bug return results @conversation_variables.setter - def conversation_variables(self, value: Sequence[Variable]): + def conversation_variables(self, value: Sequence[VariableBase]): self._conversation_variables = json.dumps( {var.name: var.model_dump() for var in value}, ensure_ascii=False, @@ -601,6 +600,7 @@ class WorkflowRun(Base): __table_args__ = ( sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"), sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + sa.Index("workflow_run_created_at_id_idx", "created_at", "id"), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) @@ -793,11 +793,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo return ( PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), Index( - "workflow_node_execution_workflow_run_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", + "workflow_node_execution_workflow_run_id_idx", "workflow_run_id", ), Index( @@ -1179,6 +1175,69 @@ class WorkflowAppLog(TypeBase): } +class WorkflowArchiveLog(TypeBase): + """ + Workflow archive log. + + Stores essential workflow run snapshot data for archived app logs. + + Field sources: + - Shared fields (tenant/app/workflow/run ids, created_by*): from WorkflowRun for consistency. + - log_* fields: from WorkflowAppLog when present; null if the run has no app log. + - run_* fields: workflow run snapshot fields from WorkflowRun. + - trigger_metadata: snapshot from WorkflowTriggerLog when present. + """ + + __tablename__ = "workflow_archive_logs" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_archive_log_pkey"), + sa.Index("workflow_archive_log_app_idx", "tenant_id", "app_id"), + sa.Index("workflow_archive_log_workflow_run_id_idx", "workflow_run_id"), + sa.Index("workflow_archive_log_run_created_at_idx", "run_created_at"), + ) + + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False + ) + + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + + log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + + run_version: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False) + run_error: Mapped[str | None] = mapped_column(LongText, nullable=True) + run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) + run_total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + run_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + run_finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + run_exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + + trigger_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True) + archived_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + + @property + def workflow_run_summary(self) -> dict[str, Any]: + return { + "id": self.workflow_run_id, + "status": self.run_status, + "triggered_from": self.run_triggered_from, + "elapsed_time": self.run_elapsed_time, + "total_tokens": self.run_total_tokens, + } + + class ConversationVariable(TypeBase): __tablename__ = "workflow_conversation_variables" @@ -1194,7 +1253,7 @@ class ConversationVariable(TypeBase): ) @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": + def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable": obj = cls( id=variable.id, app_id=app_id, @@ -1203,7 +1262,7 @@ class ConversationVariable(TypeBase): ) return obj - def to_variable(self) -> Variable: + def to_variable(self) -> VariableBase: mapping = json.loads(self.data) return variable_factory.build_conversation_variable_from_mapping(mapping) @@ -1519,6 +1578,7 @@ class WorkflowDraftVariable(Base): file_id: str | None = None, ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() + variable.id = str(uuid4()) variable.created_at = naive_utc_now() variable.updated_at = naive_utc_now() variable.description = description diff --git a/api/pyproject.toml b/api/pyproject.toml index dbc6a2eb83..575c1434c5 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.11.2" +version = "1.11.4" requires-python = ">=3.11,<3.13" dependencies = [ @@ -31,7 +31,7 @@ dependencies = [ "gunicorn~=23.0.0", "httpx[socks]~=0.27.0", "jieba==0.42.1", - "json-repair>=0.41.1", + "json-repair>=0.55.1", "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.1.77", @@ -64,7 +64,7 @@ dependencies = [ "pandas[excel,output-formatting,performance]~=2.2.2", "psycogreen~=1.0.2", "psycopg2-binary~=2.9.6", - "pycryptodome==3.19.1", + "pycryptodome==3.23.0", "pydantic~=2.11.4", "pydantic-extra-types~=2.10.3", "pydantic-settings~=2.11.0", @@ -93,6 +93,7 @@ dependencies = [ "weaviate-client==4.17.0", "apscheduler>=3.11.0", "weave>=0.52.16", + "fastopenapi[flask]>=0.7.0", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -189,7 +190,7 @@ storage = [ "opendal~=0.46.0", "oss2==2.18.5", "supabase~=2.18.1", - "tos~=2.7.1", + "tos~=2.9.0", ] ############################################################ diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 6a689b96df..007c49ddb0 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -8,6 +8,7 @@ ], "typeCheckingMode": "strict", "allowedUntypedLibraries": [ + "fastopenapi", "flask_restx", "flask_login", "opentelemetry.instrumentation.celery", diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 641e8a488e..6446eb0d6e 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -14,8 +14,10 @@ from dataclasses import dataclass from datetime import datetime from typing import Protocol +from sqlalchemy.orm import Session + from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from models.workflow import WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload @dataclass(frozen=True) @@ -175,6 +177,18 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... + def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]: + """ + Count node executions and offloads for the given workflow run ids. + """ + ... + + def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]: + """ + Delete node executions and offloads for the given workflow run ids. + """ + ... + def delete_executions_by_app( self, tenant_id: str, @@ -240,3 +254,23 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr The number of executions deleted """ ... + + def get_offloads_by_execution_ids( + self, + session: Session, + node_execution_ids: Sequence[str], + ) -> Sequence[WorkflowNodeExecutionOffload]: + """ + Get offload records by node execution IDs. + + This method retrieves workflow node execution offload records + that belong to the given node execution IDs. + + Args: + session: The database session to use + node_execution_ids: List of node execution IDs to filter by + + Returns: + A sequence of WorkflowNodeExecutionOffload instances + """ + ... diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index c1ed45812e..17e01a6e18 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -34,15 +34,18 @@ Example: ``` """ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from sqlalchemy.orm import Session + from core.workflow.entities.pause_reason import PauseReason +from core.workflow.enums import WorkflowType from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( AverageInteractionStats, @@ -253,6 +256,151 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_runs_batch_by_time_range( + self, + start_from: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + run_types: Sequence[WorkflowType] | None = None, + tenant_ids: Sequence[str] | None = None, + ) -> Sequence[WorkflowRun]: + """ + Fetch ended workflow runs in a time window for archival and clean batching. + """ + ... + + def get_archived_run_ids( + self, + session: Session, + run_ids: Sequence[str], + ) -> set[str]: + """ + Fetch workflow run IDs that already have archive log records. + """ + ... + + def get_archived_logs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowArchiveLog]: + """ + Fetch archived workflow logs by time range for restore. + """ + ... + + def get_archived_log_by_run_id( + self, + run_id: str, + ) -> WorkflowArchiveLog | None: + """ + Fetch a workflow archive log by workflow run ID. + """ + ... + + def delete_archive_log_by_run_id( + self, + session: Session, + run_id: str, + ) -> int: + """ + Delete archive log by workflow run ID. + + Used after restoring a workflow run to remove the archive log record, + allowing the run to be archived again if needed. + + Args: + session: Database session + run_id: Workflow run ID + + Returns: + Number of records deleted (0 or 1) + """ + ... + + def delete_runs_with_related( + self, + runs: Sequence[WorkflowRun], + delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + """ + Delete workflow runs and their related records (node executions, offloads, app logs, + trigger logs, pauses, pause reasons). + """ + ... + + def get_pause_records_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowPause]: + """ + Fetch workflow pause records by workflow run ID. + """ + ... + + def get_pause_reason_records_by_run_id( + self, + session: Session, + pause_ids: Sequence[str], + ) -> Sequence[WorkflowPauseReason]: + """ + Fetch workflow pause reason records by pause IDs. + """ + ... + + def get_app_logs_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowAppLog]: + """ + Fetch workflow app logs by workflow run ID. + """ + ... + + def create_archive_logs( + self, + session: Session, + run: WorkflowRun, + app_logs: Sequence[WorkflowAppLog], + trigger_metadata: str | None, + ) -> int: + """ + Create archive log records for a workflow run. + """ + ... + + def get_archived_runs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowRun]: + """ + Return workflow runs that already have archive logs, for cleanup of `workflow_runs`. + """ + ... + + def count_runs_with_related( + self, + runs: Sequence[WorkflowRun], + count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + """ + Count workflow runs and their related records (node executions, offloads, app logs, + trigger logs, pauses, pause reasons) without deleting data. + """ + ... + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 1512c162d7..e74e6e26f1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -9,12 +9,12 @@ import json from collections.abc import Sequence from datetime import datetime -from sqlalchemy import asc, delete, desc, select +from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from models.workflow import WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, WorkflowNodeExecutionSnapshot, @@ -369,3 +369,85 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut result = cast(CursorResult, session.execute(stmt)) session.commit() return result.rowcount + + def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]: + """ + Delete node executions (and offloads) for the given workflow runs using workflow_run_id. + """ + if not run_ids: + return 0, 0 + + run_ids = list(run_ids) + run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids) + node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter) + + offloads_deleted = ( + cast( + CursorResult, + session.execute( + delete(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + ), + ).rowcount + or 0 + ) + + node_executions_deleted = ( + cast( + CursorResult, + session.execute(delete(WorkflowNodeExecutionModel).where(run_id_filter)), + ).rowcount + or 0 + ) + + return node_executions_deleted, offloads_deleted + + def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]: + """ + Count node executions (and offloads) for the given workflow runs using workflow_run_id. + """ + if not run_ids: + return 0, 0 + + run_ids = list(run_ids) + run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids) + + node_executions_count = ( + session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(run_id_filter)) or 0 + ) + node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter) + offloads_count = ( + session.scalar( + select(func.count()) + .select_from(WorkflowNodeExecutionOffload) + .where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)) + ) + or 0 + ) + + return int(node_executions_count), int(offloads_count) + + @staticmethod + def get_by_run( + session: Session, + run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Fetch node executions for a run using workflow_run_id. + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def get_offloads_by_execution_ids( + self, + session: Session, + node_execution_ids: Sequence[str], + ) -> Sequence[WorkflowNodeExecutionOffload]: + if not node_execution_ids: + return [] + + stmt = select(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + return list(session.scalars(stmt)) diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index cbf495dce1..00cb979e17 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -22,7 +22,7 @@ Implementation Notes: import json import logging import uuid -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from decimal import Decimal from typing import Any, cast @@ -34,7 +34,7 @@ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.enums import WorkflowExecutionStatus, WorkflowType from core.workflow.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now @@ -44,8 +44,7 @@ from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType -from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason, WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( @@ -379,6 +378,335 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id) return total_deleted + def get_runs_batch_by_time_range( + self, + start_from: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + run_types: Sequence[WorkflowType] | None = None, + tenant_ids: Sequence[str] | None = None, + ) -> Sequence[WorkflowRun]: + """ + Fetch ended workflow runs in a time window for archival and clean batching. + + Query scope: + - created_at in [start_from, end_before) + - type in run_types (when provided) + - status is an ended state + - optional tenant_id filter and cursor (last_seen) for pagination + """ + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where( + WorkflowRun.created_at < end_before, + WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()), + ) + .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc()) + .limit(batch_size) + ) + if run_types is not None: + if not run_types: + return [] + stmt = stmt.where(WorkflowRun.type.in_(run_types)) + + if start_from: + stmt = stmt.where(WorkflowRun.created_at >= start_from) + + if tenant_ids: + stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids)) + + if last_seen: + stmt = stmt.where( + or_( + WorkflowRun.created_at > last_seen[0], + and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]), + ) + ) + + return session.scalars(stmt).all() + + def get_archived_run_ids( + self, + session: Session, + run_ids: Sequence[str], + ) -> set[str]: + if not run_ids: + return set() + + stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids)) + return set(session.scalars(stmt).all()) + + def get_archived_log_by_run_id( + self, + run_id: str, + ) -> WorkflowArchiveLog | None: + with self._session_maker() as session: + stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1) + return session.scalar(stmt) + + def delete_archive_log_by_run_id( + self, + session: Session, + run_id: str, + ) -> int: + stmt = delete(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id) + result = session.execute(stmt) + return cast(CursorResult, result).rowcount or 0 + + def get_pause_records_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowPause]: + stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def get_pause_reason_records_by_run_id( + self, + session: Session, + pause_ids: Sequence[str], + ) -> Sequence[WorkflowPauseReason]: + if not pause_ids: + return [] + + stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) + return list(session.scalars(stmt)) + + def delete_runs_with_related( + self, + runs: Sequence[WorkflowRun], + delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + if not runs: + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + with self._session_maker() as session: + run_ids = [run.id for run in runs] + if delete_node_executions: + node_executions_deleted, offloads_deleted = delete_node_executions(session, runs) + else: + node_executions_deleted, offloads_deleted = 0, 0 + + app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))) + app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0 + + pause_stmt = select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids)) + pause_ids = session.scalars(pause_stmt).all() + pause_reasons_deleted = 0 + pauses_deleted = 0 + + if pause_ids: + pause_reasons_result = session.execute( + delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) + ) + pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0 + pauses_result = session.execute(delete(WorkflowPause).where(WorkflowPause.id.in_(pause_ids))) + pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0 + + trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0 + + runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))) + runs_deleted = cast(CursorResult, runs_result).rowcount or 0 + + session.commit() + + return { + "runs": runs_deleted, + "node_executions": node_executions_deleted, + "offloads": offloads_deleted, + "app_logs": app_logs_deleted, + "trigger_logs": trigger_logs_deleted, + "pauses": pauses_deleted, + "pause_reasons": pause_reasons_deleted, + } + + def get_app_logs_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowAppLog]: + stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def create_archive_logs( + self, + session: Session, + run: WorkflowRun, + app_logs: Sequence[WorkflowAppLog], + trigger_metadata: str | None, + ) -> int: + if not app_logs: + archive_log = WorkflowArchiveLog( + log_id=None, + log_created_at=None, + log_created_from=None, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + workflow_run_id=run.id, + created_by_role=run.created_by_role, + created_by=run.created_by, + run_version=run.version, + run_status=run.status, + run_triggered_from=run.triggered_from, + run_error=run.error, + run_elapsed_time=run.elapsed_time, + run_total_tokens=run.total_tokens, + run_total_steps=run.total_steps, + run_created_at=run.created_at, + run_finished_at=run.finished_at, + run_exceptions_count=run.exceptions_count, + trigger_metadata=trigger_metadata, + ) + session.add(archive_log) + return 1 + + archive_logs = [ + WorkflowArchiveLog( + log_id=app_log.id, + log_created_at=app_log.created_at, + log_created_from=app_log.created_from, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + workflow_run_id=run.id, + created_by_role=run.created_by_role, + created_by=run.created_by, + run_version=run.version, + run_status=run.status, + run_triggered_from=run.triggered_from, + run_error=run.error, + run_elapsed_time=run.elapsed_time, + run_total_tokens=run.total_tokens, + run_total_steps=run.total_steps, + run_created_at=run.created_at, + run_finished_at=run.finished_at, + run_exceptions_count=run.exceptions_count, + trigger_metadata=trigger_metadata, + ) + for app_log in app_logs + ] + session.add_all(archive_logs) + return len(archive_logs) + + def get_archived_runs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowRun]: + """ + Retrieves WorkflowRun records by joining workflow_archive_logs. + + Used to identify runs that are already archived and ready for deletion. + """ + stmt = ( + select(WorkflowRun) + .join(WorkflowArchiveLog, WorkflowArchiveLog.workflow_run_id == WorkflowRun.id) + .where( + WorkflowArchiveLog.run_created_at >= start_date, + WorkflowArchiveLog.run_created_at < end_date, + ) + .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc()) + .limit(limit) + ) + if tenant_ids: + stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) + return list(session.scalars(stmt)) + + def get_archived_logs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowArchiveLog]: + # Returns WorkflowArchiveLog rows directly; use this when workflow_runs may be deleted. + stmt = ( + select(WorkflowArchiveLog) + .where( + WorkflowArchiveLog.run_created_at >= start_date, + WorkflowArchiveLog.run_created_at < end_date, + ) + .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc()) + .limit(limit) + ) + if tenant_ids: + stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) + return list(session.scalars(stmt)) + + def count_runs_with_related( + self, + runs: Sequence[WorkflowRun], + count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + if not runs: + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + with self._session_maker() as session: + run_ids = [run.id for run in runs] + if count_node_executions: + node_executions_count, offloads_count = count_node_executions(session, runs) + else: + node_executions_count, offloads_count = 0, 0 + + app_logs_count = ( + session.scalar( + select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)) + ) + or 0 + ) + + pause_ids = session.scalars( + select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids)) + ).all() + pauses_count = len(pause_ids) + pause_reasons_count = 0 + if pause_ids: + pause_reasons_count = ( + session.scalar( + select(func.count()) + .select_from(WorkflowPauseReason) + .where(WorkflowPauseReason.pause_id.in_(pause_ids)) + ) + or 0 + ) + + trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0 + + return { + "runs": len(runs), + "node_executions": node_executions_count, + "offloads": offloads_count, + "app_logs": int(app_logs_count), + "trigger_logs": trigger_logs_count, + "pauses": pauses_count, + "pause_reasons": int(pause_reasons_count), + } + def create_workflow_pause( self, workflow_run_id: str, @@ -405,9 +733,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ValueError: If workflow_run_id is invalid or workflow run doesn't exist RuntimeError: If workflow is already paused or in invalid state """ - previous_pause_model_query = select(WorkflowPauseModel).where( - WorkflowPauseModel.workflow_run_id == workflow_run_id - ) + previous_pause_model_query = select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id) with self._session_maker() as session, session.begin(): # Get the workflow run workflow_run = session.get(WorkflowRun, workflow_run_id) @@ -434,7 +760,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Upload the state file # Create the pause record - pause_model = WorkflowPauseModel() + pause_model = WorkflowPause() pause_model.id = str(uuidv7()) pause_model.workflow_id = workflow_run.workflow_id pause_model.workflow_run_id = workflow_run.id @@ -643,13 +969,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ with self._session_maker() as session, session.begin(): # Get the pause model by ID - pause_model = session.get(WorkflowPauseModel, pause_entity.id) + pause_model = session.get(WorkflowPause, pause_entity.id) if pause_model is None: raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}") self._delete_pause_model(session, pause_model) @staticmethod - def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel): + def _delete_pause_model(session: Session, pause_model: WorkflowPause): storage.delete(pause_model.state_object_key) # Delete the pause record @@ -684,15 +1010,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): _limit: int = limit or 1000 pruned_record_ids: list[str] = [] cond = or_( - WorkflowPauseModel.created_at < expiration, + WorkflowPause.created_at < expiration, and_( - WorkflowPauseModel.resumed_at.is_not(null()), - WorkflowPauseModel.resumed_at < resumption_expiration, + WorkflowPause.resumed_at.is_not(null()), + WorkflowPause.resumed_at < resumption_expiration, ), ) # First, collect pause records to delete with their state files # Expired pauses (created before expiration time) - stmt = select(WorkflowPauseModel).where(cond).limit(_limit) + stmt = select(WorkflowPause).where(cond).limit(_limit) with self._session_maker(expire_on_commit=False) as session: # Old resumed pauses (resumed more than resumption_duration ago) @@ -703,7 +1029,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Delete state files from storage for pause in pauses_to_delete: with self._session_maker(expire_on_commit=False) as session, session.begin(): - # todo: this issues a separate query for each WorkflowPauseModel record. + # todo: this issues a separate query for each WorkflowPause record. # consider batching this lookup. try: storage.delete(pause.state_object_key) @@ -964,7 +1290,7 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): def __init__( self, *, - pause_model: WorkflowPauseModel, + pause_model: WorkflowPause, reason_models: Sequence[WorkflowPauseReason], pause_reasons: Sequence[PauseReason] | None = None, human_input_form: Sequence = (), diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index c828cc60c2..1f6740b066 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository. from collections.abc import Sequence from datetime import UTC, datetime, timedelta +from typing import cast -from sqlalchemy import and_, select +from sqlalchemy import and_, delete, func, select +from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session from models.enums import WorkflowTriggerStatus @@ -44,6 +46,11 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return self.session.scalar(query) + def list_by_run_id(self, run_id: str) -> Sequence[WorkflowTriggerLog]: + """List trigger logs for a workflow run.""" + query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id) + return list(self.session.scalars(query).all()) + def get_failed_for_retry( self, tenant_id: str, max_retry_count: int = 3, limit: int = 100 ) -> Sequence[WorkflowTriggerLog]: @@ -94,3 +101,37 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): .limit(1) ) return self.session.scalar(query) + + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Delete trigger logs associated with the given workflow run ids. + + Args: + run_ids: Collection of workflow run identifiers. + + Returns: + Number of rows deleted. + """ + if not run_ids: + return 0 + + result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))) + return cast(CursorResult, result).rowcount or 0 + + def count_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Count trigger logs associated with the given workflow run ids. + + Args: + run_ids: Collection of workflow run identifiers. + + Returns: + Number of rows matched. + """ + if not run_ids: + return 0 + + count = self.session.scalar( + select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)) + ) + return int(count or 0) diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py index e78f1db532..7f9e6b7b68 100644 --- a/api/repositories/workflow_trigger_log_repository.py +++ b/api/repositories/workflow_trigger_log_repository.py @@ -121,3 +121,15 @@ class WorkflowTriggerLogRepository(Protocol): The matching WorkflowTriggerLog if present, None otherwise """ ... + + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Delete trigger logs for workflow run IDs. + + Args: + run_ids: Workflow run IDs to delete + + Returns: + Number of rows deleted + """ + ... diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 352a84b592..be5f483b95 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -1,90 +1,78 @@ -import datetime import logging import time import click -from sqlalchemy.exc import SQLAlchemyError +from redis.exceptions import LockError import app from configs import dify_config -from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.model import ( - App, - Message, - MessageAgentThought, - MessageAnnotation, - MessageChain, - MessageFeedback, - MessageFile, -) -from models.web import SavedMessage -from services.feature_service import FeatureService +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService logger = logging.getLogger(__name__) -@app.celery.task(queue="dataset") +@app.celery.task(queue="retention") def clean_messages(): - click.echo(click.style("Start clean messages.", fg="green")) - start_at = time.perf_counter() - plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( - days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING - ) - while True: - try: - # Main query with join and filter - messages = ( - db.session.query(Message) - .where(Message.created_at < plan_sandbox_clean_message_day) - .order_by(Message.created_at.desc()) - .limit(100) - .all() - ) + """ + Clean expired messages based on clean policy. - except SQLAlchemyError: - raise - if not messages: - break - for message in messages: - app = db.session.query(App).filter_by(id=message.app_id).first() - if not app: - logger.warning( - "Expected App record to exist, but none was found, app_id=%s, message_id=%s", - message.app_id, - message.id, - ) - continue - features_cache_key = f"features:{app.tenant_id}" - plan_cache = redis_client.get(features_cache_key) - if plan_cache is None: - features = FeatureService.get_features(app.tenant_id) - redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) - plan = features.billing.subscription.plan - else: - plan = plan_cache.decode() - if plan == CloudPlan.SANDBOX: - # clean related message - db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(Message).where(Message.id == message.id).delete() - db.session.commit() - end_at = time.perf_counter() - click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green")) + This task uses MessagesCleanService to efficiently clean messages in batches. + The behavior depends on BILLING_ENABLED configuration: + - BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period) + - BILLING_ENABLED=False: delete all messages within the time range + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + start_at = time.perf_counter() + + try: + # Create policy based on billing configuration + policy = create_message_clean_policy( + graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD, + ) + + # Create and run the cleanup service + # lock the task to avoid concurrent execution in case of the future data volume growth + with redis_client.lock( + "retention:clean_messages", timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL, blocking=False + ): + service = MessagesCleanService.from_days( + policy=policy, + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except LockError: + end_at = time.perf_counter() + logger.exception("clean_messages: acquire task lock failed, skip current execution") + click.echo( + click.style( + f"clean_messages: skipped (lock already held) - latency: {end_at - start_at:.2f}s", + fg="yellow", + ) + ) + raise + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py new file mode 100644 index 0000000000..ff45a3ddf2 --- /dev/null +++ b/api/schedule/clean_workflow_runs_task.py @@ -0,0 +1,79 @@ +import logging +from datetime import UTC, datetime + +import click +from redis.exceptions import LockError + +import app +from configs import dify_config +from extensions.ext_redis import redis_client +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + +logger = logging.getLogger(__name__) + + +@app.celery.task(queue="retention") +def clean_workflow_runs_task() -> None: + """ + Scheduled cleanup for workflow runs and related records (sandbox tenants only). + """ + click.echo( + click.style( + ( + "Scheduled workflow run cleanup starting: " + f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, " + f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}" + ), + fg="green", + ) + ) + + start_time = datetime.now(UTC) + + try: + # lock the task to avoid concurrent execution in case of the future data volume growth + with redis_client.lock( + "retention:clean_workflow_runs_task", + timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL, + blocking=False, + ): + WorkflowRunCleanup( + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + start_from=None, + end_before=None, + ).run() + + end_time = datetime.now(UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) + ) + except LockError: + end_time = datetime.now(UTC) + elapsed = end_time - start_time + logger.exception("clean_workflow_runs_task: acquire task lock failed, skip current execution") + click.echo( + click.style( + f"Scheduled workflow run cleanup skipped (lock already held). " + f"start={start_time.isoformat()} end={end_time.isoformat()} duration={elapsed}", + fg="yellow", + ) + ) + raise + except Exception as e: + end_time = datetime.now(UTC) + elapsed = end_time - start_time + logger.exception("clean_workflow_runs_task failed") + click.echo( + click.style( + f"Scheduled workflow run cleanup failed. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed} - {str(e)}", + fg="red", + ) + ) + raise diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index c343063fae..ed46c1c70a 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -50,10 +50,13 @@ def create_clusters(batch_size): ) for new_cluster in new_clusters: tidb_auth_binding = TidbAuthBinding( + tenant_id=None, cluster_id=new_cluster["cluster_id"], cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + active=False, + status="CREATING", ) db.session.add(tidb_auth_binding) db.session.commit() diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index db610df290..77d6b5a138 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -16,6 +16,11 @@ celery_redis = Redis( port=redis_config.get("port") or 6379, password=redis_config.get("password") or None, db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1, + ssl=bool(dify_config.BROKER_USE_SSL), + ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None, + ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None, + ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None, + ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None, ) logger = logging.getLogger(__name__) diff --git a/api/services/account_service.py b/api/services/account_service.py index 5a549dc318..35e4a505af 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from hashlib import sha256 from typing import Any, cast from pydantic import BaseModel -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -748,6 +748,21 @@ class AccountService: cls.email_code_login_rate_limiter.increment_rate_limit(email) return token + @staticmethod + def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + """ + Retrieve an account by email and fall back to the lowercase email if the original lookup fails. + + This keeps backward compatibility for older records that stored uppercase emails while the + rest of the system gradually normalizes new inputs. + """ + query_session = session or db.session + account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account + + return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @@ -999,6 +1014,11 @@ class TenantService: tenant.encrypt_public_key = generate_key_pair(tenant.id) db.session.commit() + + from services.credit_pool_service import CreditPoolService + + CreditPoolService.create_default_pool(tenant.id) + return tenant @staticmethod @@ -1358,16 +1378,27 @@ class RegisterService: if not inviter: raise ValueError("Inviter is required") + normalized_email = email.lower() + """Invite new member""" + # Check workspace permission for member invitations + from libs.workspace_permission import check_workspace_member_invite_permission + + check_workspace_member_invite_permission(tenant.id) + with Session(db.engine) as session: - account = session.query(Account).filter_by(email=email).first() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") - name = email.split("@")[0] + name = normalized_email.split("@")[0] account = cls.register( - email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True + email=normalized_email, + name=name, + language=language, + status=AccountStatus.PENDING, + is_setup=True, ) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) @@ -1389,7 +1420,7 @@ class RegisterService: # send email send_invite_member_mail_task.delay( language=language, - to=email, + to=account.email, token=token, inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, @@ -1488,6 +1519,16 @@ class RegisterService: invitation: dict = json.loads(data) return invitation + @classmethod + def get_invitation_with_case_fallback( + cls, workspace_id: str | None, email: str | None, token: str + ) -> dict[str, Any] | None: + invitation = cls.get_invitation_if_token_valid(workspace_id, email, token) + if invitation or not email or email == email.lower(): + return invitation + normalized_email = email.lower() + return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token) + def _generate_refresh_token(length: int = 64): token = secrets.token_hex(length) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index d03cbddceb..56e9cc6a00 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -77,7 +77,7 @@ class AppAnnotationService: if annotation_setting: add_annotation_to_index_task.delay( annotation.id, - annotation.question, + question, current_tenant_id, app_id, annotation_setting.collection_binding_id, @@ -137,13 +137,16 @@ class AppAnnotationService: if not app: raise NotFound("App not found") if keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) stmt = ( select(MessageAnnotation) .where(MessageAnnotation.app_id == app_id) .where( or_( - MessageAnnotation.question.ilike(f"%{keyword}%"), - MessageAnnotation.content.ilike(f"%{keyword}%"), + MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"), + MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) @@ -206,8 +209,12 @@ class AppAnnotationService: if not app: raise NotFound("App not found") + question = args.get("question") + if question is None: + raise ValueError("'question' is required") + annotation = MessageAnnotation( - app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id + app_id=app.id, content=args["answer"], question=question, account_id=current_user.id ) db.session.add(annotation) db.session.commit() @@ -216,7 +223,7 @@ class AppAnnotationService: if annotation_setting: add_annotation_to_index_task.delay( annotation.id, - args["question"], + question, current_tenant_id, app_id, annotation_setting.collection_binding_id, @@ -241,8 +248,12 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") + question = args.get("question") + if question is None: + raise ValueError("'question' is required") + annotation.content = args["answer"] - annotation.question = args["question"] + annotation.question = question db.session.commit() # if annotation reply is enabled , add annotation to index @@ -253,7 +264,7 @@ class AppAnnotationService: if app_annotation_setting: update_annotation_to_index_task.delay( annotation.id, - annotation.question, + annotation.question_text, current_tenant_id, app_id, app_annotation_setting.collection_binding_id, diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index deba0b79e8..0f42c99246 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig +from models.model import AppModelConfig, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -428,10 +428,10 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") - if icon_type_value in ["emoji", "link", "image"]: + if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]: icon_type = icon_type_value else: - icon_type = "emoji" + icon_type = IconType.EMOJI icon = icon or str(app_data.get("icon", "")) if app: @@ -521,12 +521,10 @@ class AppDslService: raise ValueError("Missing model_config for chat/agent-chat/completion app") # Initialize or update model config if not app.app_model_config: - app_model_config = AppModelConfig().from_model_config_dict(model_config) + app_model_config = AppModelConfig( + app_id=app.id, created_by=account.id, updated_by=account.id + ).from_model_config_dict(model_config) app_model_config.id = str(uuid4()) - app_model_config.app_id = app.id - app_model_config.created_by = account.id - app_model_config.updated_by = account.id - app.app_model_config_id = app_model_config.id self._session.add(app_model_config) @@ -783,15 +781,16 @@ class AppDslService: return dependencies @classmethod - def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]: + def get_leaked_dependencies( + cls, tenant_id: str, dsl_dependencies: list[PluginDependency] + ) -> list[PluginDependency]: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] - if not dependencies: + if not dsl_dependencies: return [] - return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) + return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies) @staticmethod def _generate_aes_key(tenant_id: str) -> bytes: diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index ce64f6ac84..a3c3471982 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import logging import threading import uuid from collections.abc import Callable, Generator, Mapping -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -18,7 +20,8 @@ from enums.quota_type import QuotaType, unlimited from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun -from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError +from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError +from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, chatflow_execute_task @@ -26,6 +29,9 @@ logger = logging.getLogger(__name__) SSE_TASK_START_FALLBACK_MS = 200 +if TYPE_CHECKING: + from controllers.console.app.workflow import LoopNodeRunPayload + class AppGenerateService: @staticmethod @@ -243,7 +249,9 @@ class AppGenerateService: raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod - def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): + def generate_single_loop( + cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True + ): if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( diff --git a/api/services/app_service.py b/api/services/app_service.py index ef89a4fd10..af458ff618 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -55,8 +55,11 @@ class AppService: if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) if args.get("name"): + from libs.helper import escape_like_pattern + name = args["name"][:30] - filters.append(App.name.ilike(f"%{name}%")) + escaped_name = escape_like_pattern(name) + filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if args.get("tag_ids") and len(args["tag_ids"]) > 0: target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) @@ -147,10 +150,9 @@ class AppService: db.session.flush() if default_model_config: - app_model_config = AppModelConfig(**default_model_config) - app_model_config.app_id = app.id - app_model_config.created_by = account.id - app_model_config.updated_by = account.id + app_model_config = AppModelConfig( + **default_model_config, app_id=app.id, created_by=account.id, updated_by=account.id + ) db.session.add(app_model_config) db.session.flush() diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index e100582511..bc73b7c8c2 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -21,7 +21,7 @@ from models.model import App, EndUser from models.trigger import WorkflowTriggerLog from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository -from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError +from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow_service import WorkflowService @@ -141,7 +141,7 @@ class AsyncWorkflowService: trigger_log_repo.update(trigger_log) session.commit() - raise InvokeRateLimitError( + raise WorkflowQuotaLimitError( f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}" ) from e diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 3d7cb6cc8d..946b8cdfdb 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,3 +1,4 @@ +import json import logging import os from collections.abc import Sequence @@ -31,6 +32,11 @@ class BillingService: compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60) + # Redis key prefix for tenant plan cache + _PLAN_CACHE_KEY_PREFIX = "tenant_plan:" + # Cache TTL: 10 minutes + _PLAN_CACHE_TTL = 600 + @classmethod def get_info(cls, tenant_id: str): params = {"tenant_id": tenant_id} @@ -125,7 +131,7 @@ class BillingService: headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = httpx.request(method, url, json=json, params=params, headers=headers) + response = httpx.request(method, url, json=json, params=params, headers=headers, follow_redirects=True) if method == "GET" and response.status_code != httpx.codes.OK: raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") if method == "PUT": @@ -137,6 +143,9 @@ class BillingService: raise ValueError("Invalid arguments.") if method == "POST" and response.status_code != httpx.codes.OK: raise ValueError(f"Unable to send request to {url}. Please try again later or contact support.") + if method == "DELETE" and response.status_code != httpx.codes.OK: + logger.error("billing_service: DELETE response: %s %s", response.status_code, response.text) + raise ValueError(f"Unable to process delete request {url}. Please try again later or contact support.") return response.json() @staticmethod @@ -159,7 +168,7 @@ class BillingService: def delete_account(cls, account_id: str): """Delete account.""" params = {"account_id": account_id} - return cls._send_request("DELETE", "/account/", params=params) + return cls._send_request("DELETE", "/account", params=params) @classmethod def is_email_in_freeze(cls, email: str) -> bool: @@ -272,14 +281,110 @@ class BillingService: data = resp.get("data", {}) for tenant_id, plan in data.items(): - subscription_plan = subscription_adapter.validate_python(plan) - results[tenant_id] = subscription_plan + try: + subscription_plan = subscription_adapter.validate_python(plan) + results[tenant_id] = subscription_plan + except Exception: + logger.exception( + "get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id + ) + continue except Exception: - logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) + logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk) continue return results + @classmethod + def _make_plan_cache_key(cls, tenant_id: str) -> str: + return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}" + + @classmethod + def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: + """ + Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios. + + NOTE: if you want to high data consistency, use get_plan_bulk instead. + + Returns: + Mapping of tenant_id -> {plan: str, expiration_date: int} + """ + tenant_plans: dict[str, SubscriptionPlan] = {} + + if not tenant_ids: + return tenant_plans + + subscription_adapter = TypeAdapter(SubscriptionPlan) + + # Step 1: Batch fetch from Redis cache using mget + redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids] + try: + cached_values = redis_client.mget(redis_keys) + + if len(cached_values) != len(tenant_ids): + raise Exception( + "get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch" + ) + + # Map cached values back to tenant_ids + cache_misses: list[str] = [] + + for tenant_id, cached_value in zip(tenant_ids, cached_values): + if cached_value: + try: + # Redis returns bytes, decode to string and parse JSON + json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value + plan_dict = json.loads(json_str) + subscription_plan = subscription_adapter.validate_python(plan_dict) + tenant_plans[tenant_id] = subscription_plan + except Exception: + logger.exception( + "get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id + ) + cache_misses.append(tenant_id) + else: + cache_misses.append(tenant_id) + + logger.info( + "get_plan_bulk_with_cache: cache hits=%s, cache misses=%s", + len(tenant_plans), + len(cache_misses), + ) + except Exception: + logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API") + cache_misses = list(tenant_ids) + + # Step 2: Fetch missing plans from billing API + if cache_misses: + bulk_plans = BillingService.get_plan_bulk(cache_misses) + + if bulk_plans: + plans_to_cache: dict[str, SubscriptionPlan] = {} + + for tenant_id, subscription_plan in bulk_plans.items(): + tenant_plans[tenant_id] = subscription_plan + plans_to_cache[tenant_id] = subscription_plan + + # Step 3: Batch update Redis cache using pipeline + if plans_to_cache: + try: + pipe = redis_client.pipeline() + for tenant_id, subscription_plan in plans_to_cache.items(): + redis_key = cls._make_plan_cache_key(tenant_id) + # Serialize dict to JSON string + json_str = json.dumps(subscription_plan) + pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str) + pipe.execute() + + logger.info( + "get_plan_bulk_with_cache: cached %s new tenant plans to Redis", + len(plans_to_cache), + ) + except Exception: + logger.exception("get_plan_bulk_with_cache: redis pipeline failed") + + return tenant_plans + @classmethod def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]: resp = cls._send_request("GET", "/subscription/cleanup/whitelist") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 659e7406fb..295d48d8a1 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -11,13 +11,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.variables.types import SegmentType -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable from models.model import App, Conversation, EndUser, Message +from services.conversation_variable_updater import ConversationVariableUpdater from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, @@ -218,7 +218,9 @@ class ConversationService: # Apply variable_name filter if provided if variable_name: # Filter using JSON extraction to match variable names case-insensitively - escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + from libs.helper import escape_like_pattern + + escaped_variable_name = escape_like_pattern(variable_name) # Filter using JSON extraction to match variable names case-insensitively if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: stmt = stmt.where( @@ -335,7 +337,7 @@ class ConversationService: updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) # Use the conversation variable updater to persist the changes - updater = conversation_variable_updater_factory() + updater = ConversationVariableUpdater(session_factory.get_session_maker()) updater.update(conversation_id, updated_variable) updater.flush() diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py new file mode 100644 index 0000000000..92008d5ff1 --- /dev/null +++ b/api/services/conversation_variable_updater.py @@ -0,0 +1,28 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from core.variables.variables import VariableBase +from models import ConversationVariable + + +class ConversationVariableNotFoundError(Exception): + pass + + +class ConversationVariableUpdater: + def __init__(self, session_maker: sessionmaker[Session]) -> None: + self._session_maker: sessionmaker[Session] = session_maker + + def update(self, conversation_id: str, variable: VariableBase) -> None: + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with self._session_maker() as session: + row = session.scalar(stmt) + if not row: + raise ConversationVariableNotFoundError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + def flush(self) -> None: + pass diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py new file mode 100644 index 0000000000..1954602571 --- /dev/null +++ b/api/services/credit_pool_service.py @@ -0,0 +1,85 @@ +import logging + +from sqlalchemy import update +from sqlalchemy.orm import Session + +from configs import dify_config +from core.errors.error import QuotaExceededError +from extensions.ext_database import db +from models import TenantCreditPool + +logger = logging.getLogger(__name__) + + +class CreditPoolService: + @classmethod + def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: + """create default credit pool for new tenant""" + credit_pool = TenantCreditPool( + tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + ) + db.session.add(credit_pool) + db.session.commit() + return credit_pool + + @classmethod + def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: + """get tenant credit pool""" + return ( + db.session.query(TenantCreditPool) + .filter_by( + tenant_id=tenant_id, + pool_type=pool_type, + ) + .first() + ) + + @classmethod + def check_credits_available( + cls, + tenant_id: str, + credits_required: int, + pool_type: str = "trial", + ) -> bool: + """check if credits are available without deducting""" + pool = cls.get_pool(tenant_id, pool_type) + if not pool: + return False + return pool.remaining_credits >= credits_required + + @classmethod + def check_and_deduct_credits( + cls, + tenant_id: str, + credits_required: int, + pool_type: str = "trial", + ) -> int: + """check and deduct credits, returns actual credits deducted""" + + pool = cls.get_pool(tenant_id, pool_type) + if not pool: + raise QuotaExceededError("Credit pool not found") + + if pool.remaining_credits <= 0: + raise QuotaExceededError("No credits remaining") + + # deduct all remaining credits if less than required + actual_credits = min(credits_required, pool.remaining_credits) + + try: + with Session(db.engine) as session: + stmt = ( + update(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, + ) + .values(quota_used=TenantCreditPool.quota_used + actual_credits) + ) + session.execute(stmt) + session.commit() + except Exception: + logger.exception("Failed to deduct credits for tenant %s", tenant_id) + raise QuotaExceededError("Failed to deduct credits") + + return actual_credits diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac4b25c5dc..be9a0e9279 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -13,10 +13,11 @@ import sqlalchemy as sa from redis.exceptions import LockNotOwnedError from sqlalchemy import exists, func, select from sqlalchemy.orm import Session -from werkzeug.exceptions import NotFound +from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.file import helpers as file_helpers from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType @@ -73,6 +74,7 @@ from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService +from services.file_service import FileService from services.rag_pipeline.rag_pipeline import RagPipelineService from services.tag_service import TagService from services.vector_service import VectorService @@ -144,7 +146,8 @@ class DatasetService: query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.where(Dataset.name.ilike(f"%{search}%")) + escaped_search = helper.escape_like_pattern(search) + query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: @@ -1161,6 +1164,7 @@ class DocumentService: Document.archived.is_(True), ), } + DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION = ".zip" @classmethod def normalize_display_status(cls, status: str | None) -> str | None: @@ -1287,6 +1291,143 @@ class DocumentService: else: return None + @staticmethod + def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequence[Document]: + """Fetch documents for a dataset in a single batch query.""" + if not document_ids: + return [] + document_id_list: list[str] = [str(document_id) for document_id in document_ids] + # Fetch all requested documents in one query to avoid N+1 lookups. + documents: Sequence[Document] = db.session.scalars( + select(Document).where( + Document.dataset_id == dataset_id, + Document.id.in_(document_id_list), + ) + ).all() + return documents + + @staticmethod + def get_document_download_url(document: Document) -> str: + """ + Return a signed download URL for an upload-file document. + """ + upload_file = DocumentService._get_upload_file_for_upload_file_document(document) + return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True) + + @staticmethod + def prepare_document_batch_download_zip( + *, + dataset_id: str, + document_ids: Sequence[str], + tenant_id: str, + current_user: Account, + ) -> tuple[list[UploadFile], str]: + """ + Resolve upload files for batch ZIP downloads and generate a client-visible filename. + """ + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + try: + DatasetService.check_dataset_permission(dataset, current_user) + except NoPermissionError as e: + raise Forbidden(str(e)) + + upload_files_by_document_id = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset_id, + document_ids=document_ids, + tenant_id=tenant_id, + ) + upload_files = [upload_files_by_document_id[document_id] for document_id in document_ids] + download_name = DocumentService._generate_document_batch_download_zip_filename() + return upload_files, download_name + + @staticmethod + def _generate_document_batch_download_zip_filename() -> str: + """ + Generate a random attachment filename for the batch download ZIP. + """ + return f"{uuid.uuid4().hex}{DocumentService.DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION}" + + @staticmethod + def _get_upload_file_id_for_upload_file_document( + document: Document, + *, + invalid_source_message: str, + missing_file_message: str, + ) -> str: + """ + Normalize and validate `Document -> UploadFile` linkage for download flows. + """ + if document.data_source_type != "upload_file": + raise NotFound(invalid_source_message) + + data_source_info: dict[str, Any] = document.data_source_info_dict or {} + upload_file_id: str | None = data_source_info.get("upload_file_id") + if not upload_file_id: + raise NotFound(missing_file_message) + + return str(upload_file_id) + + @staticmethod + def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile: + """ + Load the `UploadFile` row for an upload-file document. + """ + upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="Document does not have an uploaded file to download.", + missing_file_message="Uploaded file not found.", + ) + upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id]) + upload_file = upload_files_by_id.get(upload_file_id) + if not upload_file: + raise NotFound("Uploaded file not found.") + return upload_file + + @staticmethod + def _get_upload_files_by_document_id_for_zip_download( + *, + dataset_id: str, + document_ids: Sequence[str], + tenant_id: str, + ) -> dict[str, UploadFile]: + """ + Batch load upload files keyed by document id for ZIP downloads. + """ + document_id_list: list[str] = [str(document_id) for document_id in document_ids] + + documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list) + documents_by_id: dict[str, Document] = {str(document.id): document for document in documents} + + missing_document_ids: set[str] = set(document_id_list) - set(documents_by_id.keys()) + if missing_document_ids: + raise NotFound("Document not found.") + + upload_file_ids: list[str] = [] + upload_file_ids_by_document_id: dict[str, str] = {} + for document_id, document in documents_by_id.items(): + if document.tenant_id != tenant_id: + raise Forbidden("No permission.") + + upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="Only uploaded-file documents can be downloaded as ZIP.", + missing_file_message="Only uploaded-file documents can be downloaded as ZIP.", + ) + upload_file_ids.append(upload_file_id) + upload_file_ids_by_document_id[document_id] = upload_file_id + + upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids) + missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys()) + if missing_upload_file_ids: + raise NotFound("Only uploaded-file documents can be downloaded as ZIP.") + + return { + document_id: upload_files_by_id[upload_file_id] + for document_id, upload_file_id in upload_file_ids_by_document_id.items() + } + @staticmethod def get_document_by_id(document_id: str) -> Document | None: document = db.session.query(Document).where(Document.id == document_id).first() @@ -3423,7 +3564,8 @@ class SegmentService: .order_by(ChildChunk.position.asc()) ) if keyword: - query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\")) return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod @@ -3456,7 +3598,8 @@ class SegmentService: query = query.where(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\")) query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index bdc960aa2d..e3832475aa 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -1,9 +1,14 @@ +import logging import os from collections.abc import Mapping from typing import Any import httpx +from core.helper.trace_id_helper import generate_traceparent_header + +logger = logging.getLogger(__name__) + class BaseRequest: proxies: Mapping[str, str] | None = { @@ -38,6 +43,15 @@ class BaseRequest: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" mounts = cls._build_mounts() + + try: + # ensure traceparent even when OTEL is disabled + traceparent = generate_traceparent_header() + if traceparent: + headers["traceparent"] = traceparent + except Exception: + logger.debug("Failed to generate traceparent header", exc_info=True) + with httpx.Client(mounts=mounts) as client: response = client.request(method, url, json=json, params=params, headers=headers) return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 83d0fcf296..a5133dfcb4 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -13,6 +13,23 @@ class WebAppSettings(BaseModel): ) +class WorkspacePermission(BaseModel): + workspace_id: str = Field( + description="The ID of the workspace.", + alias="workspaceId", + ) + allow_member_invite: bool = Field( + description="Whether to allow members to invite new members to the workspace.", + default=False, + alias="allowMemberInvite", + ) + allow_owner_transfer: bool = Field( + description="Whether to allow owners to transfer ownership of the workspace.", + default=False, + alias="allowOwnerTransfer", + ) + + class EnterpriseService: @classmethod def get_info(cls): @@ -44,6 +61,16 @@ class EnterpriseService: except ValueError as e: raise ValueError(f"Invalid date format: {data}") from e + class WorkspacePermissionService: + @classmethod + def get_permission(cls, workspace_id: str): + if not workspace_id: + raise ValueError("workspace_id must be provided.") + data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission") + if not data or "permission" not in data: + raise ValueError("No data found.") + return WorkspacePermission.model_validate(data["permission"]) + class WebAppAuth: @classmethod def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str): @@ -110,5 +137,5 @@ class EnterpriseService: if not app_id: raise ValueError("app_id must be provided.") - body = {"appId": app_id} - EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) + params = {"appId": app_id} + EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params) diff --git a/api/services/enterprise/workspace_sync.py b/api/services/enterprise/workspace_sync.py new file mode 100644 index 0000000000..acfe325397 --- /dev/null +++ b/api/services/enterprise/workspace_sync.py @@ -0,0 +1,58 @@ +import json +import logging +import uuid +from datetime import UTC, datetime + +from redis import RedisError + +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + +WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue" +WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing" + + +class WorkspaceSyncService: + """Service to publish workspace sync tasks to Redis queue for enterprise backend consumption""" + + @staticmethod + def queue_credential_sync(workspace_id: str, *, source: str) -> bool: + """ + Queue a credential sync task for a newly created workspace. + + This publishes a task to Redis that will be consumed by the enterprise backend + worker to sync credentials with the plugin-manager. + + Args: + workspace_id: The workspace/tenant ID to sync credentials for + source: Source of the sync request (for debugging/tracking) + + Returns: + bool: True if task was queued successfully, False otherwise + """ + try: + task = { + "task_id": str(uuid.uuid4()), + "workspace_id": workspace_id, + "retry_count": 0, + "created_at": datetime.now(UTC).isoformat(), + "source": source, + } + + # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP + redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task)) + + logger.info( + "Queued credential sync task for workspace %s, task_id: %s, source: %s", + workspace_id, + task["task_id"], + source, + ) + return True + + except (RedisError, TypeError) as e: + logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True) + # Don't raise - we don't want to fail workspace creation if queueing fails + # The scheduled task will catch it later + return False diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index f405546909..a29d848ac5 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -70,7 +70,6 @@ class ProviderResponse(BaseModel): description: I18nObject | None = None icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None supported_model_types: Sequence[ModelType] @@ -98,11 +97,6 @@ class ProviderResponse(BaseModel): en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans", ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self @@ -116,7 +110,6 @@ class ProviderWithModelsResponse(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] @@ -134,11 +127,6 @@ class ProviderWithModelsResponse(BaseModel): self.icon_small_dark = I18nObject( en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans" ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self @@ -163,11 +151,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): self.icon_small_dark = I18nObject( en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans" ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self diff --git a/api/services/errors/app.py b/api/services/errors/app.py index 24e4760acc..60e59e97dc 100644 --- a/api/services/errors/app.py +++ b/api/services/errors/app.py @@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception): pass -class InvokeRateLimitError(Exception): - """Raised when rate limit is exceeded for workflow invocations.""" +class WorkflowQuotaLimitError(Exception): + """Raised when workflow execution quota is exceeded (for async/background workflows).""" pass diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 40faa85b9a..65dd41af43 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -35,7 +35,10 @@ class ExternalDatasetService: .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: - query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + from libs.helper import escape_like_pattern + + escaped_search = escape_like_pattern(search) + query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\")) external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 112855a748..fda3a15144 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field from configs import dify_config from enums.cloud_plan import CloudPlan +from enums.hosted_provider import HostedTrialProvider from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -142,6 +143,7 @@ class FeatureModel(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) knowledge_pipeline: KnowledgePipeline = KnowledgePipeline() + next_credit_reset_date: int = 0 class KnowledgeRateLimitModel(BaseModel): @@ -171,6 +173,9 @@ class SystemFeatureModel(BaseModel): plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() enable_change_email: bool = True plugin_manager: PluginManagerModel = PluginManagerModel() + trial_models: list[str] = [] + enable_trial_app: bool = False + enable_explore_banner: bool = False class FeatureService: @@ -217,7 +222,7 @@ class FeatureService: ) @classmethod - def get_system_features(cls) -> SystemFeatureModel: + def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() cls._fulfill_system_params_from_env(system_features) @@ -227,7 +232,7 @@ class FeatureService: system_features.webapp_auth.enabled = True system_features.enable_change_email = False system_features.plugin_manager.enabled = True - cls._fulfill_params_from_enterprise(system_features) + cls._fulfill_params_from_enterprise(system_features, is_authenticated) if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True @@ -242,6 +247,20 @@ class FeatureService: system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" + system_features.trial_models = cls._fulfill_trial_models_from_env() + system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP + system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER + + @classmethod + def _fulfill_trial_models_from_env(cls) -> list[str]: + return [ + provider.value + for provider in HostedTrialProvider + if ( + getattr(dify_config, f"HOSTED_{provider.config_key}_PAID_ENABLED", False) + and getattr(dify_config, f"HOSTED_{provider.config_key}_TRIAL_ENABLED", False) + ) + ] @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): @@ -319,8 +338,11 @@ class FeatureService: if "knowledge_pipeline_publish_enabled" in billing_info: features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"] + if "next_credit_reset_date" in billing_info: + features.next_credit_reset_date = billing_info["next_credit_reset_date"] + @classmethod - def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): + def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False): enterprise_info = EnterpriseService.get_info() if "SSOEnforcedForSignin" in enterprise_info: @@ -357,19 +379,14 @@ class FeatureService: ) features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") - if "License" in enterprise_info: - license_info = enterprise_info["License"] + if is_authenticated and (license_info := enterprise_info.get("License")): + features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + features.license.expired_at = license_info.get("expiredAt", "") - if "status" in license_info: - features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - - if "expiredAt" in license_info: - features.license.expired_at = license_info["expiredAt"] - - if "workspaces" in license_info: - features.license.workspaces.enabled = license_info["workspaces"]["enabled"] - features.license.workspaces.limit = license_info["workspaces"]["limit"] - features.license.workspaces.size = license_info["workspaces"]["used"] + if workspaces_info := license_info.get("workspaces"): + features.license.workspaces.enabled = workspaces_info.get("enabled", False) + features.license.workspaces.limit = workspaces_info.get("limit", 0) + features.license.workspaces.size = workspaces_info.get("used", 0) if "PluginInstallationPermission" in enterprise_info: plugin_installation_info = enterprise_info["PluginInstallationPermission"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 0911cf38c4..a0a99f3f82 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -2,7 +2,11 @@ import base64 import hashlib import os import uuid +from collections.abc import Iterator, Sequence +from contextlib import contextmanager, suppress +from tempfile import NamedTemporaryFile from typing import Literal, Union +from zipfile import ZIP_DEFLATED, ZipFile from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -17,6 +21,7 @@ from constants import ( ) from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor +from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id @@ -167,6 +172,9 @@ class FileService: return upload_file def get_file_preview(self, file_id: str): + """ + Return a short text preview extracted from a document file. + """ with self._session_maker(expire_on_commit=False) as session: upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() @@ -253,3 +261,101 @@ class FileService: return storage.delete(upload_file.key) session.delete(upload_file) + + @staticmethod + def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]: + """ + Fetch `UploadFile` rows for a tenant in a single batch query. + + This is a generic `UploadFile` lookup helper (not dataset/document specific), so it lives in `FileService`. + """ + if not upload_file_ids: + return {} + + # Normalize and deduplicate ids before using them in the IN clause. + upload_file_id_list: list[str] = [str(upload_file_id) for upload_file_id in upload_file_ids] + unique_upload_file_ids: list[str] = list(set(upload_file_id_list)) + + # Fetch upload files in one query for efficient batch access. + upload_files: Sequence[UploadFile] = db.session.scalars( + select(UploadFile).where( + UploadFile.tenant_id == tenant_id, + UploadFile.id.in_(unique_upload_file_ids), + ) + ).all() + return {str(upload_file.id): upload_file for upload_file in upload_files} + + @staticmethod + def _sanitize_zip_entry_name(name: str) -> str: + """ + Sanitize a ZIP entry name to avoid path traversal and weird separators. + + We keep this conservative: the upload flow already rejects `/` and `\\`, but older rows (or imported data) + could still contain unsafe names. + """ + # Drop any directory components and prevent empty names. + base = os.path.basename(name).strip() or "file" + + # ZIP uses forward slashes as separators; remove any residual separator characters. + return base.replace("/", "_").replace("\\", "_") + + @staticmethod + def _dedupe_zip_entry_name(original_name: str, used_names: set[str]) -> str: + """ + Return a unique ZIP entry name, inserting suffixes before the extension. + """ + # Keep the original name when it's not already used. + if original_name not in used_names: + return original_name + + # Insert suffixes before the extension (e.g., "doc.txt" -> "doc (1).txt"). + stem, extension = os.path.splitext(original_name) + suffix = 1 + while True: + candidate = f"{stem} ({suffix}){extension}" + if candidate not in used_names: + return candidate + suffix += 1 + + @staticmethod + @contextmanager + def build_upload_files_zip_tempfile( + *, + upload_files: Sequence[UploadFile], + ) -> Iterator[str]: + """ + Build a ZIP from `UploadFile`s and yield a tempfile path. + + We yield a path (rather than an open file handle) to avoid "read of closed file" issues when Flask/Werkzeug + streams responses. The caller is expected to keep this context open until the response is fully sent, then + close it (e.g., via `response.call_on_close(...)`) to delete the tempfile. + """ + used_names: set[str] = set() + + # Build a ZIP in a temp file and keep it on disk until the caller finishes streaming it. + tmp_path: str | None = None + try: + with NamedTemporaryFile(mode="w+b", suffix=".zip", delete=False) as tmp: + tmp_path = tmp.name + with ZipFile(tmp, mode="w", compression=ZIP_DEFLATED) as zf: + for upload_file in upload_files: + # Ensure the entry name is safe and unique. + safe_name = FileService._sanitize_zip_entry_name(upload_file.name) + arcname = FileService._dedupe_zip_entry_name(safe_name, used_names) + used_names.add(arcname) + + # Stream file bytes from storage into the ZIP entry. + with zf.open(arcname, "w") as entry: + for chunk in storage.load(upload_file.key, stream=True): + entry.write(chunk) + + # Flush so `send_file(path, ...)` can re-open it safely on all platforms. + tmp.flush() + + assert tmp_path is not None + yield tmp_path + finally: + # Remove the temp file when the context is closed (typically after the response finishes streaming). + if tmp_path is not None: + with suppress(FileNotFoundError): + os.remove(tmp_path) diff --git a/api/services/message_service.py b/api/services/message_service.py index 8c9e820e80..8cfa6512be 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -286,10 +286,9 @@ class MessageService: else: conversation_override_model_configs = json.loads(conversation.override_model_configs) app_model_config = AppModelConfig( - id=conversation.app_model_config_id, app_id=app_model.id, ) - + app_model_config.id = conversation.app_model_config_id app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) if not app_model_config: raise ValueError("did not find app model config") diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index eea382febe..edd1004b82 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -99,7 +99,6 @@ class ModelProviderService: description=provider_configuration.provider.description, icon_small=provider_configuration.provider.icon_small, icon_small_dark=provider_configuration.provider.icon_small_dark, - icon_large=provider_configuration.provider.icon_large, background=provider_configuration.provider.background, help=provider_configuration.provider.help, supported_model_types=provider_configuration.provider.supported_model_types, @@ -423,7 +422,6 @@ class ModelProviderService: label=first_model.provider.label, icon_small=first_model.provider.icon_small, icon_small_dark=first_model.provider.icon_small_dark, - icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, models=[ ProviderModelWithStatusEntity( @@ -488,7 +486,6 @@ class ModelProviderService: provider=result.provider.provider, label=result.provider.label, icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, supported_model_types=result.provider.supported_model_types, ), ) @@ -522,7 +519,7 @@ class ModelProviderService: :param tenant_id: workspace id :param provider: provider name - :param icon_type: icon type (icon_small or icon_large) + :param icon_type: icon type (icon_small or icon_small_dark) :param lang: language (zh_Hans or en_US) :return: """ diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 5dcbf5fec5..40565c56ed 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -146,7 +146,7 @@ class PluginParameterService: provider, action, resolved_credentials, - CredentialType.API_KEY.value, + original_subscription.credential_type or CredentialType.UNAUTHORIZED.value, parameter, ) .options diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index b8303eb724..411c335c17 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence from mimetypes import guess_type from pydantic import BaseModel +from sqlalchemy import select from yarl import URL from configs import dify_config @@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import ( from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.plugin import PluginInstaller +from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.provider import ProviderCredential from models.provider_ids import GenericProviderID from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import FeatureService, PluginInstallationScope @@ -506,6 +509,33 @@ class PluginService: @staticmethod def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: manager = PluginInstaller() + + # Get plugin info before uninstalling to delete associated credentials + try: + plugins = manager.list_plugins(tenant_id) + plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None) + + if plugin: + plugin_id = plugin.plugin_id + logger.info("Deleting credentials for plugin: %s", plugin_id) + + # Delete provider credentials that match this plugin + credentials = db.session.scalars( + select(ProviderCredential).where( + ProviderCredential.tenant_id == tenant_id, + ProviderCredential.provider_name.like(f"{plugin_id}/%"), + ) + ).all() + + for cred in credentials: + db.session.delete(cred) + + db.session.commit() + logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id) + except Exception as e: + logger.warning("Failed to delete credentials: %s", e) + # Continue with uninstall even if credential deletion fails + return manager.uninstall(tenant_id, plugin_installation_id) @staticmethod diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f53448e7fe..ccc6abcc06 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,7 +36,7 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import Variable +from core.variables.variables import VariableBase from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -270,8 +270,8 @@ class RagPipelineService: graph: dict, unique_hash: str | None, account: Account, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], rag_pipeline_variables: list, ) -> Workflow: """ @@ -436,7 +436,7 @@ class RagPipelineService: user_inputs=user_inputs, user_id=account.id, variable_pool=VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs=user_inputs, environment_variables=[], conversation_variables=[], @@ -874,7 +874,7 @@ class RagPipelineService: variable_pool = node_instance.graph_runtime_state.variable_pool invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) if invoke_from: - if invoke_from.value == InvokeFrom.PUBLISHED: + if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() @@ -1318,7 +1318,7 @@ class RagPipelineService: "datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)], "original_document_id": document.id, }, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, streaming=False, call_depth=0, workflow_thread_pool_id=None, diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 06f294863d..c1c6e204fb 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -870,15 +870,16 @@ class RagPipelineDslService: return dependencies @classmethod - def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]: + def get_leaked_dependencies( + cls, tenant_id: str, dsl_dependencies: list[PluginDependency] + ) -> list[PluginDependency]: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] - if not dependencies: + if not dsl_dependencies: return [] - return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) + return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies) def _generate_aes_key(self, tenant_id: str) -> bytes: """Generate AES key based on tenant_id""" diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 84f97907c0..8ea365e907 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -44,7 +44,7 @@ class RagPipelineTransformService: doc_form = dataset.doc_form if not doc_form: return self._transform_to_empty_pipeline(dataset) - retrieval_model = dataset.retrieval_model + retrieval_model = RetrievalSetting.model_validate(dataset.retrieval_model) if dataset.retrieval_model else None pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique) # deal dependencies self._deal_dependencies(pipeline_yaml, dataset.tenant_id) @@ -154,7 +154,12 @@ class RagPipelineTransformService: return node def _deal_knowledge_index( - self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict + self, + dataset: Dataset, + doc_form: str, + indexing_technique: str | None, + retrieval_model: RetrievalSetting | None, + node: dict, ): knowledge_configuration_dict = node.get("data", {}) knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict) @@ -163,10 +168,9 @@ class RagPipelineTransformService: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - retrieval_setting = RetrievalSetting.model_validate(retrieval_model) if indexing_technique == "economy": - retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH - knowledge_configuration.retrieval_model = retrieval_setting + retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH + knowledge_configuration.retrieval_model = retrieval_model else: dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 544383a106..6b211a5632 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,4 +1,7 @@ from configs import dify_config +from extensions.ext_database import db +from models.model import AccountTrialAppRecord, TrialApp +from services.feature_service import FeatureService from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory @@ -20,6 +23,15 @@ class RecommendedAppService: ) ) + if FeatureService.get_system_features().enable_trial_app: + apps = result["recommended_apps"] + for app in apps: + app_id = app["app_id"] + trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + if trial_app_model: + app["can_trial"] = True + else: + app["can_trial"] = False return result @classmethod @@ -32,4 +44,30 @@ class RecommendedAppService: mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() result: dict = retrieval_instance.get_recommend_app_detail(app_id) + if FeatureService.get_system_features().enable_trial_app: + app_id = result["id"] + trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + if trial_app_model: + result["can_trial"] = True + else: + result["can_trial"] = False return result + + @classmethod + def add_trial_app_record(cls, app_id: str, account_id: str): + """ + Add trial app record. + :param app_id: app id + :return: + """ + account_trial_app_record = ( + db.session.query(AccountTrialAppRecord) + .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) + .first() + ) + if account_trial_app_record: + account_trial_app_record.count += 1 + db.session.commit() + else: + db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id)) + db.session.commit() diff --git a/api/services/retention/__init__.py b/api/services/retention/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py new file mode 100644 index 0000000000..6e647b983b --- /dev/null +++ b/api/services/retention/conversation/messages_clean_policy.py @@ -0,0 +1,216 @@ +import datetime +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from services.billing_service import BillingService, SubscriptionPlan + +logger = logging.getLogger(__name__) + + +@dataclass +class SimpleMessage: + id: str + app_id: str + created_at: datetime.datetime + + +class MessagesCleanPolicy(ABC): + """ + Abstract base class for message cleanup policies. + + A policy determines which messages from a batch should be deleted. + """ + + @abstractmethod + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + """ + Filter messages and return IDs of messages that should be deleted. + + Args: + messages: Batch of messages to evaluate + app_to_tenant: Mapping from app_id to tenant_id + + Returns: + List of message IDs that should be deleted + """ + ... + + +class BillingDisabledPolicy(MessagesCleanPolicy): + """ + Policy for community or enterpriseedition (billing disabled). + + No special filter logic, just return all message ids. + """ + + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + return [msg.id for msg in messages] + + +class BillingSandboxPolicy(MessagesCleanPolicy): + """ + Policy for sandbox plan tenants in cloud edition (billing enabled). + + Filters messages based on sandbox plan expiration rules: + - Skip tenants in the whitelist + - Only delete messages from sandbox plan tenants + - Respect grace period after subscription expiration + - Safe default: if tenant mapping or plan is missing, do NOT delete + """ + + def __init__( + self, + plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]], + graceful_period_days: int = 21, + tenant_whitelist: Sequence[str] | None = None, + current_timestamp: int | None = None, + ) -> None: + self._graceful_period_days = graceful_period_days + self._tenant_whitelist: Sequence[str] = tenant_whitelist or [] + self._plan_provider = plan_provider + self._current_timestamp = current_timestamp + + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + """ + Filter messages based on sandbox plan expiration rules. + + Args: + messages: Batch of messages to evaluate + app_to_tenant: Mapping from app_id to tenant_id + + Returns: + List of message IDs that should be deleted + """ + if not messages or not app_to_tenant: + return [] + + # Get unique tenant_ids and fetch subscription plans + tenant_ids = list(set(app_to_tenant.values())) + tenant_plans = self._plan_provider(tenant_ids) + + if not tenant_plans: + return [] + + # Apply sandbox deletion rules + return self._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + ) + + def _filter_expired_sandbox_messages( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + tenant_plans: dict[str, SubscriptionPlan], + ) -> list[str]: + """ + Filter messages that should be deleted based on sandbox plan expiration. + + A message should be deleted if: + 1. It belongs to a sandbox tenant AND + 2. Either: + a) The tenant has no previous subscription (expiration_date == -1), OR + b) The subscription expired more than graceful_period_days ago + + Args: + messages: List of message objects with id and app_id attributes + app_to_tenant: Mapping from app_id to tenant_id + tenant_plans: Mapping from tenant_id to subscription plan info + + Returns: + List of message IDs that should be deleted + """ + current_timestamp = self._current_timestamp + if current_timestamp is None: + current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + sandbox_message_ids: list[str] = [] + graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60 + + for msg in messages: + # Get tenant_id for this message's app + tenant_id = app_to_tenant.get(msg.app_id) + if not tenant_id: + continue + + # Skip tenant messages in whitelist + if tenant_id in self._tenant_whitelist: + continue + + # Get subscription plan for this tenant + tenant_plan = tenant_plans.get(tenant_id) + if not tenant_plan: + continue + + plan = str(tenant_plan["plan"]) + expiration_date = int(tenant_plan["expiration_date"]) + + # Only process sandbox plans + if plan != CloudPlan.SANDBOX: + continue + + # Case 1: No previous subscription (-1 means never had a paid subscription) + if expiration_date == -1: + sandbox_message_ids.append(msg.id) + continue + + # Case 2: Subscription expired beyond grace period + if current_timestamp - expiration_date > graceful_period_seconds: + sandbox_message_ids.append(msg.id) + + return sandbox_message_ids + + +def create_message_clean_policy( + graceful_period_days: int = 21, + current_timestamp: int | None = None, +) -> MessagesCleanPolicy: + """ + Factory function to create the appropriate message clean policy. + + Determines which policy to use based on BILLING_ENABLED configuration: + - If BILLING_ENABLED is True: returns BillingSandboxPolicy + - If BILLING_ENABLED is False: returns BillingDisabledPolicy + + Args: + graceful_period_days: Grace period in days after subscription expiration (default: 21) + current_timestamp: Current Unix timestamp for testing (default: None, uses current time) + """ + if not dify_config.BILLING_ENABLED: + logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy") + return BillingDisabledPolicy() + + # Billing enabled - fetch whitelist from BillingService + tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist() + plan_provider = BillingService.get_plan_bulk_with_cache + + logger.info( + "create_message_clean_policy: billing enabled, using BillingSandboxPolicy " + "(graceful_period_days=%s, whitelist=%s)", + graceful_period_days, + tenant_whitelist, + ) + + return BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=graceful_period_days, + tenant_whitelist=tenant_whitelist, + current_timestamp=current_timestamp, + ) diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py new file mode 100644 index 0000000000..3ca5d82860 --- /dev/null +++ b/api/services/retention/conversation/messages_clean_service.py @@ -0,0 +1,334 @@ +import datetime +import logging +import random +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.model import ( + App, + AppAnnotationHitHistory, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.retention.conversation.messages_clean_policy import ( + MessagesCleanPolicy, + SimpleMessage, +) + +logger = logging.getLogger(__name__) + + +class MessagesCleanService: + """ + Service for cleaning expired messages based on retention policies. + + Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted. + If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support). + """ + + def __init__( + self, + policy: MessagesCleanPolicy, + end_before: datetime.datetime, + start_from: datetime.datetime | None = None, + batch_size: int = 1000, + dry_run: bool = False, + ) -> None: + """ + Initialize the service with cleanup parameters. + + Args: + policy: The policy that determines which messages to delete + end_before: End time (exclusive) of the range + start_from: Optional start time (inclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + """ + self._policy = policy + self._end_before = end_before + self._start_from = start_from + self._batch_size = batch_size + self._dry_run = dry_run + + @classmethod + def from_time_range( + cls, + policy: MessagesCleanPolicy, + start_from: datetime.datetime, + end_before: datetime.datetime, + batch_size: int = 1000, + dry_run: bool = False, + ) -> "MessagesCleanService": + """ + Create a service instance for cleaning messages within a specific time range. + + Time range is [start_from, end_before). + + Args: + policy: The policy that determines which messages to delete + start_from: Start time (inclusive) of the range + end_before: End time (exclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + MessagesCleanService instance + + Raises: + ValueError: If start_from >= end_before or invalid parameters + """ + if start_from >= end_before: + raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + logger.info( + "clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s", + start_from, + end_before, + batch_size, + policy.__class__.__name__, + ) + + return cls( + policy=policy, + end_before=end_before, + start_from=start_from, + batch_size=batch_size, + dry_run=dry_run, + ) + + @classmethod + def from_days( + cls, + policy: MessagesCleanPolicy, + days: int = 30, + batch_size: int = 1000, + dry_run: bool = False, + ) -> "MessagesCleanService": + """ + Create a service instance for cleaning messages older than specified days. + + Args: + policy: The policy that determines which messages to delete + days: Number of days to look back from now + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + MessagesCleanService instance + + Raises: + ValueError: If invalid parameters + """ + if days < 0: + raise ValueError(f"days ({days}) must be greater than or equal to 0") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + end_before = datetime.datetime.now() - datetime.timedelta(days=days) + + logger.info( + "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s", + days, + end_before, + batch_size, + policy.__class__.__name__, + ) + + return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run) + + def run(self) -> dict[str, int]: + """ + Execute the message cleanup operation. + + Returns: + Dict with statistics: batches, filtered_messages, total_deleted + """ + return self._clean_messages_by_time_range() + + def _clean_messages_by_time_range(self) -> dict[str, int]: + """ + Clean messages within a time range using cursor-based pagination. + + Time range is [start_from, end_before) + + Steps: + 1. Iterate messages using cursor pagination (by created_at, id) + 2. Query app_id -> tenant_id mapping + 3. Delegate to policy to determine which messages to delete + 4. Batch delete messages and their relations + + Returns: + Dict with statistics: batches, filtered_messages, total_deleted + """ + stats = { + "batches": 0, + "total_messages": 0, + "filtered_messages": 0, + "total_deleted": 0, + } + + # Cursor-based pagination using (created_at, id) to avoid infinite loops + # and ensure proper ordering with time-based filtering + _cursor: tuple[datetime.datetime, str] | None = None + + logger.info( + "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s", + self._dry_run, + self._start_from, + self._end_before, + ) + + while True: + stats["batches"] += 1 + + # Step 1: Fetch a batch of messages using cursor + with Session(db.engine, expire_on_commit=False) as session: + msg_stmt = ( + select(Message.id, Message.app_id, Message.created_at) + .where(Message.created_at < self._end_before) + .order_by(Message.created_at, Message.id) + .limit(self._batch_size) + ) + + if self._start_from: + msg_stmt = msg_stmt.where(Message.created_at >= self._start_from) + + # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id) + # This translates to: + # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id) + if _cursor: + # Continuing from previous batch + msg_stmt = msg_stmt.where( + (Message.created_at > _cursor[0]) + | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1])) + ) + + raw_messages = list(session.execute(msg_stmt).all()) + messages = [ + SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at) + for msg_id, app_id, msg_created_at in raw_messages + ] + + # Track total messages fetched across all batches + stats["total_messages"] += len(messages) + + if not messages: + logger.info("clean_messages (batch %s): no more messages to process", stats["batches"]) + break + + # Update cursor to the last message's (created_at, id) + _cursor = (messages[-1].created_at, messages[-1].id) + + # Step 2: Extract app_ids and query tenant_ids + app_ids = list({msg.app_id for msg in messages}) + + if not app_ids: + logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"]) + continue + + app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids)) + apps = list(session.execute(app_stmt).all()) + + if not apps: + logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"]) + continue + + # Build app_id -> tenant_id mapping + app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps} + + # Step 3: Delegate to policy to determine which messages to delete + message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant) + + if not message_ids_to_delete: + logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"]) + continue + + stats["filtered_messages"] += len(message_ids_to_delete) + + # Step 4: Batch delete messages and their relations + if not self._dry_run: + with Session(db.engine, expire_on_commit=False) as session: + # Delete related records first + self._batch_delete_message_relations(session, message_ids_to_delete) + + # Delete messages + delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete)) + delete_result = cast(CursorResult, session.execute(delete_stmt)) + messages_deleted = delete_result.rowcount + session.commit() + + stats["total_deleted"] += messages_deleted + + logger.info( + "clean_messages (batch %s): processed %s messages, deleted %s messages", + stats["batches"], + len(messages), + messages_deleted, + ) + else: + # Log random sample of message IDs that would be deleted (up to 10) + sample_size = min(10, len(message_ids_to_delete)) + sampled_ids = random.sample(list(message_ids_to_delete), sample_size) + + logger.info( + "clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:", + stats["batches"], + len(message_ids_to_delete), + sample_size, + ) + for msg_id in sampled_ids: + logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id) + + logger.info( + "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s", + stats["batches"], + stats["total_messages"], + stats["filtered_messages"], + stats["total_deleted"], + ) + + return stats + + @staticmethod + def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None: + """ + Batch delete all related records for given message IDs. + + Args: + session: Database session + message_ids: List of message IDs to delete relations for + """ + if not message_ids: + return + + # Delete all related records in batch + session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids))) + + session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids))) + + session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids))) + + session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids))) + + session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids))) + + session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids))) + + session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids))) + + session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids))) diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py new file mode 100644 index 0000000000..18dd42c91e --- /dev/null +++ b/api/services/retention/workflow_run/__init__.py @@ -0,0 +1 @@ +"""Workflow run retention services.""" diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py new file mode 100644 index 0000000000..ea5cbb7740 --- /dev/null +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -0,0 +1,531 @@ +""" +Archive Paid Plan Workflow Run Logs Service. + +This service archives workflow run logs for paid plan users older than the configured +retention period (default: 90 days) to S3-compatible storage. + +Archived tables: +- workflow_runs +- workflow_app_logs +- workflow_node_executions +- workflow_node_execution_offload +- workflow_pauses +- workflow_pause_reasons +- workflow_trigger_logs + +""" + +import datetime +import io +import json +import logging +import time +import zipfile +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any + +import click +from sqlalchemy import inspect +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.workflow.enums import WorkflowType +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageNotConfiguredError, + get_archive_storage, +) +from models.workflow import WorkflowAppLog, WorkflowRun +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.billing_service import BillingService +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION + +logger = logging.getLogger(__name__) + + +@dataclass +class TableStats: + """Statistics for a single archived table.""" + + table_name: str + row_count: int + checksum: str + size_bytes: int + + +@dataclass +class ArchiveResult: + """Result of archiving a single workflow run.""" + + run_id: str + tenant_id: str + success: bool + tables: list[TableStats] = field(default_factory=list) + error: str | None = None + elapsed_time: float = 0.0 + + +@dataclass +class ArchiveSummary: + """Summary of the entire archive operation.""" + + total_runs_processed: int = 0 + runs_archived: int = 0 + runs_skipped: int = 0 + runs_failed: int = 0 + total_elapsed_time: float = 0.0 + + +class WorkflowRunArchiver: + """ + Archive workflow run logs for paid plan users. + + Storage Layout: + {tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/ + └── archive.v1.0.zip + ├── manifest.json + ├── workflow_runs.jsonl + ├── workflow_app_logs.jsonl + ├── workflow_node_executions.jsonl + ├── workflow_node_execution_offload.jsonl + ├── workflow_pauses.jsonl + ├── workflow_pause_reasons.jsonl + └── workflow_trigger_logs.jsonl + """ + + ARCHIVED_TYPE = [ + WorkflowType.WORKFLOW, + WorkflowType.RAG_PIPELINE, + ] + ARCHIVED_TABLES = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", + ] + + start_from: datetime.datetime | None + end_before: datetime.datetime + + def __init__( + self, + days: int = 90, + batch_size: int = 100, + start_from: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + workers: int = 1, + tenant_ids: Sequence[str] | None = None, + limit: int | None = None, + dry_run: bool = False, + delete_after_archive: bool = False, + workflow_run_repo: APIWorkflowRunRepository | None = None, + ): + """ + Initialize the archiver. + + Args: + days: Archive runs older than this many days + batch_size: Number of runs to process per batch + start_from: Optional start time (inclusive) for archiving + end_before: Optional end time (exclusive) for archiving + workers: Number of concurrent workflow runs to archive + tenant_ids: Optional tenant IDs for grayscale rollout + limit: Maximum number of runs to archive (None for unlimited) + dry_run: If True, only preview without making changes + delete_after_archive: If True, delete runs and related data after archiving + """ + self.days = days + self.batch_size = batch_size + if start_from or end_before: + if start_from is None or end_before is None: + raise ValueError("start_from and end_before must be provided together") + if start_from >= end_before: + raise ValueError("start_from must be earlier than end_before") + self.start_from = start_from.replace(tzinfo=datetime.UTC) + self.end_before = end_before.replace(tzinfo=datetime.UTC) + else: + self.start_from = None + self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days) + if workers < 1: + raise ValueError("workers must be at least 1") + self.workers = workers + self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else [] + self.limit = limit + self.dry_run = dry_run + self.delete_after_archive = delete_after_archive + self.workflow_run_repo = workflow_run_repo + + def run(self) -> ArchiveSummary: + """ + Main archiving loop. + + Returns: + ArchiveSummary with statistics about the operation + """ + summary = ArchiveSummary() + start_time = time.time() + + click.echo( + click.style( + self._build_start_message(), + fg="white", + ) + ) + + # Initialize archive storage (will raise if not configured) + try: + if not self.dry_run: + storage = get_archive_storage() + else: + storage = None + except ArchiveStorageNotConfiguredError as e: + click.echo(click.style(f"Archive storage not configured: {e}", fg="red")) + return summary + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = self._get_workflow_run_repo() + + def _archive_with_session(run: WorkflowRun) -> ArchiveResult: + with session_maker() as session: + return self._archive_run(session, storage, run) + + last_seen: tuple[datetime.datetime, str] | None = None + archived_count = 0 + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + while True: + # Check limit + if self.limit and archived_count >= self.limit: + click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow")) + break + + # Fetch batch of runs + runs = self._get_runs_batch(last_seen) + + if not runs: + break + + run_ids = [run.id for run in runs] + with session_maker() as session: + archived_run_ids = repo.get_archived_run_ids(session, run_ids) + + last_seen = (runs[-1].created_at, runs[-1].id) + + # Filter to paid tenants only + tenant_ids = {run.tenant_id for run in runs} + paid_tenants = self._filter_paid_tenants(tenant_ids) + + runs_to_process: list[WorkflowRun] = [] + for run in runs: + summary.total_runs_processed += 1 + + # Skip non-paid tenants + if run.tenant_id not in paid_tenants: + summary.runs_skipped += 1 + continue + + # Skip already archived runs + if run.id in archived_run_ids: + summary.runs_skipped += 1 + continue + + # Check limit + if self.limit and archived_count + len(runs_to_process) >= self.limit: + break + + runs_to_process.append(run) + + if not runs_to_process: + continue + + results = list(executor.map(_archive_with_session, runs_to_process)) + + for run, result in zip(runs_to_process, results): + if result.success: + summary.runs_archived += 1 + archived_count += 1 + click.echo( + click.style( + f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} " + f"run {run.id} (tenant={run.tenant_id}, " + f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)", + fg="green", + ) + ) + else: + summary.runs_failed += 1 + click.echo( + click.style( + f"Failed to archive run {run.id}: {result.error}", + fg="red", + ) + ) + + summary.total_elapsed_time = time.time() - start_time + click.echo( + click.style( + f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: " + f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="white", + ) + ) + + return summary + + def _get_runs_batch( + self, + last_seen: tuple[datetime.datetime, str] | None, + ) -> Sequence[WorkflowRun]: + """Fetch a batch of workflow runs to archive.""" + repo = self._get_workflow_run_repo() + return repo.get_runs_batch_by_time_range( + start_from=self.start_from, + end_before=self.end_before, + last_seen=last_seen, + batch_size=self.batch_size, + run_types=self.ARCHIVED_TYPE, + tenant_ids=self.tenant_ids or None, + ) + + def _build_start_message(self) -> str: + range_desc = f"before {self.end_before.isoformat()}" + if self.start_from: + range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}" + return ( + f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving " + f"for runs {range_desc} " + f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})" + ) + + def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]: + """Filter tenant IDs to only include paid tenants.""" + if not dify_config.BILLING_ENABLED: + # If billing is not enabled, treat all tenants as paid + return tenant_ids + + if not tenant_ids: + return set() + + try: + bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids)) + except Exception: + logger.exception("Failed to fetch billing plans for tenants") + # On error, skip all tenants in this batch + return set() + + # Filter to paid tenants (any plan except SANDBOX) + paid = set() + for tid, info in bulk_info.items(): + if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM): + paid.add(tid) + + return paid + + def _archive_run( + self, + session: Session, + storage: ArchiveStorage | None, + run: WorkflowRun, + ) -> ArchiveResult: + """Archive a single workflow run.""" + start_time = time.time() + result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + + try: + # Extract data from all tables + table_data, app_logs, trigger_metadata = self._extract_data(session, run) + + if self.dry_run: + # In dry run, just report what would be archived + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + result.tables.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum="", + size_bytes=0, + ) + ) + result.success = True + else: + if storage is None: + raise ArchiveStorageNotConfiguredError("Archive storage not configured") + archive_key = self._get_archive_key(run) + + # Serialize tables for the archive bundle + table_stats: list[TableStats] = [] + table_payloads: dict[str, bytes] = {} + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + data = ArchiveStorage.serialize_to_jsonl(records) + table_payloads[table_name] = data + checksum = ArchiveStorage.compute_checksum(data) + + table_stats.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum=checksum, + size_bytes=len(data), + ) + ) + + # Generate and upload archive bundle + manifest = self._generate_manifest(run, table_stats) + manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8") + archive_data = self._build_archive_bundle(manifest_data, table_payloads) + storage.put_object(archive_key, archive_data) + + repo = self._get_workflow_run_repo() + archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata) + session.commit() + + deleted_counts = None + if self.delete_after_archive: + deleted_counts = repo.delete_runs_with_related( + [run], + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + + logger.info( + "Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s", + run.id, + {s.table_name: s.row_count for s in table_stats}, + archived_log_count, + deleted_counts, + ) + + result.tables = table_stats + result.success = True + + except Exception as e: + logger.exception("Failed to archive workflow run %s", run.id) + result.error = str(e) + session.rollback() + + result.elapsed_time = time.time() - start_time + return result + + def _extract_data( + self, + session: Session, + run: WorkflowRun, + ) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]: + table_data: dict[str, list[dict[str, Any]]] = {} + table_data["workflow_runs"] = [self._row_to_dict(run)] + repo = self._get_workflow_run_repo() + app_logs = repo.get_app_logs_by_run_id(session, run.id) + table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs] + node_exec_repo = self._get_workflow_node_execution_repo(session) + node_exec_records = node_exec_repo.get_executions_by_workflow_run( + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_run_id=run.id, + ) + node_exec_ids = [record.id for record in node_exec_records] + offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids) + table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records] + table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records] + repo = self._get_workflow_run_repo() + pause_records = repo.get_pause_records_by_run_id(session, run.id) + pause_ids = [pause.id for pause in pause_records] + pause_reason_records = repo.get_pause_reason_records_by_run_id( + session, + pause_ids, + ) + table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records] + table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records] + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + trigger_records = trigger_repo.list_by_run_id(run.id) + table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records] + trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None + return table_data, app_logs, trigger_metadata + + @staticmethod + def _row_to_dict(row: Any) -> dict[str, Any]: + mapper = inspect(row).mapper + return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns} + + def _get_archive_key(self, run: WorkflowRun) -> str: + """Get the storage key for the archive bundle.""" + created_at = run.created_at + prefix = ( + f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={run.id}" + ) + return f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + def _generate_manifest( + self, + run: WorkflowRun, + table_stats: list[TableStats], + ) -> dict[str, Any]: + """Generate a manifest for the archived workflow run.""" + return { + "schema_version": ARCHIVE_SCHEMA_VERSION, + "workflow_run_id": run.id, + "tenant_id": run.tenant_id, + "app_id": run.app_id, + "workflow_id": run.workflow_id, + "created_at": run.created_at.isoformat(), + "archived_at": datetime.datetime.now(datetime.UTC).isoformat(), + "tables": { + stat.table_name: { + "row_count": stat.row_count, + "checksum": stat.checksum, + "size_bytes": stat.size_bytes, + } + for stat in table_stats + }, + } + + def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes: + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + archive.writestr("manifest.json", manifest_data) + for table_name in self.ARCHIVED_TABLES: + data = table_payloads.get(table_name) + if data is None: + raise ValueError(f"Missing archive payload for {table_name}") + archive.writestr(f"{table_name}.jsonl", data) + return buffer.getvalue() + + def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_ids = [run.id for run in runs] + return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids) + + def _get_workflow_node_execution_repo( + self, + session: Session, + ) -> DifyAPIWorkflowNodeExecutionRepository: + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False) + return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..c3e0dce399 --- /dev/null +++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,293 @@ +import datetime +import logging +from collections.abc import Iterable, Sequence + +import click +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.factory import DifyAPIRepositoryFactory +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.billing_service import BillingService, SubscriptionPlan + +logger = logging.getLogger(__name__) + + +class WorkflowRunCleanup: + def __init__( + self, + days: int, + batch_size: int, + start_from: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + workflow_run_repo: APIWorkflowRunRepository | None = None, + dry_run: bool = False, + ): + if (start_from is None) ^ (end_before is None): + raise ValueError("start_from and end_before must be both set or both omitted.") + + computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days) + self.window_start = start_from + self.window_end = end_before or computed_cutoff + + if self.window_start and self.window_end <= self.window_start: + raise ValueError("end_before must be greater than start_from.") + + if batch_size <= 0: + raise ValueError("batch_size must be greater than 0.") + + self.batch_size = batch_size + self._cleanup_whitelist: set[str] | None = None + self.dry_run = dry_run + self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD + self.workflow_run_repo: APIWorkflowRunRepository + if workflow_run_repo: + self.workflow_run_repo = workflow_run_repo + else: + # Lazy import to avoid circular dependencies during module import + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + def run(self) -> None: + click.echo( + click.style( + f"{'Inspecting' if self.dry_run else 'Cleaning'} workflow runs " + f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}" + f"{self.window_end.isoformat()} (batch={self.batch_size})", + fg="white", + ) + ) + if self.dry_run: + click.echo(click.style("Dry run mode enabled. No data will be deleted.", fg="yellow")) + + total_runs_deleted = 0 + total_runs_targeted = 0 + related_totals = self._empty_related_counts() if self.dry_run else None + batch_index = 0 + last_seen: tuple[datetime.datetime, str] | None = None + + while True: + run_rows = self.workflow_run_repo.get_runs_batch_by_time_range( + start_from=self.window_start, + end_before=self.window_end, + last_seen=last_seen, + batch_size=self.batch_size, + ) + if not run_rows: + break + + batch_index += 1 + last_seen = (run_rows[-1].created_at, run_rows[-1].id) + tenant_ids = {row.tenant_id for row in run_rows} + free_tenants = self._filter_free_tenants(tenant_ids) + free_runs = [row for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_runs) + + if not free_runs: + skipped_message = ( + f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)" + ) + click.echo( + click.style( + skipped_message, + fg="yellow", + ) + ) + continue + + total_runs_targeted += len(free_runs) + + if self.dry_run: + batch_counts = self.workflow_run_repo.count_runs_with_related( + free_runs, + count_node_executions=self._count_node_executions, + count_trigger_logs=self._count_trigger_logs, + ) + if related_totals is not None: + for key in related_totals: + related_totals[key] += batch_counts.get(key, 0) + sample_ids = ", ".join(run.id for run in free_runs[:5]) + click.echo( + click.style( + f"[batch #{batch_index}] would delete {len(free_runs)} runs " + f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown", + fg="yellow", + ) + ) + continue + + try: + counts = self.workflow_run_repo.delete_runs_with_related( + free_runs, + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + except Exception: + logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) + raise + + total_runs_deleted += counts["runs"] + click.echo( + click.style( + f"[batch #{batch_index}] deleted runs: {counts['runs']} " + f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " + f"skipped {paid_or_skipped} paid/unknown", + fg="green", + ) + ) + + if self.dry_run: + if self.window_start: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"before {self.window_end.isoformat()}" + ) + if related_totals is not None: + summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}" + summary_color = "yellow" + else: + if self.window_start: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}" + ) + summary_color = "white" + + click.echo(click.style(summary_message, fg=summary_color)) + + def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: + tenant_id_list = list(tenant_ids) + + if not dify_config.BILLING_ENABLED: + return set(tenant_id_list) + + if not tenant_id_list: + return set() + + cleanup_whitelist = self._get_cleanup_whitelist() + + try: + bulk_info = BillingService.get_plan_bulk_with_cache(tenant_id_list) + except Exception: + bulk_info = {} + logger.exception("Failed to fetch billing plans in bulk for tenants: %s", tenant_id_list) + + eligible_free_tenants: set[str] = set() + for tenant_id in tenant_id_list: + if tenant_id in cleanup_whitelist: + continue + + info = bulk_info.get(tenant_id) + if info is None: + logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id) + continue + + if info.get("plan") != CloudPlan.SANDBOX: + continue + + if self._is_within_grace_period(tenant_id, info): + continue + + eligible_free_tenants.add(tenant_id) + + return eligible_free_tenants + + def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None: + if expiration_value < 0: + return None + + try: + return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC) + except (OverflowError, OSError, ValueError): + logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id) + return None + + def _is_within_grace_period(self, tenant_id: str, info: SubscriptionPlan) -> bool: + if self.free_plan_grace_period_days <= 0: + return False + + expiration_value = info.get("expiration_date", -1) + expiration_at = self._expiration_datetime(tenant_id, expiration_value) + if expiration_at is None: + return False + + grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days) + return datetime.datetime.now(datetime.UTC) < grace_deadline + + def _get_cleanup_whitelist(self) -> set[str]: + if self._cleanup_whitelist is not None: + return self._cleanup_whitelist + + if not dify_config.BILLING_ENABLED: + self._cleanup_whitelist = set() + return self._cleanup_whitelist + + try: + whitelist_ids = BillingService.get_expired_subscription_cleanup_whitelist() + except Exception: + logger.exception("Failed to fetch cleanup whitelist from billing service") + whitelist_ids = [] + + self._cleanup_whitelist = set(whitelist_ids) + return self._cleanup_whitelist + + def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + def _count_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.count_by_run_ids(run_ids) + + @staticmethod + def _empty_related_counts() -> dict[str, int]: + return { + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + @staticmethod + def _format_related_counts(counts: dict[str, int]) -> str: + return ( + f"node_executions {counts['node_executions']}, " + f"offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, " + f"trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, " + f"pause_reasons {counts['pause_reasons']}" + ) + + def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_ids = [run.id for run in runs] + repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False) + ) + return repo.count_by_runs(session, run_ids) + + def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_ids = [run.id for run in runs] + repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False) + ) + return repo.delete_by_runs(session, run_ids) diff --git a/api/services/retention/workflow_run/constants.py b/api/services/retention/workflow_run/constants.py new file mode 100644 index 0000000000..162bb4947d --- /dev/null +++ b/api/services/retention/workflow_run/constants.py @@ -0,0 +1,2 @@ +ARCHIVE_SCHEMA_VERSION = "1.0" +ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip" diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py new file mode 100644 index 0000000000..11873bf1b9 --- /dev/null +++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py @@ -0,0 +1,134 @@ +""" +Delete Archived Workflow Run Service. + +This service deletes archived workflow run data from the database while keeping +archive logs intact. +""" + +import time +from collections.abc import Sequence +from dataclasses import dataclass, field +from datetime import datetime + +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository + + +@dataclass +class DeleteResult: + run_id: str + tenant_id: str + success: bool + deleted_counts: dict[str, int] = field(default_factory=dict) + error: str | None = None + elapsed_time: float = 0.0 + + +class ArchivedWorkflowRunDeletion: + def __init__(self, dry_run: bool = False): + self.dry_run = dry_run + self.workflow_run_repo: APIWorkflowRunRepository | None = None + + def delete_by_run_id(self, run_id: str) -> DeleteResult: + start_time = time.time() + result = DeleteResult(run_id=run_id, tenant_id="", success=False) + + repo = self._get_workflow_run_repo() + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + with session_maker() as session: + run = session.get(WorkflowRun, run_id) + if not run: + result.error = f"Workflow run {run_id} not found" + result.elapsed_time = time.time() - start_time + return result + + result.tenant_id = run.tenant_id + if not repo.get_archived_run_ids(session, [run.id]): + result.error = f"Workflow run {run_id} is not archived" + result.elapsed_time = time.time() - start_time + return result + + result = self._delete_run(run) + result.elapsed_time = time.time() - start_time + return result + + def delete_batch( + self, + tenant_ids: list[str] | None, + start_date: datetime, + end_date: datetime, + limit: int = 100, + ) -> list[DeleteResult]: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + results: list[DeleteResult] = [] + + repo = self._get_workflow_run_repo() + with session_maker() as session: + runs = list( + repo.get_archived_runs_by_time_range( + session=session, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + ) + for run in runs: + results.append(self._delete_run(run)) + + return results + + def _delete_run(self, run: WorkflowRun) -> DeleteResult: + start_time = time.time() + result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + if self.dry_run: + result.success = True + result.elapsed_time = time.time() - start_time + return result + + repo = self._get_workflow_run_repo() + try: + deleted_counts = repo.delete_runs_with_related( + [run], + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + result.deleted_counts = deleted_counts + result.success = True + except Exception as e: + result.error = str(e) + result.elapsed_time = time.time() - start_time + return result + + @staticmethod + def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + @staticmethod + def _delete_node_executions( + session: Session, + runs: Sequence[WorkflowRun], + ) -> tuple[int, int]: + from repositories.factory import DifyAPIRepositoryFactory + + run_ids = [run.id for run in runs] + repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False) + ) + return repo.delete_by_runs(session, run_ids) + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + from repositories.factory import DifyAPIRepositoryFactory + + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py new file mode 100644 index 0000000000..d4a6e87585 --- /dev/null +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -0,0 +1,481 @@ +""" +Restore Archived Workflow Run Service. + +This service restores archived workflow run data from S3-compatible storage +back to the database. +""" + +import io +import json +import logging +import time +import zipfile +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from datetime import datetime +from typing import Any, cast + +import click +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker + +from extensions.ext_database import db +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageNotConfiguredError, + get_archive_storage, +) +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowArchiveLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.factory import DifyAPIRepositoryFactory +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME + +logger = logging.getLogger(__name__) + + +# Mapping of table names to SQLAlchemy models +TABLE_MODELS = { + "workflow_runs": WorkflowRun, + "workflow_app_logs": WorkflowAppLog, + "workflow_node_executions": WorkflowNodeExecutionModel, + "workflow_node_execution_offload": WorkflowNodeExecutionOffload, + "workflow_pauses": WorkflowPause, + "workflow_pause_reasons": WorkflowPauseReason, + "workflow_trigger_logs": WorkflowTriggerLog, +} + +SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]] + +SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = { + "1.0": {}, +} + + +@dataclass +class RestoreResult: + """Result of restoring a single workflow run.""" + + run_id: str + tenant_id: str + success: bool + restored_counts: dict[str, int] + error: str | None = None + elapsed_time: float = 0.0 + + +class WorkflowRunRestore: + """ + Restore archived workflow run data from storage to database. + + This service reads archived data from storage and restores it to the + database tables. It handles idempotency by skipping records that already + exist in the database. + """ + + def __init__(self, dry_run: bool = False, workers: int = 1): + """ + Initialize the restore service. + + Args: + dry_run: If True, only preview without making changes + workers: Number of concurrent workflow runs to restore + """ + self.dry_run = dry_run + if workers < 1: + raise ValueError("workers must be at least 1") + self.workers = workers + self.workflow_run_repo: APIWorkflowRunRepository | None = None + + def _restore_from_run( + self, + run: WorkflowRun | WorkflowArchiveLog, + *, + session_maker: sessionmaker, + ) -> RestoreResult: + start_time = time.time() + run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id + created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at + result = RestoreResult( + run_id=run_id, + tenant_id=run.tenant_id, + success=False, + restored_counts={}, + ) + + if not self.dry_run: + click.echo( + click.style( + f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})", + fg="white", + ) + ) + + try: + storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + result.error = str(e) + click.echo(click.style(f"Archive storage not configured: {e}", fg="red")) + result.elapsed_time = time.time() - start_time + return result + + prefix = ( + f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={run_id}" + ) + archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + try: + archive_data = storage.get_object(archive_key) + except FileNotFoundError: + result.error = f"Archive bundle not found: {archive_key}" + click.echo(click.style(result.error, fg="red")) + result.elapsed_time = time.time() - start_time + return result + + with session_maker() as session: + try: + with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive: + try: + manifest = self._load_manifest_from_zip(archive) + except ValueError as e: + result.error = f"Archive bundle invalid: {e}" + click.echo(click.style(result.error, fg="red")) + return result + + tables = manifest.get("tables", {}) + schema_version = self._get_schema_version(manifest) + for table_name, info in tables.items(): + row_count = info.get("row_count", 0) + if row_count == 0: + result.restored_counts[table_name] = 0 + continue + + if self.dry_run: + result.restored_counts[table_name] = row_count + continue + + member_path = f"{table_name}.jsonl" + try: + data = archive.read(member_path) + except KeyError: + click.echo( + click.style( + f" Warning: Table data not found in archive: {member_path}", + fg="yellow", + ) + ) + result.restored_counts[table_name] = 0 + continue + + records = ArchiveStorage.deserialize_from_jsonl(data) + restored = self._restore_table_records( + session, + table_name, + records, + schema_version=schema_version, + ) + result.restored_counts[table_name] = restored + if not self.dry_run: + click.echo( + click.style( + f" Restored {restored}/{len(records)} records to {table_name}", + fg="white", + ) + ) + + # Verify row counts match manifest + manifest_total = sum(info.get("row_count", 0) for info in tables.values()) + restored_total = sum(result.restored_counts.values()) + + if not self.dry_run: + # Note: restored count might be less than manifest count if records already exist + logger.info( + "Restore verification: manifest_total=%d, restored_total=%d", + manifest_total, + restored_total, + ) + + # Delete the archive log record after successful restore + repo = self._get_workflow_run_repo() + repo.delete_archive_log_by_run_id(session, run_id) + + session.commit() + + result.success = True + if not self.dry_run: + click.echo( + click.style( + f"Completed restore for workflow run {run_id}: restored={result.restored_counts}", + fg="green", + ) + ) + + except Exception as e: + logger.exception("Failed to restore workflow run %s", run_id) + result.error = str(e) + session.rollback() + click.echo(click.style(f"Restore failed: {e}", fg="red")) + + result.elapsed_time = time.time() - start_time + return result + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + return self.workflow_run_repo + + @staticmethod + def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]: + try: + data = archive.read("manifest.json") + except KeyError as e: + raise ValueError("manifest.json missing from archive bundle") from e + return json.loads(data.decode("utf-8")) + + def _restore_table_records( + self, + session: Session, + table_name: str, + records: list[dict[str, Any]], + *, + schema_version: str, + ) -> int: + """ + Restore records to a table. + + Uses INSERT ... ON CONFLICT DO NOTHING for idempotency. + + Args: + session: Database session + table_name: Name of the table + records: List of record dictionaries + schema_version: Archived schema version from manifest + + Returns: + Number of records actually inserted + """ + if not records: + return 0 + + model = TABLE_MODELS.get(table_name) + if not model: + logger.warning("Unknown table: %s", table_name) + return 0 + + column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model) + unknown_fields: set[str] = set() + + # Apply schema mapping, filter to current columns, then convert datetimes + converted_records = [] + for record in records: + mapped = self._apply_schema_mapping(table_name, schema_version, record) + unknown_fields.update(set(mapped.keys()) - column_names) + filtered = {key: value for key, value in mapped.items() if key in column_names} + for key in non_nullable_with_default: + if key in filtered and filtered[key] is None: + filtered.pop(key) + missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None] + if missing_required: + missing_cols = ", ".join(sorted(missing_required)) + raise ValueError( + f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}" + ) + converted = self._convert_datetime_fields(filtered, model) + converted_records.append(converted) + if unknown_fields: + logger.warning( + "Dropped unknown columns for %s (schema_version=%s): %s", + table_name, + schema_version, + ", ".join(sorted(unknown_fields)), + ) + + # Use INSERT ... ON CONFLICT DO NOTHING for idempotency + stmt = pg_insert(model).values(converted_records) + stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) + + result = session.execute(stmt) + return cast(CursorResult, result).rowcount or 0 + + def _convert_datetime_fields( + self, + record: dict[str, Any], + model: type[DeclarativeBase] | Any, + ) -> dict[str, Any]: + """Convert ISO datetime strings to datetime objects.""" + from sqlalchemy import DateTime + + result = dict(record) + + for column in model.__table__.columns: + if isinstance(column.type, DateTime): + value = result.get(column.key) + if isinstance(value, str): + try: + result[column.key] = datetime.fromisoformat(value) + except ValueError: + pass + + return result + + def _get_schema_version(self, manifest: dict[str, Any]) -> str: + schema_version = manifest.get("schema_version") + if not schema_version: + logger.warning("Manifest missing schema_version; defaulting to 1.0") + schema_version = "1.0" + schema_version = str(schema_version) + if schema_version not in SCHEMA_MAPPERS: + raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.") + return schema_version + + def _apply_schema_mapping( + self, + table_name: str, + schema_version: str, + record: dict[str, Any], + ) -> dict[str, Any]: + # Keep hook for forward/backward compatibility when schema evolves. + mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name) + if mapper is None: + return dict(record) + return mapper(record) + + def _get_model_column_info( + self, + model: type[DeclarativeBase] | Any, + ) -> tuple[set[str], set[str], set[str]]: + columns = list(model.__table__.columns) + column_names = {column.key for column in columns} + required_columns = { + column.key + for column in columns + if not column.nullable + and column.default is None + and column.server_default is None + and not column.autoincrement + } + non_nullable_with_default = { + column.key + for column in columns + if not column.nullable + and (column.default is not None or column.server_default is not None or column.autoincrement) + } + return column_names, required_columns, non_nullable_with_default + + def restore_batch( + self, + tenant_ids: list[str] | None, + start_date: datetime, + end_date: datetime, + limit: int = 100, + ) -> list[RestoreResult]: + """ + Restore multiple workflow runs by time range. + + Args: + tenant_ids: Optional tenant IDs + start_date: Start date filter + end_date: End date filter + limit: Maximum number of runs to restore (default: 100) + + Returns: + List of RestoreResult objects + """ + results: list[RestoreResult] = [] + if tenant_ids is not None and not tenant_ids: + return results + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = self._get_workflow_run_repo() + + with session_maker() as session: + archive_logs = repo.get_archived_logs_by_time_range( + session=session, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + click.echo( + click.style( + f"Found {len(archive_logs)} archived workflow runs to restore", + fg="white", + ) + ) + + def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult: + return self._restore_from_run( + archive_log, + session_maker=session_maker, + ) + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + results = list(executor.map(_restore_with_session, archive_logs)) + + total_counts: dict[str, int] = {} + for result in results: + for table_name, count in result.restored_counts.items(): + total_counts[table_name] = total_counts.get(table_name, 0) + count + success_count = sum(1 for result in results if result.success) + + if self.dry_run: + click.echo( + click.style( + f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}", + fg="yellow", + ) + ) + else: + click.echo( + click.style( + f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}", + fg="green", + ) + ) + + return results + + def restore_by_run_id( + self, + run_id: str, + ) -> RestoreResult: + """ + Restore a single workflow run by run ID. + """ + repo = self._get_workflow_run_repo() + archive_log = repo.get_archived_log_by_run_id(run_id) + + if not archive_log: + click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red")) + return RestoreResult( + run_id=run_id, + tenant_id="", + success=False, + restored_counts={}, + error=f"Workflow run archive {run_id} not found", + ) + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + result = self._restore_from_run(archive_log, session_maker=session_maker) + if self.dry_run and result.success: + click.echo( + click.style( + f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}", + fg="yellow", + ) + ) + return result diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 937e6593fe..bd3585acf4 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -19,7 +19,10 @@ class TagService: .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%"))) + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) + query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index b3b6e36346..c32157919b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -7,7 +7,6 @@ from httpx import get from sqlalchemy import select from core.entities.provider_entities import ProviderConfig -from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController @@ -86,7 +85,9 @@ class ApiToolManageService: raise ValueError(f"invalid schema: {str(e)}") @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]: + def convert_schema_to_tool_bundles( + schema: str, extra_info: dict | None = None + ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ convert schema to tool bundles @@ -104,7 +105,7 @@ class ApiToolManageService: provider_name: str, icon: dict, credentials: dict, - schema_type: str, + schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str, custom_disclaimer: str, @@ -113,9 +114,6 @@ class ApiToolManageService: """ create api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() # check if the provider exists @@ -178,9 +176,6 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -245,18 +240,15 @@ class ApiToolManageService: original_provider: str, icon: dict, credentials: dict, - schema_type: str, + _schema_type: ApiProviderSchemaType, schema: str, - privacy_policy: str, + privacy_policy: str | None, custom_disclaimer: str, labels: list[str], ): """ update api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() # check if the provider exists @@ -281,7 +273,7 @@ class ApiToolManageService: provider.icon = json.dumps(icon) provider.schema = schema provider.description = extra_info.get("description", "") - provider.schema_type_str = ApiProviderSchemaType.OPENAPI + provider.schema_type_str = schema_type provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer @@ -322,9 +314,6 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -347,9 +336,6 @@ class ApiToolManageService: db.session.delete(provider) db.session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -366,7 +352,7 @@ class ApiToolManageService: tool_name: str, credentials: dict, parameters: dict, - schema_type: str, + schema_type: ApiProviderSchemaType, schema: str, ): """ diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 87951d53e6..6797a67dde 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.name_generator import generate_incremental_name from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache -from core.helper.tool_provider_cache import ToolProviderListCache from core.plugin.entities.plugin_daemon import CredentialType from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -205,9 +204,6 @@ class BuiltinToolManageService: db_provider.name = name session.commit() - - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) except Exception as e: session.rollback() raise ValueError(str(e)) @@ -290,8 +286,6 @@ class BuiltinToolManageService: session.rollback() raise ValueError(str(e)) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id, "builtin") return {"result": "success"} @staticmethod @@ -409,9 +403,6 @@ class BuiltinToolManageService: ) cache.delete() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -434,8 +425,6 @@ class BuiltinToolManageService: target_provider.is_default = True session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} @staticmethod diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 252be77b27..0be106f597 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -319,8 +319,14 @@ class MCPToolManageService: except MCPError as e: raise ValueError(f"Failed to connect to MCP server: {e}") - # Update database with retrieved tools - db_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + # Update database with retrieved tools (ensure description is a non-null string) + tools_payload = [] + for tool in tools: + data = tool.model_dump() + if data.get("description") is None: + data["description"] = "" + tools_payload.append(data) + db_provider.tools = json.dumps(tools_payload) db_provider.authed = True db_provider.updated_at = datetime.now() self._session.flush() @@ -620,6 +626,21 @@ class MCPToolManageService: server_url_hash=new_server_url_hash, ) + @staticmethod + def reconnect_with_url( + *, + server_url: str, + headers: dict[str, str], + timeout: float | None, + sse_read_timeout: float | None, + ) -> ReconnectResult: + return MCPToolManageService._reconnect_with_url( + server_url=server_url, + headers=headers, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + ) + @staticmethod def _reconnect_with_url( *, @@ -642,9 +663,16 @@ class MCPToolManageService: sse_read_timeout=sse_read_timeout, ) as mcp_client: tools = mcp_client.list_tools() + # Ensure tool descriptions are non-null in payload + tools_payload = [] + for t in tools: + d = t.model_dump() + if d.get("description") is None: + d["description"] = "" + tools_payload.append(d) return ReconnectResult( authed=True, - tools=json.dumps([tool.model_dump() for tool in tools]), + tools=json.dumps(tools_payload), encrypted_credentials=EMPTY_CREDENTIALS_JSON, ) except MCPAuthError: diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 038c462f15..51e9120b8d 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,6 +1,5 @@ import logging -from core.helper.tool_provider_cache import ToolProviderListCache from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager from services.tools.tools_transform_service import ToolTransformService @@ -16,14 +15,6 @@ class ToolCommonService: :return: the list of tool providers """ - # Try to get from cache first - cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ) - if cached_result is not None: - logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ) - return cached_result - - # Cache miss - fetch from database - logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ) providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ) # add icon @@ -32,7 +23,4 @@ class ToolCommonService: result = [provider.to_dict() for provider in providers] - # Cache the result - ToolProviderListCache.set_cached_providers(tenant_id, typ, result) - return result diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 20cb258277..0ae40199ab 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -5,9 +5,8 @@ from datetime import datetime from typing import Any from sqlalchemy import or_, select +from sqlalchemy.orm import Session -from core.db.session_factory import session_factory -from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -88,17 +87,13 @@ class WorkflowToolManageService: except Exception as e: raise ValueError(str(e)) - with session_factory.create_session() as session, session.begin(): + with Session(db.engine, expire_on_commit=False) as session, session.begin(): session.add(workflow_tool_provider) if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod @@ -188,9 +183,6 @@ class WorkflowToolManageService: ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod @@ -253,9 +245,6 @@ class WorkflowToolManageService: db.session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 57de9b3cee..688993c798 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -799,7 +799,7 @@ class TriggerProviderService: user_id: str, provider_id: TriggerProviderID, subscription_id: str, - credentials: Mapping[str, Any], + credentials: dict[str, Any], ) -> dict[str, Any]: """ Verify credentials for an existing subscription without updating it. @@ -853,7 +853,7 @@ class TriggerProviderService: """ Create a subscription builder for rebuilding an existing subscription. - This method creates a builder pre-filled with data from the rebuild request, + This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub) keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged. :param tenant_id: Tenant ID @@ -876,16 +876,12 @@ class TriggerProviderService: raise ValueError(f"Subscription {subscription_id} not found") credential_type = CredentialType.of(subscription.credential_type) - if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]: - raise ValueError("Credential type not supported for rebuild") - - # TODO: Trying to invoke update api of the plugin trigger provider - - # FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one + if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}: + raise ValueError(f"Credential type {credential_type} not supported for auto creation") # Delete the previous subscription user_id = subscription.user_id - TriggerManager.unsubscribe_trigger( + unsubscribe_result = TriggerManager.unsubscribe_trigger( tenant_id=tenant_id, user_id=user_id, provider_id=provider_id, @@ -893,15 +889,21 @@ class TriggerProviderService: credentials=subscription.credentials, credential_type=credential_type, ) + if not unsubscribe_result.success: + raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}") # Create a new subscription with the same subscription_id and endpoint_id + new_credentials: dict[str, Any] = { + key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE) + for key, value in credentials.items() + } new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger( tenant_id=tenant_id, user_id=user_id, provider_id=provider_id, endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id), parameters=parameters, - credentials=credentials, + credentials=new_credentials, credential_type=credential_type, ) TriggerProviderService.update_trigger_subscription( @@ -909,7 +911,7 @@ class TriggerProviderService: subscription_id=subscription.id, name=name, parameters=parameters, - credentials=credentials, + credentials=new_credentials, properties=new_subscription.properties, expires_at=new_subscription.expires_at, ) diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 5c4607d400..4159f5f8f4 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -863,10 +863,18 @@ class WebhookService: not_found_in_cache.append(node_id) continue - with Session(db.engine) as session: - try: - # lock the concurrent webhook trigger creation - redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10) + lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock" + lock = redis_client.lock(lock_key, timeout=10) + lock_acquired = False + + try: + # acquire the lock with blocking and timeout + lock_acquired = lock.acquire(blocking=True, blocking_timeout=10) + if not lock_acquired: + logger.warning("Failed to acquire lock for webhook sync, app %s", app.id) + raise RuntimeError("Failed to acquire lock for webhook trigger synchronization") + + with Session(db.engine) as session: # fetch the non-cached nodes from DB all_records = session.scalars( select(WorkflowWebhookTrigger).where( @@ -903,11 +911,16 @@ class WebhookService: session.delete(nodes_id_in_db[node_id]) redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}") session.commit() - except Exception: - logger.exception("Failed to sync webhook relationships for app %s", app.id) - raise - finally: - redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock") + except Exception: + logger.exception("Failed to sync webhook relationships for app %s", app.id) + raise + finally: + # release the lock only if it was acquired + if lock_acquired: + try: + lock.release() + except Exception: + logger.exception("Failed to release lock for webhook sync, app %s", app.id) @classmethod def generate_webhook_id(cls) -> str: diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 0f969207cf..f973361341 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses from abc import ABC, abstractmethod from collections.abc import Mapping @@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator): self._max_size_bytes = max_size_bytes @classmethod - def default(cls) -> "VariableTruncator": + def default(cls) -> VariableTruncator: return VariableTruncator( max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE, array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH, diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 9bd797a45f..5ca0b63001 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -12,6 +12,7 @@ from libs.passport import PassportService from libs.password import compare_password from models import Account, AccountStatus from models.model import App, EndUser, Site +from services.account_service import AccountService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError @@ -32,7 +33,7 @@ class WebAppAuthService: @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" - account = db.session.query(Account).filter_by(email=email).first() + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: raise AccountNotFoundError() @@ -52,7 +53,7 @@ class WebAppAuthService: @classmethod def get_user_through_email(cls, email: str): - account = db.session.query(Account).where(Account.email == email).first() + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: return None diff --git a/api/services/website_service.py b/api/services/website_service.py index a23f01ec71..fe48c3b08e 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import json from dataclasses import dataclass @@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest: return CrawlRequest(url=self.url, provider=self.provider, options=options) @classmethod - def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest": + def from_args(cls, args: dict) -> WebsiteCrawlApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") url = args.get("url") @@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest: job_id: str @classmethod - def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest": + def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") if not provider: diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 01f0c7a55a..efc76c33bc 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from core.workflow.enums import WorkflowExecutionStatus -from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun +from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog from services.plugin.plugin_service import PluginService @@ -86,12 +86,19 @@ class WorkflowAppService: # Join to workflow run for filtering when needed. if keyword: - keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") + from libs.helper import escape_like_pattern + + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword[:30]) + keyword_like_val = f"%{escaped_keyword}%" keyword_conditions = [ - WorkflowRun.inputs.ilike(keyword_like_val), - WorkflowRun.outputs.ilike(keyword_like_val), + WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"), + WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), + and_( + WorkflowRun.created_by_role == "end_user", + EndUser.session_id.ilike(keyword_like_val, escape="\\"), + ), ] # filter keyword by workflow run id @@ -166,7 +173,80 @@ class WorkflowAppService: "data": items, } - def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]: + def get_paginate_workflow_archive_logs( + self, + *, + session: Session, + app_model: App, + page: int = 1, + limit: int = 20, + ): + """ + Get paginate workflow archive logs using SQLAlchemy 2.0 style. + """ + stmt = select(WorkflowArchiveLog).where( + WorkflowArchiveLog.tenant_id == app_model.tenant_id, + WorkflowArchiveLog.app_id == app_model.id, + WorkflowArchiveLog.log_id.isnot(None), + ) + + stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc()) + + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = session.scalar(count_stmt) or 0 + + offset_stmt = stmt.offset((page - 1) * limit).limit(limit) + + logs = list(session.scalars(offset_stmt).all()) + account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT} + end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER} + + accounts_by_id = {} + if account_ids: + accounts_by_id = { + account.id: account + for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all() + } + + end_users_by_id = {} + if end_user_ids: + end_users_by_id = { + end_user.id: end_user + for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all() + } + + items = [] + for log in logs: + if log.created_by_role == CreatorUserRole.ACCOUNT: + created_by_account = accounts_by_id.get(log.created_by) + created_by_end_user = None + elif log.created_by_role == CreatorUserRole.END_USER: + created_by_account = None + created_by_end_user = end_users_by_id.get(log.created_by) + else: + created_by_account = None + created_by_end_user = None + + items.append( + { + "id": log.id, + "workflow_run": log.workflow_run_summary, + "trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata), + "created_by_account": created_by_account, + "created_by_end_user": created_by_end_user, + "created_at": log.log_created_at, + } + ) + + return { + "page": page, + "limit": limit, + "total": total, + "has_more": total > page * limit, + "data": items, + } + + def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]: metadata: dict[str, Any] | None = self._safe_json_loads(meta_val) if not metadata: return {} diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f299ce3baa..70b0190231 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File -from core.variables import Segment, StringSegment, Variable +from core.variables import Segment, StringSegment, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import ( ArrayFileSegment, @@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader): # Application ID for which variables are being loaded. _app_id: str _tenant_id: str - _fallback_variables: Sequence[Variable] + _fallback_variables: Sequence[VariableBase] def __init__( self, engine: Engine, app_id: str, tenant_id: str, - fallback_variables: Sequence[Variable] | None = None, + fallback_variables: Sequence[VariableBase] | None = None, ): self._engine = engine self._app_id = app_id @@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader): def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: return (selector[0], selector[1]) - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: if not selectors: return [] - # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. - variable_by_selector: dict[tuple[str, str], Variable] = {} + # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance. + variable_by_selector: dict[tuple[str, str], VariableBase] = {} with Session(bind=self._engine, expire_on_commit=False) as session: srv = WorkflowDraftVariableService(session) @@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader): return list(variable_by_selector.values()) - def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]: + def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]: # This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable` # and must remain synchronized with it. # Ideally, these should be co-located for better maintainability. @@ -679,6 +679,7 @@ def _batch_upsert_draft_variable( def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: d: dict[str, Any] = { + "id": model.id, "app_id": model.app_id, "last_edited_at": None, "node_id": model.node_id, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 634dd3e2c3..d02f55b504 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,4 +1,5 @@ import json +import logging import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence @@ -15,8 +16,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl -from core.variables import Variable -from core.variables.variables import VariableUnion +from core.variables import VariableBase +from core.variables.variables import Variable from core.workflow.entities import GraphInitParams, WorkflowNodeExecution from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -220,8 +221,8 @@ class WorkflowService: features: dict, unique_hash: str | None, account: Account, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], ) -> Workflow: """ Sync draft workflow @@ -697,7 +698,7 @@ class WorkflowService: else: variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs=user_inputs, environment_variables=draft_workflow.environment_variables, conversation_variables=[], @@ -1007,6 +1008,8 @@ class WorkflowService: @staticmethod def _load_email_recipients(form_id: str) -> list[DeliveryTestEmailRecipient]: + logger = logging.getLogger(__name__) + with Session(bind=db.engine) as session: recipients = session.scalars( select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id) @@ -1426,7 +1429,7 @@ def _setup_variable_pool( workflow: Workflow, node_type: NodeType, conversation_id: str, - conversation_variables: list[Variable], + conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if node_type == NodeType.START or node_type.is_trigger_node: @@ -1445,16 +1448,16 @@ def _setup_variable_pool( system_variable.conversation_id = conversation_id system_variable.dialogue_count = 1 else: - system_variable = SystemVariable.empty() + system_variable = SystemVariable.default() # init variable pool variable_pool = VariablePool( system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - conversation_variables=cast(list[VariableUnion], conversation_variables), # + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. + conversation_variables=cast(list[Variable], conversation_variables), # ) return variable_pool diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 292ac6e008..3ee41c2e8d 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -31,7 +31,8 @@ class WorkspaceService: assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo + feature = FeatureService.get_features(tenant.id) + can_replace_logo = feature.can_replace_logo if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): base_url = dify_config.FILES_URL @@ -46,5 +47,19 @@ class WorkspaceService: "remove_webapp_brand": remove_webapp_brand, "replace_webapp_logo": replace_webapp_logo, } + if dify_config.EDITION == "CLOUD": + tenant_info["next_credit_reset_date"] = feature.next_credit_reset_date + + from services.credit_pool_service import CreditPoolService + + paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid") + if paid_pool: + tenant_info["trial_credits"] = paid_pool.quota_limit + tenant_info["trial_credits_used"] = paid_pool.quota_used + else: + trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial") + if trial_pool: + tenant_info["trial_credits"] = trial_pool.quota_limit + tenant_info["trial_credits_used"] = trial_pool.quota_used return tenant_info diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index e7dead8a56..62e6497e9d 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,11 +4,11 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DatasetAutoDisableLog, DocumentSegment @@ -28,106 +28,106 @@ def add_document_to_index_task(dataset_document_id: str): logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green")) start_at = time.perf_counter() - dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() - if not dataset_document: - logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) - db.session.close() - return + with session_factory.create_session() as session: + dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() + if not dataset_document: + logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) + return - if dataset_document.indexing_status != "completed": - db.session.close() - return + if dataset_document.indexing_status != "completed": + return - indexing_cache_key = f"document_{dataset_document.id}_indexing" + indexing_cache_key = f"document_{dataset_document.id}_indexing" - try: - dataset = dataset_document.dataset - if not dataset: - raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.") + try: + dataset = dataset_document.dataset + if not dataset: + raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.") - segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == "completed", + ) + .order_by(DocumentSegment.position.asc()) + .all() ) - .order_by(DocumentSegment.position.asc()) - .all() - ) - documents = [] - multimodal_documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, + documents = [] + multimodal_documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + documents.append(document) + + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) + + # delete auto disable log + session.query(DatasetAutoDisableLog).where( + DatasetAutoDisableLog.document_id == dataset_document.id + ).delete() + + # update segment to enable + session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( + { + DocumentSegment.enabled: True, + DocumentSegment.disabled_at: None, + DocumentSegment.disabled_by: None, + DocumentSegment.updated_at: naive_utc_now(), + } ) - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - child_documents.append(child_document) - document.children = child_documents - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodal_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - documents.append(document) + session.commit() - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) - - # delete auto disable log - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() - - # update segment to enable - db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( - { - DocumentSegment.enabled: True, - DocumentSegment.disabled_at: None, - DocumentSegment.disabled_by: None, - DocumentSegment.updated_at: naive_utc_now(), - } - ) - db.session.commit() - - end_at = time.perf_counter() - logger.info( - click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") - ) - except Exception as e: - logger.exception("add document to index failed") - dataset_document.enabled = False - dataset_document.disabled_at = naive_utc_now() - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + end_at = time.perf_counter() + logger.info( + click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") + ) + except Exception as e: + logger.exception("add document to index failed") + dataset_document.enabled = False + dataset_document.disabled_at = naive_utc_now() + dataset_document.indexing_status = "error" + dataset_document.error = str(e) + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 775814318b..fc6bf03454 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -5,9 +5,9 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -32,74 +32,72 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" active_jobs_key = f"annotation_import_active:{tenant_id}" - # get app info - app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + with session_factory.create_session() as session: + # get app info + app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - if app: - try: - documents = [] - for content in content_list: - annotation = MessageAnnotation( - app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id + if app: + try: + documents = [] + for content in content_list: + annotation = MessageAnnotation( + app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id + ) + session.add(annotation) + session.flush() + + document = Document( + page_content=content["question"], + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, + ) + documents.append(document) + # if annotation reply is enabled , batch add annotations' index + app_annotation_setting = ( + session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) - db.session.add(annotation) - db.session.flush() - document = Document( - page_content=content["question"], - metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, - ) - documents.append(document) - # if annotation reply is enabled , batch add annotations' index - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() - ) + if app_annotation_setting: + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + app_annotation_setting.collection_binding_id, "annotation" + ) + ) + if not dataset_collection_binding: + raise NotFound("App annotation setting not found") + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id, + ) - if app_annotation_setting: - dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - app_annotation_setting.collection_binding_id, "annotation" + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.create(documents, duplicate_check=True) + + session.commit() + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + "Build index successful for batch import annotation: {} latency: {}".format( + job_id, end_at - start_at + ), + fg="green", ) ) - if not dataset_collection_binding: - raise NotFound("App annotation setting not found") - dataset = Dataset( - id=app_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - embedding_model_provider=dataset_collection_binding.provider_name, - embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id, - ) - - vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - vector.create(documents, duplicate_check=True) - - db.session.commit() - redis_client.setex(indexing_cache_key, 600, "completed") - end_at = time.perf_counter() - logger.info( - click.style( - "Build index successful for batch import annotation: {} latency: {}".format( - job_id, end_at - start_at - ), - fg="green", - ) - ) - except Exception as e: - db.session.rollback() - redis_client.setex(indexing_cache_key, 600, "error") - indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}" - redis_client.setex(indexing_error_msg_key, 600, str(e)) - logger.exception("Build index for batch import annotations failed") - finally: - # Clean up active job tracking to release concurrency slot - try: - redis_client.zrem(active_jobs_key, job_id) - logger.debug("Released concurrency slot for job: %s", job_id) - except Exception as cleanup_error: - # Log but don't fail if cleanup fails - the job will be auto-expired - logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error) - - # Close database session - db.session.close() + except Exception as e: + session.rollback() + redis_client.setex(indexing_cache_key, 600, "error") + indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}" + redis_client.setex(indexing_error_msg_key, 600, str(e)) + logger.exception("Build index for batch import annotations failed") + finally: + # Clean up active job tracking to release concurrency slot + try: + redis_client.zrem(active_jobs_key, job_id) + logger.debug("Released concurrency slot for job: %s", job_id) + except Exception as cleanup_error: + # Log but don't fail if cleanup fails - the job will be auto-expired + logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index c0020b29ed..7b5cd46b00 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -5,8 +5,8 @@ import click from celery import shared_task from sqlalchemy import exists, select +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -22,50 +22,55 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) - if not app: - logger.info(click.style(f"App not found: {app_id}", fg="red")) - db.session.close() - return + with session_factory.create_session() as session: + app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) + if not app: + logger.info(click.style(f"App not found: {app_id}", fg="red")) + return - app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() - - if not app_annotation_setting: - logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red")) - db.session.close() - return - - disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" - disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" - - try: - dataset = Dataset( - id=app_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - collection_binding_id=app_annotation_setting.collection_binding_id, + app_annotation_setting = ( + session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) + if not app_annotation_setting: + logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red")) + return + + disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" + disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" + try: - if annotations_exists: - vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - vector.delete() - except Exception: - logger.exception("Delete annotation index failed when annotation deleted.") - redis_client.setex(disable_app_annotation_job_key, 600, "completed") + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + collection_binding_id=app_annotation_setting.collection_binding_id, + ) - # delete annotation setting - db.session.delete(app_annotation_setting) - db.session.commit() + try: + if annotations_exists: + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete() + except Exception: + logger.exception("Delete annotation index failed when annotation deleted.") + redis_client.setex(disable_app_annotation_job_key, 600, "completed") - end_at = time.perf_counter() - logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("Annotation batch deleted index failed") - redis_client.setex(disable_app_annotation_job_key, 600, "error") - disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}" - redis_client.setex(disable_app_annotation_error_key, 600, str(e)) - finally: - redis_client.delete(disable_app_annotation_key) - db.session.close() + # delete annotation setting + session.delete(app_annotation_setting) + session.commit() + + end_at = time.perf_counter() + logger.info( + click.style( + f"App annotations index deleted : {app_id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception as e: + logger.exception("Annotation batch deleted index failed") + redis_client.setex(disable_app_annotation_job_key, 600, "error") + disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}" + redis_client.setex(disable_app_annotation_error_key, 600, str(e)) + finally: + redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index cdc07c77a8..4f8e2fec7a 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -5,9 +5,9 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset @@ -33,92 +33,98 @@ def enable_annotation_reply_task( logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + with session_factory.create_session() as session: + app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - if not app: - logger.info(click.style(f"App not found: {app_id}", fg="red")) - db.session.close() - return + if not app: + logger.info(click.style(f"App not found: {app_id}", fg="red")) + return - annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() - enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" - enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" + annotations = session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() + enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" + enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" - try: - documents = [] - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, "annotation" - ) - annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() - if annotation_setting: - if dataset_collection_binding.id != annotation_setting.collection_binding_id: - old_dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - annotation_setting.collection_binding_id, "annotation" - ) - ) - if old_dataset_collection_binding and annotations: - old_dataset = Dataset( - id=app_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - embedding_model_provider=old_dataset_collection_binding.provider_name, - embedding_model=old_dataset_collection_binding.model_name, - collection_binding_id=old_dataset_collection_binding.id, - ) - - old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"]) - try: - old_vector.delete() - except Exception as e: - logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) - annotation_setting.score_threshold = score_threshold - annotation_setting.collection_binding_id = dataset_collection_binding.id - annotation_setting.updated_user_id = user_id - annotation_setting.updated_at = naive_utc_now() - db.session.add(annotation_setting) - else: - new_app_annotation_setting = AppAnnotationSetting( - app_id=app_id, - score_threshold=score_threshold, - collection_binding_id=dataset_collection_binding.id, - created_user_id=user_id, - updated_user_id=user_id, + try: + documents = [] + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_provider_name, embedding_model_name, "annotation" ) - db.session.add(new_app_annotation_setting) + annotation_setting = ( + session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + ) + if annotation_setting: + if dataset_collection_binding.id != annotation_setting.collection_binding_id: + old_dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + annotation_setting.collection_binding_id, "annotation" + ) + ) + if old_dataset_collection_binding and annotations: + old_dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=old_dataset_collection_binding.provider_name, + embedding_model=old_dataset_collection_binding.model_name, + collection_binding_id=old_dataset_collection_binding.id, + ) - dataset = Dataset( - id=app_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - embedding_model_provider=embedding_provider_name, - embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id, - ) - if annotations: - for annotation in annotations: - document = Document( - page_content=annotation.question, - metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, + old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"]) + try: + old_vector.delete() + except Exception as e: + logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) + annotation_setting.score_threshold = score_threshold + annotation_setting.collection_binding_id = dataset_collection_binding.id + annotation_setting.updated_user_id = user_id + annotation_setting.updated_at = naive_utc_now() + session.add(annotation_setting) + else: + new_app_annotation_setting = AppAnnotationSetting( + app_id=app_id, + score_threshold=score_threshold, + collection_binding_id=dataset_collection_binding.id, + created_user_id=user_id, + updated_user_id=user_id, ) - documents.append(document) + session.add(new_app_annotation_setting) - vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - try: - vector.delete_by_metadata_field("app_id", app_id) - except Exception as e: - logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) - vector.create(documents) - db.session.commit() - redis_client.setex(enable_app_annotation_job_key, 600, "completed") - end_at = time.perf_counter() - logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("Annotation batch created index failed") - redis_client.setex(enable_app_annotation_job_key, 600, "error") - enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}" - redis_client.setex(enable_app_annotation_error_key, 600, str(e)) - db.session.rollback() - finally: - redis_client.delete(enable_app_annotation_key) - db.session.close() + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=embedding_provider_name, + embedding_model=embedding_model_name, + collection_binding_id=dataset_collection_binding.id, + ) + if annotations: + for annotation in annotations: + document = Document( + page_content=annotation.question_text, + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, + ) + documents.append(document) + + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + try: + vector.delete_by_metadata_field("app_id", app_id) + except Exception as e: + logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) + vector.create(documents) + session.commit() + redis_client.setex(enable_app_annotation_job_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + f"App annotations added to index: {app_id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception as e: + logger.exception("Annotation batch created index failed") + redis_client.setex(enable_app_annotation_job_key, 600, "error") + enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}" + redis_client.setex(enable_app_annotation_error_key, 600, str(e)) + session.rollback() + finally: + redis_client.delete(enable_app_annotation_key) diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index dd7faae7a3..cc96542d4b 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -19,11 +19,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer +from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory from core.workflow.runtime import GraphRuntimeState from extensions.ext_database import db from models.account import Account -from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant from models.trigger import WorkflowTriggerLog from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun @@ -107,10 +108,7 @@ def _execute_workflow_common( ): """Execute workflow with common logic and trigger log updates.""" - # Create a new session for this task - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - with session_factory() as session: + with session_factory.create_session() as session: trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) # Get trigger log @@ -154,7 +152,7 @@ def _execute_workflow_common( args["workflow_id"] = str(trigger_data.workflow_id) pause_config = PauseStateLayerConfig( - session_factory=session_factory, + session_factory=session_factory.get_session_maker(), state_owner_user_id=workflow.created_by, ) @@ -171,7 +169,7 @@ def _execute_workflow_common( root_node_id=trigger_data.root_node_id, graph_engine_layers=[ # TODO: Re-enable TimeSliceLayer after the HITL release. - TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory), + TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), ], pause_state_config=pause_config, ) @@ -271,7 +269,7 @@ def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None: graph_engine_layers.extend( [ TimeSliceLayer(cfs_plan_scheduler), - TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory), + TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), ] ) @@ -291,7 +289,7 @@ def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None: workflow_run_repo.delete_workflow_pause(pause_entity) -def _get_user(session: Session, workflow_run: WorkflowRun) -> Account | EndUser: +def _get_user(session: Session, workflow_run: WorkflowRun | WorkflowTriggerLog) -> Account | EndUser: """Compose user from trigger log""" tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id)) if not tenant: @@ -326,19 +324,9 @@ def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id) if not trigger_log: logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id) - return + return None - cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity( - queue=trigger_log.queue_name, - schedule_strategy=AsyncWorkflowSystemStrategy, - granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY, - ) - cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity) - - try: - trigger_type = AppTriggerType(trigger_log.trigger_type) - except ValueError: - trigger_type = AppTriggerType.UNKNOWN + return trigger_log class _TenantNotFoundError(Exception): diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 3e1bd16cc7..74b939e84d 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment from models.model import UploadFile @@ -28,65 +28,64 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form """ logger.info(click.style("Start batch clean documents when documents deleted", fg="green")) start_at = time.perf_counter() + if not doc_form: + raise ValueError("doc_form is required") - try: - if not doc_form: - raise ValueError("doc_form is required") - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Document has no dataset") + if not dataset: + raise Exception("Document has no dataset") - db.session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id.in_(document_ids), - ).delete(synchronize_session=False) + session.query(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id.in_(document_ids), + ).delete(synchronize_session=False) - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) - ).all() - # check segment is exist - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() + for image_file in image_files: + try: + if image_file and image_file.key: + storage.delete(image_file.key) + except Exception: + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file.id, + ) + stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(stmt) + session.delete(segment) + if file_ids: + files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() + for file in files: try: - if image_file and image_file.key: - storage.delete(image_file.key) + storage.delete(file.key) except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - upload_file_id, - ) - db.session.delete(image_file) - db.session.delete(segment) + logger.exception("Delete file failed when document deleted, file_id: %s", file.id) + stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) + session.execute(stmt) - db.session.commit() - if file_ids: - files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() - for file in files: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file.id) - db.session.delete(file) + session.commit() - db.session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned documents when documents deleted latency: {end_at - start_at}", - fg="green", + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned documents when documents deleted latency: {end_at - start_at}", + fg="green", + ) ) - ) - except Exception: - logger.exception("Cleaned documents when documents deleted failed") - finally: - db.session.close() + except Exception: + logger.exception("Cleaned documents when documents deleted failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index bd95af2614..8ee09d5738 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -9,9 +9,9 @@ import pandas as pd from celery import shared_task from sqlalchemy import func +from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper @@ -48,104 +48,107 @@ def batch_create_segment_to_index_task( indexing_cache_key = f"segment_batch_import_{job_id}" - try: - dataset = db.session.get(Dataset, dataset_id) - if not dataset: - raise ValueError("Dataset not exist.") + with session_factory.create_session() as session: + try: + dataset = session.get(Dataset, dataset_id) + if not dataset: + raise ValueError("Dataset not exist.") - dataset_document = db.session.get(Document, document_id) - if not dataset_document: - raise ValueError("Document not exist.") + dataset_document = session.get(Document, document_id) + if not dataset_document: + raise ValueError("Document not exist.") - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - raise ValueError("Document is not available.") + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + raise ValueError("Document is not available.") - upload_file = db.session.get(UploadFile, upload_file_id) - if not upload_file: - raise ValueError("UploadFile not found.") + upload_file = session.get(UploadFile, upload_file_id) + if not upload_file: + raise ValueError("UploadFile not found.") - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore - storage.download(upload_file.key, file_path) + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file.key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file.key, file_path) - df = pd.read_csv(file_path) - content = [] - for _, row in df.iterrows(): + df = pd.read_csv(file_path) + content = [] + for _, row in df.iterrows(): + if dataset_document.doc_form == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} + else: + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") + + document_segments = [] + embedding_model = None + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + + word_count_change = 0 + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens( + texts=[segment["content"] for segment in content] + ) + else: + tokens_list = [0] * len(content) + + for segment, tokens in zip(content, tokens_list): + content = segment["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + max_position = ( + session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == dataset_document.id) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + created_by=user_id, + indexing_at=naive_utc_now(), + status="completed", + completed_at=naive_utc_now(), + ) if dataset_document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - content.append(data) - if len(content) == 0: - raise ValueError("The CSV file is empty.") + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count + session.add(segment_document) + document_segments.append(segment_document) - document_segments = [] - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) + assert dataset_document.word_count is not None + dataset_document.word_count += word_count_change + session.add(dataset_document) - word_count_change = 0 - if embedding_model: - tokens_list = embedding_model.get_text_embedding_num_tokens( - texts=[segment["content"] for segment in content] + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) + session.commit() + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + f"Segment batch created job: {job_id} latency: {end_at - start_at}", + fg="green", + ) ) - else: - tokens_list = [0] * len(content) - - for segment, tokens in zip(content, tokens_list): - content = segment["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == dataset_document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=tenant_id, - dataset_id=dataset_id, - document_id=document_id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - created_by=user_id, - indexing_at=naive_utc_now(), - status="completed", - completed_at=naive_utc_now(), - ) - if dataset_document.doc_form == "qa_model": - segment_document.answer = segment["answer"] - segment_document.word_count += len(segment["answer"]) - word_count_change += segment_document.word_count - db.session.add(segment_document) - document_segments.append(segment_document) - - assert dataset_document.word_count is not None - dataset_document.word_count += word_count_change - db.session.add(dataset_document) - - VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) - db.session.commit() - redis_client.setex(indexing_cache_key, 600, "completed") - end_at = time.perf_counter() - logger.info( - click.style( - f"Segment batch created job: {job_id} latency: {end_at - start_at}", - fg="green", - ) - ) - except Exception: - logger.exception("Segments batch created index failed") - redis_client.setex(indexing_cache_key, 600, "error") - finally: - db.session.close() + except Exception: + logger.exception("Segments batch created index failed") + redis_client.setex(indexing_cache_key, 600, "error") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index b4d82a150d..0d51a743ad 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from extensions.ext_database import db from extensions.ext_storage import storage from models import WorkflowType from models.dataset import ( @@ -53,135 +53,155 @@ def clean_dataset_task( logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green")) start_at = time.perf_counter() - try: - dataset = Dataset( - id=dataset_id, - tenant_id=tenant_id, - indexing_technique=indexing_technique, - index_struct=index_struct, - collection_binding_id=collection_binding_id, - ) - documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() - # Use JOIN to fetch attachments with bindings in a single query - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id) - ).all() - - # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace - # This ensures all invalid doc_form values are properly handled - if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): - # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup - from core.rag.index_processor.constant.index_type import IndexStructureType - - doc_form = IndexStructureType.PARAGRAPH_INDEX - logger.info( - click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow") - ) - - # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure - # This ensures Document/Segment deletion can continue even if vector database cleanup fails + with session_factory.create_session() as session: try: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) - logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green")) - except Exception: - logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red")) - # Continue with document and segment deletion even if vector cleanup fails - logger.info( - click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow") + dataset = Dataset( + id=dataset_id, + tenant_id=tenant_id, + indexing_technique=indexing_technique, + index_struct=index_struct, + collection_binding_id=collection_binding_id, ) + documents = session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + ) + ).all() - if documents is None or len(documents) == 0: - logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) - else: - logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) + # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace + # This ensures all invalid doc_form values are properly handled + if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): + # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup + from core.rag.index_processor.constant.index_type import IndexStructureType - for document in documents: - db.session.delete(document) - # delete document file + doc_form = IndexStructureType.PARAGRAPH_INDEX + logger.info( + click.style( + f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", + fg="yellow", + ) + ) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() - if image_file is None: - continue + # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure + # This ensures Document/Segment deletion can continue even if vector database cleanup fails + try: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) + logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green")) + except Exception: + logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red")) + # Continue with document and segment deletion even if vector cleanup fails + logger.info( + click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow") + ) + + if documents is None or len(documents) == 0: + logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) + else: + logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) + + for document in documents: + session.delete(document) + + segment_ids = [segment.id for segment in segments] + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() + for image_file in image_files: + if image_file is None: + continue + try: + storage.delete(image_file.key) + except Exception: + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file.id, + ) + stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(stmt) + + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + # delete segment attachments + if attachments_with_bindings: + attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] + binding_ids = [binding.id for binding, _ in attachments_with_bindings] + for binding, attachment_file in attachments_with_bindings: try: - storage.delete(image_file.key) + storage.delete(attachment_file.key) except Exception: logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - upload_file_id, + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, ) - db.session.delete(image_file) - db.session.delete(segment) - # delete segment attachments - if attachments_with_bindings: - for binding, attachment_file in attachments_with_bindings: - try: - storage.delete(attachment_file.key) - except Exception: - logger.exception( - "Delete attachment_file failed when storage deleted, \ - attachment_file_id: %s", - binding.attachment_id, - ) - db.session.delete(attachment_file) - db.session.delete(binding) + attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) + session.execute(attachment_file_delete_stmt) - db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() - db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() - db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() - # delete dataset metadata - db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() - db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() - # delete pipeline and workflow - if pipeline_id: - db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() - db.session.query(Workflow).where( - Workflow.tenant_id == tenant_id, - Workflow.app_id == pipeline_id, - Workflow.type == WorkflowType.RAG_PIPELINE, - ).delete() - # delete files - if documents: - for document in documents: - try: + binding_delete_stmt = delete(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.id.in_(binding_ids) + ) + session.execute(binding_delete_stmt) + + session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() + session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() + session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() + # delete dataset metadata + session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() + session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() + # delete pipeline and workflow + if pipeline_id: + session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() + session.query(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.app_id == pipeline_id, + Workflow.type == WorkflowType.RAG_PIPELINE, + ).delete() + # delete files + if documents: + file_ids = [] + for document in documents: if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - file = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) - .first() - ) - if not file: - continue - storage.delete(file.key) - db.session.delete(file) - except Exception: - continue + file_ids.append(file_id) + files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() + for file in files: + storage.delete(file.key) - db.session.commit() - end_at = time.perf_counter() - logger.info( - click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green") - ) - except Exception: - # Add rollback to prevent dirty session state in case of exceptions - # This ensures the database session is properly cleaned up - try: - db.session.rollback() - logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow")) + file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) + session.execute(file_delete_stmt) + + session.commit() + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", + fg="green", + ) + ) except Exception: - logger.exception("Failed to rollback database session") + # Add rollback to prevent dirty session state in case of exceptions + # This ensures the database session is properly cleaned up + try: + session.rollback() + logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow")) + except Exception: + logger.exception("Failed to rollback database session") - logger.exception("Cleaned dataset when dataset deleted failed") - finally: - db.session.close() + logger.exception("Cleaned dataset when dataset deleted failed") + finally: + # Explicitly close the session for test expectations and safety + try: + session.close() + except Exception: + logger.exception("Failed to close database session") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 6d2feb1da3..86e7cc7160 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding from models.model import UploadFile @@ -29,85 +29,94 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Document has no dataset") + if not dataset: + raise Exception("Document has no dataset") - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - # Use JOIN to fetch attachments with bindings in a single query - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where( - SegmentAttachmentBinding.tenant_id == dataset.tenant_id, - SegmentAttachmentBinding.dataset_id == dataset_id, - SegmentAttachmentBinding.document_id == document_id, - ) - ).all() - # check segment is exist - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == dataset.tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + SegmentAttachmentBinding.document_id == document_id, + ) + ).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() - if image_file is None: - continue + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + image_files = session.scalars( + select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + ).all() + for image_file in image_files: + if image_file is None: + continue + try: + storage.delete(image_file.key) + except Exception: + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file.id, + ) + + image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(image_file_delete_stmt) + session.delete(segment) + + session.commit() + if file_id: + file = session.query(UploadFile).where(UploadFile.id == file_id).first() + if file: try: - storage.delete(image_file.key) + storage.delete(file.key) + except Exception: + logger.exception("Delete file failed when document deleted, file_id: %s", file_id) + session.delete(file) + # delete segment attachments + if attachments_with_bindings: + attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] + binding_ids = [binding.id for binding, _ in attachments_with_bindings] + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) except Exception: logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - upload_file_id, + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, ) - db.session.delete(image_file) - db.session.delete(segment) + attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) + session.execute(attachment_file_delete_stmt) - db.session.commit() - if file_id: - file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() - if file: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file_id) - db.session.delete(file) - db.session.commit() - # delete segment attachments - if attachments_with_bindings: - for binding, attachment_file in attachments_with_bindings: - try: - storage.delete(attachment_file.key) - except Exception: - logger.exception( - "Delete attachment_file failed when storage deleted, \ - attachment_file_id: %s", - binding.attachment_id, - ) - db.session.delete(attachment_file) - db.session.delete(binding) + binding_delete_stmt = delete(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.id.in_(binding_ids) + ) + session.execute(binding_delete_stmt) - # delete dataset metadata binding - db.session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() - db.session.commit() + # delete dataset metadata binding + session.query(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ).delete() + session.commit() - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", - fg="green", + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", + fg="green", + ) ) - ) - except Exception: - logger.exception("Cleaned document when document deleted failed") - finally: - db.session.close() + except Exception: + logger.exception("Cleaned document when document deleted failed") diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 771b43f9b0..bcca1bf49f 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -3,10 +3,10 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment logger = logging.getLogger(__name__) @@ -24,37 +24,37 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Document has no dataset") - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - for document_id in document_ids: - document = db.session.query(Document).where(Document.id == document_id).first() - db.session.delete(document) + if not dataset: + raise Exception("Document has no dataset") + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) - ).all() - index_node_ids = [segment.index_node_id for segment in segments] + document_delete_stmt = delete(Document).where(Document.id.in_(document_ids)) + session.execute(document_delete_stmt) - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + for document_id in document_ids: + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() + index_node_ids = [segment.index_node_id for segment in segments] - for segment in segments: - db.session.delete(segment) - db.session.commit() - end_at = time.perf_counter() - logger.info( - click.style( - "Clean document when import form notion document deleted end :: {} latency: {}".format( - dataset_id, end_at - start_at - ), - fg="green", + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + session.commit() + end_at = time.perf_counter() + logger.info( + click.style( + "Clean document when import form notion document deleted end :: {} latency: {}".format( + dataset_id, end_at - start_at + ), + fg="green", + ) ) - ) - except Exception: - logger.exception("Cleaned document when import form notion document deleted failed") - finally: - db.session.close() + except Exception: + logger.exception("Cleaned document when import form notion document deleted failed") diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 6b2907cffd..b5e472d71e 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -4,9 +4,9 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment @@ -25,75 +25,77 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() - if not segment: - logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) - db.session.close() - return - - if segment.status != "waiting": - db.session.close() - return - - indexing_cache_key = f"segment_{segment.id}_indexing" - - try: - # update segment status to indexing - db.session.query(DocumentSegment).filter_by(id=segment.id).update( - { - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: naive_utc_now(), - } - ) - db.session.commit() - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - - dataset = segment.dataset - - if not dataset: - logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + with session_factory.create_session() as session: + segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + if not segment: + logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - dataset_document = segment.document - - if not dataset_document: - logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + if segment.status != "waiting": return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) - return + indexing_cache_key = f"segment_{segment.id}_indexing" - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, [document]) + try: + # update segment status to indexing + session.query(DocumentSegment).filter_by(id=segment.id).update( + { + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: naive_utc_now(), + } + ) + session.commit() + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) - # update segment to completed - db.session.query(DocumentSegment).filter_by(id=segment.id).update( - { - DocumentSegment.status: "completed", - DocumentSegment.completed_at: naive_utc_now(), - } - ) - db.session.commit() + dataset = segment.dataset - end_at = time.perf_counter() - logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("create segment to index failed") - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.status = "error" - segment.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + if not dataset: + logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + return + + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) + return + + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, [document]) + + # update segment to completed + session.query(DocumentSegment).filter_by(id=segment.id).update( + { + DocumentSegment.status: "completed", + DocumentSegment.completed_at: naive_utc_now(), + } + ) + session.commit() + + end_at = time.perf_counter() + logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green")) + except Exception as e: + logger.exception("create segment to index failed") + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.status = "error" + segment.error = str(e) + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index 3d13afdec0..fa844a8647 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -4,11 +4,11 @@ import time import click from celery import shared_task # type: ignore +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -24,166 +24,174 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset: - raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX - index_processor = IndexProcessorFactory(index_type).init_index_processor() - if action == "upgrade": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, + if not dataset: + raise Exception("Dataset not found") + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if action == "upgrade": + dataset_documents = ( + session.query(DatasetDocument) + .where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() ) - .all() - ) - if dataset_documents: - dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False - ) - db.session.commit() + if dataset_documents: + dataset_documents_ids = [doc.id for doc in dataset_documents] + session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + session.commit() - for dataset_document in dataset_documents: - try: - # add from vector index - segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) - .order_by(DocumentSegment.position.asc()) - .all() - ) - if segments: - documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, + for dataset_document in dataset_documents: + try: + # add from vector index + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, ) - - documents.append(document) - # save vector index - # clean keywords - index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) - index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False - ) - db.session.commit() - except Exception as e: - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False - ) - db.session.commit() - elif action == "update": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - .all() - ) - # add new index - if dataset_documents: - # update document status - dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False - ) - db.session.commit() - - # clean index - index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) - - for dataset_document in dataset_documents: - # update from vector index - try: - segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) - .order_by(DocumentSegment.position.asc()) - .all() - ) - if segments: - documents = [] - multimodal_documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - child_documents.append(child_document) - document.children = child_documents - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodal_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - documents.append(document) - # save vector index - index_processor.load( - dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + .order_by(DocumentSegment.position.asc()) + .all() ) - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False - ) - db.session.commit() - except Exception as e: - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False - ) - db.session.commit() - else: - # clean collection - index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) - end_at = time.perf_counter() - logging.info( - click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") - ) - except Exception: - logging.exception("Deal dataset vector index failed") - finally: - db.session.close() + documents.append(document) + # save vector index + # clean keywords + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) + index_processor.load(dataset, documents, with_keywords=False) + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + session.commit() + except Exception as e: + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + session.commit() + elif action == "update": + dataset_documents = ( + session.query(DatasetDocument) + .where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + # add new index + if dataset_documents: + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + for dataset_document in dataset_documents: + # update from vector index + try: + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + multimodal_documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + documents.append(document) + # save vector index + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + session.commit() + except Exception as e: + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info( + click.style( + "Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Deal dataset vector index failed") diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 1c7de3b1ce..0047e04a17 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -5,11 +5,11 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -27,160 +27,170 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset: - raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX - index_processor = IndexProcessorFactory(index_type).init_index_processor() - if action == "remove": - index_processor.clean(dataset, None, with_keywords=False) - elif action == "add": - dataset_documents = db.session.scalars( - select(DatasetDocument).where( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - ).all() + if not dataset: + raise Exception("Dataset not found") + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if action == "remove": + index_processor.clean(dataset, None, with_keywords=False) + elif action == "add": + dataset_documents = session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() - if dataset_documents: - dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False - ) - db.session.commit() + if dataset_documents: + dataset_documents_ids = [doc.id for doc in dataset_documents] + session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + session.commit() - for dataset_document in dataset_documents: - try: - # add from vector index - segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) - .order_by(DocumentSegment.position.asc()) - .all() - ) - if segments: - documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, + for dataset_document in dataset_documents: + try: + # add from vector index + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, ) - - documents.append(document) - # save vector index - index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False - ) - db.session.commit() - except Exception as e: - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False - ) - db.session.commit() - elif action == "update": - dataset_documents = db.session.scalars( - select(DatasetDocument).where( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - ).all() - # add new index - if dataset_documents: - # update document status - dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False - ) - db.session.commit() - - # clean index - index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) - - for dataset_document in dataset_documents: - # update from vector index - try: - segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) - .order_by(DocumentSegment.position.asc()) - .all() - ) - if segments: - documents = [] - multimodal_documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - child_documents.append(child_document) - document.children = child_documents - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodal_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - documents.append(document) - # save vector index - index_processor.load( - dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + .order_by(DocumentSegment.position.asc()) + .all() ) - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False - ) - db.session.commit() - except Exception as e: - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False - ) - db.session.commit() - else: - # clean collection - index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) - end_at = time.perf_counter() - logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green")) - except Exception: - logger.exception("Deal dataset vector index failed") - finally: - db.session.close() + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + session.commit() + except Exception as e: + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + session.commit() + elif action == "update": + dataset_documents = session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() + # add new index + if dataset_documents: + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + for dataset_document in dataset_documents: + # update from vector index + try: + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + multimodal_documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + documents.append(document) + # save vector index + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + session.commit() + except Exception as e: + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("Deal dataset vector index failed") diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index cb703cc263..ecf6f9cb39 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -3,7 +3,7 @@ import logging from celery import shared_task from configs import dify_config -from extensions.ext_database import db +from core.db.session_factory import session_factory from models import Account from services.billing_service import BillingService from tasks.mail_account_deletion_task import send_deletion_success_task @@ -13,16 +13,17 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_account_task(account_id): - account = db.session.query(Account).where(Account.id == account_id).first() - try: - if dify_config.BILLING_ENABLED: - BillingService.delete_account(account_id) - except Exception: - logger.exception("Failed to delete account %s from billing service.", account_id) - raise + with session_factory.create_session() as session: + account = session.query(Account).where(Account.id == account_id).first() + try: + if dify_config.BILLING_ENABLED: + BillingService.delete_account(account_id) + except Exception: + logger.exception("Failed to delete account %s from billing service.", account_id) + raise - if not account: - logger.error("Account %s not found.", account_id) - return - # send success email - send_deletion_success_task.delay(account.email) + if not account: + logger.error("Account %s not found.", account_id) + return + # send success email + send_deletion_success_task.delay(account.email) diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py index 756b67c93e..9664b8ac73 100644 --- a/api/tasks/delete_conversation_task.py +++ b/api/tasks/delete_conversation_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from extensions.ext_database import db +from core.db.session_factory import session_factory from models import ConversationVariable from models.model import Message, MessageAnnotation, MessageFeedback from models.tools import ToolConversationVariables, ToolFile @@ -27,44 +27,46 @@ def delete_conversation_related_data(conversation_id: str): ) start_at = time.perf_counter() - try: - db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - db.session.query(ToolConversationVariables).where( - ToolConversationVariables.conversation_id == conversation_id - ).delete(synchronize_session=False) - - db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False) - - db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False) - - db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - db.session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}", - fg="green", + with session_factory.create_session() as session: + try: + session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete( + synchronize_session=False ) - ) - except Exception as e: - logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id) - db.session.rollback() - raise e - finally: - db.session.close() + session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + session.query(ToolConversationVariables).where( + ToolConversationVariables.conversation_id == conversation_id + ).delete(synchronize_session=False) + + session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False) + + session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False) + + session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + session.commit() + + end_at = time.perf_counter() + logger.info( + click.style( + ( + f"Succeeded cleaning data from db for conversation_id {conversation_id} " + f"latency: {end_at - start_at}" + ), + fg="green", + ) + ) + + except Exception: + logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id) + session.rollback() + raise diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index bea5c952cf..bfa709502c 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -4,8 +4,8 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from models.dataset import Dataset, Document, SegmentAttachmentBinding from models.model import UploadFile @@ -26,49 +26,52 @@ def delete_segment_from_index_task( """ logger.info(click.style("Start delete segment from index", fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) - return + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) + return - dataset_document = db.session.query(Document).where(Document.id == document_id).first() - if not dataset_document: - return + dataset_document = session.query(Document).where(Document.id == document_id).first() + if not dataset_document: + return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info("Document not in valid state for index operations, skipping") - return - doc_form = dataset_document.doc_form + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + logging.info("Document not in valid state for index operations, skipping") + return + doc_form = dataset_document.doc_form - # Proceed with index cleanup using the index_node_ids directly - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean( - dataset, - index_node_ids, - with_keywords=True, - delete_child_chunks=True, - precomputed_child_node_ids=child_node_ids, - ) - if dataset.is_multimodal: - # delete segment attachment binding - segment_attachment_bindings = ( - db.session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() + # Proceed with index cleanup using the index_node_ids directly + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, + index_node_ids, + with_keywords=True, + delete_child_chunks=True, + precomputed_child_node_ids=child_node_ids, ) - if segment_attachment_bindings: - attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] - index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) - for binding in segment_attachment_bindings: - db.session.delete(binding) - # delete upload file - db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) - db.session.commit() + if dataset.is_multimodal: + # delete segment attachment binding + segment_attachment_bindings = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) + for binding in segment_attachment_bindings: + session.delete(binding) + # delete upload file + session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + session.commit() - end_at = time.perf_counter() - logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) - except Exception: - logger.exception("delete segment from index failed") - finally: - db.session.close() + end_at = time.perf_counter() + logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) + except Exception: + logger.exception("delete segment from index failed") diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 6b5f01b416..0ce6429a94 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -4,8 +4,8 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -23,46 +23,53 @@ def disable_segment_from_index_task(segment_id: str): logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() - if not segment: - logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) - db.session.close() - return - - if segment.status != "completed": - logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red")) - db.session.close() - return - - indexing_cache_key = f"segment_{segment.id}_indexing" - - try: - dataset = segment.dataset - - if not dataset: - logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + with session_factory.create_session() as session: + segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + if not segment: + logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - dataset_document = segment.document - - if not dataset_document: - logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + if segment.status != "completed": + logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) - return + indexing_cache_key = f"segment_{segment.id}_indexing" - index_type = dataset_document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, [segment.index_node_id]) + try: + dataset = segment.dataset - end_at = time.perf_counter() - logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green")) - except Exception: - logger.exception("remove segment from index failed") - segment.enabled = True - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + if not dataset: + logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + return + + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) + return + + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.clean(dataset, [segment.index_node_id]) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Segment removed from index: {segment.id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("remove segment from index failed") + segment.enabled = True + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index c2a3de29f4..03635902d1 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -5,8 +5,8 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument @@ -26,69 +26,65 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) - db.session.close() - return + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) + return - dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() - if not dataset_document: - logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) - db.session.close() - return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) - db.session.close() - return - # sync index processor - index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if not dataset_document: + logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ) - ).all() - - if not segments: - db.session.close() - return - - try: - index_node_ids = [segment.index_node_id for segment in segments] - if dataset.is_multimodal: - segment_ids = [segment.id for segment in segments] - segment_attachment_bindings = ( - db.session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() + segments = session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, ) - if segment_attachment_bindings: - attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] - index_node_ids.extend(attachment_ids) - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + ).all() - end_at = time.perf_counter() - logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) - except Exception: - # update segment error msg - db.session.query(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ).update( - { - "disabled_at": None, - "disabled_by": None, - "enabled": True, - } - ) - db.session.commit() - finally: - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - redis_client.delete(indexing_cache_key) - db.session.close() + if not segments: + return + + try: + index_node_ids = [segment.index_node_id for segment in segments] + if dataset.is_multimodal: + segment_ids = [segment.id for segment in segments] + segment_attachment_bindings = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_node_ids.extend(attachment_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + + end_at = time.perf_counter() + logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) + except Exception: + # update segment error msg + session.query(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "disabled_at": None, + "disabled_by": None, + "enabled": True, + } + ) + session.commit() + finally: + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 5fc2597c92..149185f6e2 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,12 +3,12 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.datasource_provider_service import DatasourceProviderService @@ -28,105 +28,103 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Start sync document: {document_id}", fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + with session_factory.create_session() as session: + document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="red")) - db.session.close() - return - - data_source_info = document.data_source_info_dict - if document.data_source_type == "notion_import": - if ( - not data_source_info - or "notion_page_id" not in data_source_info - or "notion_workspace_id" not in data_source_info - ): - raise ValueError("no notion page found") - workspace_id = data_source_info["notion_workspace_id"] - page_id = data_source_info["notion_page_id"] - page_type = data_source_info["type"] - page_edited_time = data_source_info["last_edited_time"] - credential_id = data_source_info.get("credential_id") - - # Get credentials from datasource provider - datasource_provider_service = DatasourceProviderService() - credential = datasource_provider_service.get_datasource_credentials( - tenant_id=document.tenant_id, - credential_id=credential_id, - provider="notion_datasource", - plugin_id="langgenius/notion_datasource", - ) - - if not credential: - logger.error( - "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", - document_id, - document.tenant_id, - credential_id, - ) - document.indexing_status = "error" - document.error = "Datasource credential not found. Please reconnect your Notion workspace." - document.stopped_at = naive_utc_now() - db.session.commit() - db.session.close() + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="red")) return - loader = NotionExtractor( - notion_workspace_id=workspace_id, - notion_obj_id=page_id, - notion_page_type=page_type, - notion_access_token=credential.get("integration_secret"), - tenant_id=document.tenant_id, - ) + data_source_info = document.data_source_info_dict + if document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_page_id" not in data_source_info + or "notion_workspace_id" not in data_source_info + ): + raise ValueError("no notion page found") + workspace_id = data_source_info["notion_workspace_id"] + page_id = data_source_info["notion_page_id"] + page_type = data_source_info["type"] + page_edited_time = data_source_info["last_edited_time"] + credential_id = data_source_info.get("credential_id") - last_edited_time = loader.get_notion_last_edited_time() + # Get credentials from datasource provider + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_datasource_credentials( + tenant_id=document.tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) - # check the page is updated - if last_edited_time != page_edited_time: - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - db.session.commit() - - # delete all document segment and index - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) - ).all() - index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - - end_at = time.perf_counter() - logger.info( - click.style( - "Cleaned document when document update data source or process rule: {} latency: {}".format( - document_id, end_at - start_at - ), - fg="green", - ) + if not credential: + logger.error( + "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", + document_id, + document.tenant_id, + credential_id, ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + document.indexing_status = "error" + document.error = "Datasource credential not found. Please reconnect your Notion workspace." + document.stopped_at = naive_utc_now() + session.commit() + return - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_sync_task failed, document_id: %s", document_id) - finally: - db.session.close() + loader = NotionExtractor( + notion_workspace_id=workspace_id, + notion_obj_id=page_id, + notion_page_type=page_type, + notion_access_token=credential.get("integration_secret"), + tenant_id=document.tenant_id, + ) + + last_edited_time = loader.get_notion_last_edited_time() + + # check the page is updated + if last_edited_time != page_edited_time: + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + session.commit() + + # delete all document segment and index + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + raise Exception("Dataset not found") + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + + end_at = time.perf_counter() + logger.info( + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) + except Exception: + logger.exception("Cleaned document when document update data source or process rule failed") + + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("document_indexing_sync_task failed, document_id: %s", document_id) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index acbdab631b..3bdff60196 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -6,11 +6,11 @@ import click from celery import shared_task from configs import dify_config +from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -46,66 +46,63 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): documents = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) - db.session.close() - return - # check document limit - features = FeatureService.get_features(dataset.tenant_id) - try: - if features.billing.enabled: - vector_space = features.vector_space - count = len(document_ids) - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - if 0 < vector_space.limit <= vector_space.size: - raise ValueError( - "Your total number of documents plus the number of uploads have over the limit of " - "your subscription." + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) + return + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + count = len(document_ids) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + for document_id in document_ids: + document = ( + session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) - except Exception as e: + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + session.add(document) + session.commit() + return + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) + document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) + if document: - document.indexing_status = "error" - document.error = str(e) - document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - db.session.close() - return + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + documents.append(document) + session.add(document) + session.commit() - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - - if document: - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() - - try: - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) - finally: - db.session.close() + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) def _document_indexing_with_tenant_queue( diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 161502a228..67a23be952 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,8 +3,9 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -26,56 +27,54 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logger.info(click.style(f"Start update document: {document_id}", fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + with session_factory.create_session() as session: + document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="red")) - db.session.close() - return + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="red")) + return - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - db.session.commit() + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + session.commit() - # delete all document segment and index - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") + # delete all document segment and index + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + raise Exception("Dataset not found") - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - end_at = time.perf_counter() - logger.info( - click.style( - "Cleaned document when document update data source or process rule: {} latency: {}".format( - document_id, end_at - start_at - ), - fg="green", + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + db.session.commit() + end_at = time.perf_counter() + logger.info( + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) ) - ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + except Exception: + logger.exception("Cleaned document when document update data source or process rule failed") - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_update_task failed, document_id: %s", document_id) - finally: - db.session.close() + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("document_indexing_update_task failed, document_id: %s", document_id) diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 4078c8910e..00a963255b 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -4,15 +4,15 @@ from collections.abc import Callable, Sequence import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select from configs import dify_config +from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService @@ -76,63 +76,64 @@ def _duplicate_document_indexing_task_with_tenant_queue( def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]): - documents = [] + documents: list[Document] = [] start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if dataset is None: - logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) - db.session.close() - return - - # check document limit - features = FeatureService.get_features(dataset.tenant_id) + with session_factory.create_session() as session: try: - if features.billing.enabled: - vector_space = features.vector_space - count = len(document_ids) - if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - current = int(getattr(vector_space, "size", 0) or 0) - limit = int(getattr(vector_space, "limit", 0) or 0) - if limit > 0 and (current + count) > limit: - raise ValueError( - "Your total number of documents plus the number of uploads have exceeded the limit of " - "your subscription." - ) - except Exception as e: - for document_id in document_ids: - document = ( - db.session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset is None: + logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + count = len(document_ids) + if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + current = int(getattr(vector_space, "size", 0) or 0) + limit = int(getattr(vector_space, "limit", 0) or 0) + if limit > 0 and (current + count) > limit: + raise ValueError( + "Your total number of documents plus the number of uploads have exceeded the limit of " + "your subscription." + ) + except Exception as e: + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) - if document: - document.indexing_status = "error" - document.error = str(e) - document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - return + for document in documents: + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + session.add(document) + session.commit() + return - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) - if document: + for document in documents: + logger.info(click.style(f"Start process document: {document.id}", fg="green")) + # clean old data index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id) ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -140,26 +141,24 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st # delete from vector index index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + session.commit() document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) - finally: - db.session.close() + indexing_runner = IndexingRunner() + indexing_runner.run(list(documents)) + end_at = time.perf_counter() + logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) @shared_task(queue="dataset") diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 7615469ed0..1f9f21aa7e 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,11 +4,11 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment @@ -27,91 +27,93 @@ def enable_segment_to_index_task(segment_id: str): logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() - if not segment: - logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) - db.session.close() - return - - if segment.status != "completed": - logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red")) - db.session.close() - return - - indexing_cache_key = f"segment_{segment.id}_indexing" - - try: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - - dataset = segment.dataset - - if not dataset: - logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + with session_factory.create_session() as session: + segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + if not segment: + logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - dataset_document = segment.document - - if not dataset_document: - logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + if segment.status != "completed": + logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) - return + indexing_cache_key = f"segment_{segment.id}_indexing" - index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, + try: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + + dataset = segment.dataset + + if not dataset: + logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + return + + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) + return + + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + multimodel_documents = [] + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodel_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) ) - child_documents.append(child_document) - document.children = child_documents - multimodel_documents = [] - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodel_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - # save vector index - index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) + # save vector index + index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) - end_at = time.perf_counter() - logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("enable segment to index failed") - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.status = "error" - segment.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + end_at = time.perf_counter() + logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) + except Exception as e: + logger.exception("enable segment to index failed") + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.status = "error" + segment.error = str(e) + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 9f17d09e18..48d3c8e178 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -5,11 +5,11 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DocumentSegment @@ -29,105 +29,102 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id) """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) - return + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) + return - dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() - if not dataset_document: - logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) - db.session.close() - return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) - db.session.close() - return - # sync index processor - index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if not dataset_document: + logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ) - ).all() - if not segments: - logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) - db.session.close() - return - - try: - documents = [] - multimodal_documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": document_id, - "dataset_id": dataset_id, - }, + segments = session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, ) + ).all() + if not segments: + logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) + return - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": document_id, - "dataset_id": dataset_id, - }, + try: + documents = [] + multimodal_documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) ) - child_documents.append(child_document) - document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodal_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - documents.append(document) - # save vector index - index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) - - end_at = time.perf_counter() - logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("enable segments to index failed") - # update segment error msg - db.session.query(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ).update( - { - "error": str(e), - "status": "error", - "disabled_at": naive_utc_now(), - "enabled": False, - } - ) - db.session.commit() - finally: - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - redis_client.delete(indexing_cache_key) - db.session.close() + end_at = time.perf_counter() + logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) + except Exception as e: + logger.exception("enable segments to index failed") + # update segment error msg + session.query(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "error": str(e), + "status": "error", + "disabled_at": naive_utc_now(), + "enabled": False, + } + ) + session.commit() + finally: + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index e6492c230d..b5e6508006 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3 CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:" -CACHE_REDIS_TTL = 60 * 15 # 15 minutes +CACHE_REDIS_TTL = 60 * 60 # 1 hour def _get_redis_cache_key(plugin_id: str) -> str: diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 1eef361a92..3c5e152520 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_id=workflow_id, user=account, application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 275f5abe6e..093342d1a3 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_id=workflow_id, user=account, application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 1b2a653c01..af72023da1 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -4,8 +4,8 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from extensions.ext_database import db from models.dataset import Document logger = logging.getLogger(__name__) @@ -23,26 +23,24 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): logger.info(click.style(f"Recover document: {document_id}", fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + with session_factory.create_session() as session: + document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="red")) - db.session.close() - return + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="red")) + return - try: - indexing_runner = IndexingRunner() - if document.indexing_status in {"waiting", "parsing", "cleaning"}: - indexing_runner.run([document]) - elif document.indexing_status == "splitting": - indexing_runner.run_in_splitting_status(document) - elif document.indexing_status == "indexing": - indexing_runner.run_in_indexing_status(document) - end_at = time.perf_counter() - logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("recover_document_indexing_task failed, document_id: %s", document_id) - finally: - db.session.close() + try: + indexing_runner = IndexingRunner() + if document.indexing_status in {"waiting", "parsing", "cleaning"}: + indexing_runner.run([document]) + elif document.indexing_status == "splitting": + indexing_runner.run_in_splitting_status(document) + elif document.indexing_status == "indexing": + indexing_runner.run_in_indexing_status(document) + end_at = time.perf_counter() + logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("recover_document_indexing_task failed, document_id: %s", document_id) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 3227f6da96..817249845a 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -1,15 +1,20 @@ import logging import time from collections.abc import Callable +from typing import Any, cast import click import sqlalchemy as sa from celery import shared_task from sqlalchemy import delete +from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker +from configs import dify_config +from core.db.session_factory import session_factory from extensions.ext_database import db +from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from models import ( ApiToken, AppAnnotationHitHistory, @@ -40,6 +45,7 @@ from models.workflow import ( ConversationVariable, Workflow, WorkflowAppLog, + WorkflowArchiveLog, ) from repositories.factory import DifyAPIRepositoryFactory @@ -64,6 +70,9 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_app_workflow_runs(tenant_id, app_id) _delete_app_workflow_node_executions(tenant_id, app_id) _delete_app_workflow_app_logs(tenant_id, app_id) + if dify_config.BILLING_ENABLED and dify_config.ARCHIVE_STORAGE_ENABLED: + _delete_app_workflow_archive_logs(tenant_id, app_id) + _delete_archived_workflow_run_files(tenant_id, app_id) _delete_app_conversations(tenant_id, app_id) _delete_app_messages(tenant_id, app_id) _delete_workflow_tool_providers(tenant_id, app_id) @@ -77,7 +86,6 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_workflow_webhook_triggers(tenant_id, app_id) _delete_workflow_schedule_plans(tenant_id, app_id) _delete_workflow_trigger_logs(tenant_id, app_id) - end_at = time.perf_counter() logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) except SQLAlchemyError as e: @@ -89,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str): - def del_model_config(model_config_id: str): - db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + def del_model_config(session, model_config_id: str): + session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) _delete_records( """select id from app_model_configs where app_id=:app_id limit 1000""", @@ -101,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str): - def del_site(site_id: str): - db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) + def del_site(session, site_id: str): + session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) _delete_records( """select id from sites where app_id=:app_id limit 1000""", @@ -113,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str): def _delete_app_mcp_servers(tenant_id: str, app_id: str): - def del_mcp_server(mcp_server_id: str): - db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + def del_mcp_server(session, mcp_server_id: str): + session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) _delete_records( """select id from app_mcp_servers where app_id=:app_id limit 1000""", @@ -125,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): - def del_api_token(api_token_id: str): - db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) + def del_api_token(session, api_token_id: str): + session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( """select id from api_tokens where app_id=:app_id limit 1000""", @@ -137,8 +145,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str): - def del_installed_app(installed_app_id: str): - db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + def del_installed_app(session, installed_app_id: str): + session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) _delete_records( """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -149,10 +157,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str): - def del_recommended_app(recommended_app_id: str): - db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete( - synchronize_session=False - ) + def del_recommended_app(session, recommended_app_id: str): + session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", @@ -163,8 +169,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str): - def del_annotation_hit_history(annotation_hit_history_id: str): - db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( + def del_annotation_hit_history(session, annotation_hit_history_id: str): + session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( synchronize_session=False ) @@ -175,8 +181,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): "annotation hit history", ) - def del_annotation_setting(annotation_setting_id: str): - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( + def del_annotation_setting(session, annotation_setting_id: str): + session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( synchronize_session=False ) @@ -189,8 +195,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str): - def del_dataset_join(dataset_join_id: str): - db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + def del_dataset_join(session, dataset_join_id: str): + session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) _delete_records( """select id from app_dataset_joins where app_id=:app_id limit 1000""", @@ -201,8 +207,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str): - def del_workflow(workflow_id: str): - db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) + def del_workflow(session, workflow_id: str): + session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) _delete_records( """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -241,10 +247,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): - def del_workflow_app_log(workflow_app_log_id: str): - db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete( - synchronize_session=False - ) + def del_workflow_app_log(session, workflow_app_log_id: str): + session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -254,12 +258,51 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): ) -def _delete_app_conversations(tenant_id: str, app_id: str): - def del_conversation(conversation_id: str): - db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( +def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): + def del_workflow_archive_log(workflow_archive_log_id: str): + db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( synchronize_session=False ) - db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) + + _delete_records( + """select id from workflow_archive_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow_archive_log, + "workflow archive log", + ) + + +def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): + prefix = f"{tenant_id}/app_id={app_id}/" + try: + archive_storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + logger.info("Archive storage not configured, skipping archive file cleanup: %s", e) + return + + try: + keys = archive_storage.list_objects(prefix) + except Exception: + logger.exception("Failed to list archive files for app %s", app_id) + return + + deleted = 0 + for key in keys: + try: + archive_storage.delete_object(key) + deleted += 1 + except Exception: + logger.exception("Failed to delete archive object %s", key) + + logger.info("Deleted %s archive objects for app %s", deleted, app_id) + + +def _delete_app_conversations(tenant_id: str, app_id: str): + def del_conversation(session, conversation_id: str): + session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", @@ -270,28 +313,26 @@ def _delete_app_conversations(tenant_id: str, app_id: str): def _delete_conversation_variables(*, app_id: str): - stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) - with db.engine.connect() as conn: - conn.execute(stmt) - conn.commit() + with session_factory.create_session() as session: + stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) + session.execute(stmt) + session.commit() logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) def _delete_app_messages(tenant_id: str, app_id: str): - def del_message(message_id: str): - db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete( + def del_message(session, message_id: str): + session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False) + session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( + session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) + session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) - db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False - ) - db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) - db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) - db.session.query(Message).where(Message.id == message_id).delete() + session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) + session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) + session.query(Message).where(Message.id == message_id).delete() _delete_records( """select id from messages where app_id=:app_id limit 1000""", @@ -302,8 +343,8 @@ def _delete_app_messages(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str): - def del_tool_provider(tool_provider_id: str): - db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( + def del_tool_provider(session, tool_provider_id: str): + session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( synchronize_session=False ) @@ -316,8 +357,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str): - def del_tag_binding(tag_binding_id: str): - db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + def del_tag_binding(session, tag_binding_id: str): + session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) _delete_records( """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", @@ -328,8 +369,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str): - def del_end_user(end_user_id: str): - db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) + def del_end_user(session, end_user_id: str): + session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) _delete_records( """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -340,10 +381,8 @@ def _delete_end_users(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str): - def del_trace_app_config(trace_app_config_id: str): - db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete( - synchronize_session=False - ) + def del_trace_app_config(session, trace_app_config_id: str): + session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", @@ -381,14 +420,14 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: total_files_deleted = 0 while True: - with db.engine.begin() as conn: + with session_factory.create_session() as session: # Get a batch of draft variable IDs along with their file_ids query_sql = """ SELECT id, file_id FROM workflow_draft_variables WHERE app_id = :app_id LIMIT :batch_size """ - result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size}) + result = session.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size}) rows = list(result) if not rows: @@ -399,7 +438,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: # Clean up associated Offload data first if file_ids: - files_deleted = _delete_draft_variable_offload_data(conn, file_ids) + files_deleted = _delete_draft_variable_offload_data(session, file_ids) total_files_deleted += files_deleted # Delete the draft variables @@ -407,8 +446,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: DELETE FROM workflow_draft_variables WHERE id IN :ids """ - deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}) - batch_deleted = deleted_result.rowcount + deleted_result = cast( + CursorResult[Any], + session.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}), + ) + batch_deleted: int = int(getattr(deleted_result, "rowcount", 0) or 0) total_deleted += batch_deleted logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green")) @@ -423,7 +465,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: return total_deleted -def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: +def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int: """ Delete Offload data associated with WorkflowDraftVariable file_ids. @@ -434,7 +476,7 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: 4. Deletes WorkflowDraftVariableFile records Args: - conn: Database connection + session: Database connection file_ids: List of WorkflowDraftVariableFile IDs Returns: @@ -450,12 +492,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: try: # Get WorkflowDraftVariableFile records and their associated UploadFile keys query_sql = """ - SELECT wdvf.id, uf.key, uf.id as upload_file_id - FROM workflow_draft_variable_files wdvf - JOIN upload_files uf ON wdvf.upload_file_id = uf.id - WHERE wdvf.id IN :file_ids - """ - result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)}) + SELECT wdvf.id, uf.key, uf.id as upload_file_id + FROM workflow_draft_variable_files wdvf + JOIN upload_files uf ON wdvf.upload_file_id = uf.id + WHERE wdvf.id IN :file_ids \ + """ + result = session.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)}) file_records = list(result) # Delete from object storage and collect upload file IDs @@ -473,17 +515,19 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: # Delete UploadFile records if upload_file_ids: delete_upload_files_sql = """ - DELETE FROM upload_files - WHERE id IN :upload_file_ids - """ - conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)}) + DELETE \ + FROM upload_files + WHERE id IN :upload_file_ids \ + """ + session.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)}) # Delete WorkflowDraftVariableFile records delete_variable_files_sql = """ - DELETE FROM workflow_draft_variable_files - WHERE id IN :file_ids - """ - conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)}) + DELETE \ + FROM workflow_draft_variable_files + WHERE id IN :file_ids \ + """ + session.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)}) except Exception: logging.exception("Error deleting draft variable offload data:") @@ -493,8 +537,8 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: def _delete_app_triggers(tenant_id: str, app_id: str): - def del_app_trigger(trigger_id: str): - db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) + def del_app_trigger(session, trigger_id: str): + session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) _delete_records( """select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -505,8 +549,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str): def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): - def del_plugin_trigger(trigger_id: str): - db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( + def del_plugin_trigger(session, trigger_id: str): + session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( synchronize_session=False ) @@ -519,8 +563,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): - def del_webhook_trigger(trigger_id: str): - db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( + def del_webhook_trigger(session, trigger_id: str): + session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( synchronize_session=False ) @@ -533,10 +577,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): - def del_schedule_plan(plan_id: str): - db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete( - synchronize_session=False - ) + def del_schedule_plan(session, plan_id: str): + session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False) _delete_records( """select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -547,8 +589,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): - def del_trigger_log(log_id: str): - db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) + def del_trigger_log(session, log_id: str): + session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) _delete_records( """select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -560,18 +602,22 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query_sql), params) - if rs.rowcount == 0: + with session_factory.create_session() as session: + rs = session.execute(sa.text(query_sql), params) + rows = rs.fetchall() + if not rows: break - for i in rs: + for i in rows: record_id = str(i.id) try: - delete_func(record_id) - db.session.commit() + delete_func(session, record_id) logger.info(click.style(f"Deleted {name} {record_id}", fg="green")) except Exception: logger.exception("Error occurred while deleting %s %s", name, record_id) - continue + # continue with next record even if one deletion fails + session.rollback() + break + session.commit() + rs.close() diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index c0ab2d0b41..c3c255fb17 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -5,8 +5,8 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Document, DocumentSegment @@ -25,52 +25,55 @@ def remove_document_from_index_task(document_id: str): logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).where(Document.id == document_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="red")) - db.session.close() - return + with session_factory.create_session() as session: + document = session.query(Document).where(Document.id == document_id).first() + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="red")) + return - if document.indexing_status != "completed": - logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red")) - db.session.close() - return + if document.indexing_status != "completed": + logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red")) + return - indexing_cache_key = f"document_{document.id}_indexing" + indexing_cache_key = f"document_{document.id}_indexing" - try: - dataset = document.dataset + try: + dataset = document.dataset - if not dataset: - raise Exception("Document has no dataset") + if not dataset: + raise Exception("Document has no dataset") - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() - index_node_ids = [segment.index_node_id for segment in segments] - if index_node_ids: - try: - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) - except Exception: - logger.exception("clean dataset %s from index failed", dataset.id) - # update segment to disable - db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( - { - DocumentSegment.enabled: False, - DocumentSegment.disabled_at: naive_utc_now(), - DocumentSegment.disabled_by: document.disabled_by, - DocumentSegment.updated_at: naive_utc_now(), - } - ) - db.session.commit() + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() + index_node_ids = [segment.index_node_id for segment in segments] + if index_node_ids: + try: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + except Exception: + logger.exception("clean dataset %s from index failed", dataset.id) + # update segment to disable + session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( + { + DocumentSegment.enabled: False, + DocumentSegment.disabled_at: naive_utc_now(), + DocumentSegment.disabled_by: document.disabled_by, + DocumentSegment.updated_at: naive_utc_now(), + } + ) + session.commit() - end_at = time.perf_counter() - logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green")) - except Exception: - logger.exception("remove document from index failed") - if not document.archived: - document.enabled = True - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + end_at = time.perf_counter() + logger.info( + click.style( + f"Document removed from index: {document.id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("remove document from index failed") + if not document.archived: + document.enabled = True + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 9d208647e6..f20b15ac83 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models import Account, Tenant @@ -29,97 +29,97 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id) """ start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) - return - user = db.session.query(Account).where(Account.id == user_id).first() - if not user: - logger.info(click.style(f"User not found: {user_id}", fg="red")) - return - tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first() - if not tenant: - raise ValueError("Tenant not found") - user.current_tenant = tenant + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + user = session.query(Account).where(Account.id == user_id).first() + if not user: + logger.info(click.style(f"User not found: {user_id}", fg="red")) + return + tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first() + if not tenant: + raise ValueError("Tenant not found") + user.current_tenant = tenant - for document_id in document_ids: - retry_indexing_cache_key = f"document_{document_id}_is_retried" - # check document limit - features = FeatureService.get_features(tenant.id) - try: - if features.billing.enabled: - vector_space = features.vector_space - if 0 < vector_space.limit <= vector_space.size: - raise ValueError( - "Your total number of documents plus the number of uploads have over the limit of " - "your subscription." - ) - except Exception as e: + for document_id in document_ids: + retry_indexing_cache_key = f"document_{document_id}_is_retried" + # check document limit + features = FeatureService.get_features(tenant.id) + try: + if features.billing.enabled: + vector_space = features.vector_space + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + document = ( + session.query(Document) + .where(Document.id == document_id, Document.dataset_id == dataset_id) + .first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + session.add(document) + session.commit() + redis_client.delete(retry_indexing_cache_key) + return + + logger.info(click.style(f"Start retry document: {document_id}", fg="green")) document = ( - db.session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() + session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) - if document: + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) + return + try: + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + session.add(document) + session.commit() + + if dataset.runtime_mode == "rag_pipeline": + rag_pipeline_service = RagPipelineService() + rag_pipeline_service.retry_error_document(dataset, document, user) + else: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) + except Exception as ex: document.indexing_status = "error" - document.error = str(e) + document.error = str(ex) document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - redis_client.delete(retry_indexing_cache_key) - return - - logger.info(click.style(f"Start retry document: {document_id}", fg="green")) - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + session.add(document) + session.commit() + logger.info(click.style(str(ex), fg="yellow")) + redis_client.delete(retry_indexing_cache_key) + logger.exception("retry_document_indexing_task failed, document_id: %s", document_id) + end_at = time.perf_counter() + logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except Exception as e: + logger.exception( + "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids ) - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) - return - try: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) - ).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - db.session.add(document) - db.session.commit() - - if dataset.runtime_mode == "rag_pipeline": - rag_pipeline_service = RagPipelineService() - rag_pipeline_service.retry_error_document(dataset, document, user) - else: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(retry_indexing_cache_key) - except Exception as ex: - document.indexing_status = "error" - document.error = str(ex) - document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - logger.info(click.style(str(ex), fg="yellow")) - redis_client.delete(retry_indexing_cache_key) - logger.exception("retry_document_indexing_task failed, document_id: %s", document_id) - end_at = time.perf_counter() - logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception( - "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids - ) - raise e - finally: - db.session.close() + raise e diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 0dc1d841f4..f1c8c56995 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -27,69 +27,71 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if dataset is None: - raise ValueError("Dataset not found") + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") - sync_indexing_cache_key = f"document_{document_id}_is_sync" - # check document limit - features = FeatureService.get_features(dataset.tenant_id) - try: - if features.billing.enabled: - vector_space = features.vector_space - if 0 < vector_space.limit <= vector_space.size: - raise ValueError( - "Your total number of documents plus the number of uploads have over the limit of " - "your subscription." - ) - except Exception as e: - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - if document: + sync_indexing_cache_key = f"document_{document_id}_is_sync" + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + document = ( + session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + session.add(document) + session.commit() + redis_client.delete(sync_indexing_cache_key) + return + + logger.info(click.style(f"Start sync website document: {document_id}", fg="green")) + document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) + return + try: + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + session.add(document) + session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) + except Exception as ex: document.indexing_status = "error" - document.error = str(e) + document.error = str(ex) document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - redis_client.delete(sync_indexing_cache_key) - return - - logger.info(click.style(f"Start sync website document: {document_id}", fg="green")) - document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) - return - try: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - db.session.add(document) - db.session.commit() - - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(sync_indexing_cache_key) - except Exception as ex: - document.indexing_status = "error" - document.error = str(ex) - document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - logger.info(click.style(str(ex), fg="yellow")) - redis_client.delete(sync_indexing_cache_key) - logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id) - end_at = time.perf_counter() - logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green")) + session.add(document) + session.commit() + logger.info(click.style(str(ex), fg="yellow")) + redis_client.delete(sync_indexing_cache_key) + logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id) + end_at = time.perf_counter() + logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index ee1d31aa91..d18ea2c23c 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -16,6 +16,7 @@ from sqlalchemy import func, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom +from core.db.session_factory import session_factory from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerInvokeEventResponse from core.plugin.impl.exc import PluginInvokeError @@ -27,7 +28,6 @@ from core.trigger.trigger_manager import TriggerManager from core.workflow.enums import NodeType, WorkflowExecutionStatus from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType, unlimited -from extensions.ext_database import db from models.enums import ( AppTriggerType, CreatorUserRole, @@ -257,7 +257,7 @@ def dispatch_triggered_workflow( tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id) ) trigger_entity: TriggerProviderEntity = provider_controller.entity - with Session(db.engine) as session: + with session_factory.create_session() as session: workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers) end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py index ed92f3f3c5..7698a1a6b8 100644 --- a/api/tasks/trigger_subscription_refresh_tasks.py +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -7,9 +7,9 @@ from celery import shared_task from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.utils.locks import build_trigger_refresh_lock_key -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService @@ -92,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None: logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id) try: now: int = _now_ts() - with Session(db.engine) as session: + with session_factory.create_session() as session: subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id) if not subscription: diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 7d145fb50c..3b3c6e5313 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -10,11 +10,10 @@ import logging from celery import shared_task from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from core.db.session_factory import session_factory from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom @@ -46,10 +45,7 @@ def save_workflow_execution_task( True if successful, False otherwise """ try: - # Create a new session for this task - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - with session_factory() as session: + with session_factory.create_session() as session: # Deserialize execution data execution = WorkflowExecution.model_validate(execution_data) diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 8f5127670f..b30a4ff15b 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -10,13 +10,12 @@ import logging from celery import shared_task from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from core.db.session_factory import session_factory from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -48,10 +47,7 @@ def save_workflow_node_execution_task( True if successful, False otherwise """ try: - # Create a new session for this task - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - with session_factory() as session: + with session_factory.create_session() as session: # Deserialize execution data execution = WorkflowNodeExecution.model_validate(execution_data) diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index f54e02a219..8c64d3ab27 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -1,15 +1,14 @@ import logging from celery import shared_task -from sqlalchemy.orm import sessionmaker +from core.db.session_factory import session_factory from core.workflow.nodes.trigger_schedule.exc import ( ScheduleExecutionError, ScheduleNotFoundError, TenantOwnerNotFoundError, ) from enums.quota_type import QuotaType, unlimited -from extensions.ext_database import db from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError @@ -33,10 +32,7 @@ def run_schedule_trigger(schedule_id: str) -> None: TenantOwnerNotFoundError: If no owner/admin for tenant ScheduleExecutionError: If workflow trigger fails """ - - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - with session_factory() as session: + with session_factory.create_session() as session: schedule = session.get(WorkflowSchedulePlan, schedule_id) if not schedule: raise ScheduleNotFoundError(f"Schedule {schedule_id} not found") diff --git a/api/templates/invite_member_mail_template_en-US.html b/api/templates/invite_member_mail_template_en-US.html index a07c5f4b16..7b296519f0 100644 --- a/api/templates/invite_member_mail_template_en-US.html +++ b/api/templates/invite_member_mail_template_en-US.html @@ -83,7 +83,30 @@

Dear {{ to }},

{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.

Click the button below to log in to Dify and join the workspace.

-

Login Here

+
+ Login Here +

+ If the button doesn't work, copy and paste this link into your browser:
+ + {{ url }} + +

+

Best regards,

Dify Team

diff --git a/api/templates/invite_member_mail_template_zh-CN.html b/api/templates/invite_member_mail_template_zh-CN.html index 27709a3c6d..c05b3ddb67 100644 --- a/api/templates/invite_member_mail_template_zh-CN.html +++ b/api/templates/invite_member_mail_template_zh-CN.html @@ -83,7 +83,30 @@

尊敬的 {{ to }},

{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。

点击下方按钮即可登录 Dify 并且加入空间。

-

在此登录

+
+ 在此登录 +

+ 如果按钮无法使用,请将以下链接复制到浏览器打开:
+ + {{ url }} + +

+

此致,

Dify 团队

diff --git a/api/templates/register_email_when_account_exist_template_en-US.html b/api/templates/register_email_when_account_exist_template_en-US.html index ac5042c274..e2bb99c989 100644 --- a/api/templates/register_email_when_account_exist_template_en-US.html +++ b/api/templates/register_email_when_account_exist_template_en-US.html @@ -115,7 +115,30 @@ We noticed you tried to sign up, but this email is already registered with an existing account. Please log in here:

- Log In +
+ Log In +

+ If the button doesn't work, copy and paste this link into your browser:
+ + {{ login_url }} + +

+

If you forgot your password, you can reset it here: Reset Password diff --git a/api/templates/register_email_when_account_exist_template_zh-CN.html b/api/templates/register_email_when_account_exist_template_zh-CN.html index 326b58343a..6a5bbd135b 100644 --- a/api/templates/register_email_when_account_exist_template_zh-CN.html +++ b/api/templates/register_email_when_account_exist_template_zh-CN.html @@ -115,7 +115,30 @@ 我们注意到您尝试注册,但此电子邮件已注册。 请在此登录:

- 登录 +
+ 登录 +

+ 如果按钮无法使用,请将以下链接复制到浏览器打开:
+ + {{ login_url }} + +

+

如果您忘记了密码,可以在此重置: 重置密码

diff --git a/api/templates/without-brand/invite_member_mail_template_en-US.html b/api/templates/without-brand/invite_member_mail_template_en-US.html index f9157284fa..687ece617a 100644 --- a/api/templates/without-brand/invite_member_mail_template_en-US.html +++ b/api/templates/without-brand/invite_member_mail_template_en-US.html @@ -92,12 +92,34 @@ platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.

Click the button below to log in to {{application_title}} and join the workspace.

-

Login Here

+
+ Login Here +

+ If the button doesn't work, copy and paste this link into your browser:
+ + {{ url }} + +

+

Best regards,

{{application_title}} Team

- \ No newline at end of file + diff --git a/api/templates/without-brand/invite_member_mail_template_zh-CN.html b/api/templates/without-brand/invite_member_mail_template_zh-CN.html index e787c90914..9ca1ef62cd 100644 --- a/api/templates/without-brand/invite_member_mail_template_zh-CN.html +++ b/api/templates/without-brand/invite_member_mail_template_zh-CN.html @@ -81,7 +81,30 @@

尊敬的 {{ to }},

{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。

点击下方按钮即可登录 {{application_title}} 并且加入空间。

-

在此登录

+
+ 在此登录 +

+ 如果按钮无法使用,请将以下链接复制到浏览器打开:
+ + {{ url }} + +

+

此致,

{{application_title}} 团队

diff --git a/api/templates/without-brand/register_email_when_account_exist_template_en-US.html b/api/templates/without-brand/register_email_when_account_exist_template_en-US.html index 2e74956e14..2a26aac5b9 100644 --- a/api/templates/without-brand/register_email_when_account_exist_template_en-US.html +++ b/api/templates/without-brand/register_email_when_account_exist_template_en-US.html @@ -111,7 +111,30 @@ We noticed you tried to sign up, but this email is already registered with an existing account. Please log in here:

- Log In +
+ Log In +

+ If the button doesn't work, copy and paste this link into your browser:
+ + {{ login_url }} + +

+

If you forgot your password, you can reset it here: Reset Password diff --git a/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html b/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html index a315f9154d..b9f93eb0fc 100644 --- a/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html +++ b/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html @@ -111,7 +111,30 @@ 我们注意到您尝试注册,但此电子邮件已注册。 请在此登录:

- 登录 +
+ 登录 +

+ 如果按钮无法使用,请将以下链接复制到浏览器打开:
+ + {{ login_url }} + +

+

如果您忘记了密码,可以在此重置: 重置密码

diff --git a/api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml b/api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml new file mode 100644 index 0000000000..5b2a6260e9 --- /dev/null +++ b/api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml @@ -0,0 +1,158 @@ +app: + description: Validate v1 Variable Assigner blocks streaming until conversation variable is updated. + icon: 🤖 + icon_background: '#FFEAD5' + mode: advanced-chat + name: test_streaming_conversation_variables_v1_overwrite + use_icon_as_answer_icon: false +dependencies: [] +kind: app +version: 0.5.0 +workflow: + conversation_variables: + - description: '' + id: 6ddf2d7f-3d1b-4bb0-9a5e-9b0c87c7b5e6 + name: conv_var + selector: + - conversation + - conv_var + value: default + value_type: string + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: assigner + id: start-source-assigner-target + source: start + sourceHandle: source + target: assigner + targetHandle: target + type: custom + zIndex: 0 + - data: + isInLoop: false + sourceType: assigner + targetType: answer + id: assigner-source-answer-target + source: assigner + sourceHandle: source + target: answer + targetHandle: target + type: custom + zIndex: 0 + nodes: + - data: + desc: '' + selected: false + title: Start + type: start + variables: [] + height: 54 + id: start + position: + x: 30 + y: 253 + positionAbsolute: + x: 30 + y: 253 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + answer: 'Current Value Of `conv_var` is:{{#conversation.conv_var#}}' + desc: '' + selected: false + title: Answer + type: answer + variables: [] + height: 106 + id: answer + position: + x: 638 + y: 253 + positionAbsolute: + x: 638 + y: 253 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + assigned_variable_selector: + - conversation + - conv_var + desc: '' + input_variable_selector: + - sys + - query + selected: false + title: Variable Assigner + type: assigner + write_mode: over-write + height: 84 + id: assigner + position: + x: 334 + y: 253 + positionAbsolute: + x: 334 + y: 253 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + viewport: + x: 0 + y: 0 + zoom: 0.7 diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index acc268f1d4..39effbab58 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -103,6 +103,8 @@ SMTP_USERNAME=123 SMTP_PASSWORD=abc SMTP_USE_TLS=true SMTP_OPPORTUNISTIC_TLS=false +# Optional: override the local hostname used for SMTP HELO/EHLO +SMTP_LOCAL_HOSTNAME= # Sentry configuration SENTRY_DSN= diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index d59d5dc0fe..5012defdad 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -48,10 +48,6 @@ class MockModelClass(PluginModelClient): en_US="https://example.com/icon_small.png", zh_Hans="https://example.com/icon_small.png", ), - icon_large=I18nObject( - en_US="https://example.com/icon_large.png", - zh_Hans="https://example.com/icon_large.png", - ), supported_model_types=[ModelType.LLM], configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], models=[ diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 7cdc3cb205..f46d1bf5db 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -4,8 +4,8 @@ from unittest.mock import patch import pytest from sqlalchemy import delete +from core.db.session_factory import session_factory from core.variables.segments import StringSegment -from extensions.ext_database import db from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -16,26 +16,23 @@ from tasks.remove_app_and_related_data_task import _delete_draft_variables, dele @pytest.fixture def app_and_tenant(flask_req_ctx): tenant_id = uuid.uuid4() - tenant = Tenant( - id=tenant_id, - name="test_tenant", - ) - db.session.add(tenant) + with session_factory.create_session() as session: + tenant = Tenant(name="test_tenant") + session.add(tenant) + session.flush() - app = App( - tenant_id=tenant_id, # Now tenant.id will have a value - name=f"Test App for tenant {tenant.id}", - mode="workflow", - enable_site=True, - enable_api=True, - ) - db.session.add(app) - db.session.flush() - yield (tenant, app) + app = App( + tenant_id=tenant.id, + name=f"Test App for tenant {tenant.id}", + mode="workflow", + enable_site=True, + enable_api=True, + ) + session.add(app) + session.flush() - # Cleanup with proper error handling - db.session.delete(app) - db.session.delete(tenant) + # return detached objects (ids will be used by tests) + return (tenant, app) class TestDeleteDraftVariablesIntegration: @@ -44,334 +41,285 @@ class TestDeleteDraftVariablesIntegration: """Create test data with apps and draft variables.""" tenant, app = app_and_tenant - # Create a second app for testing - app2 = App( - tenant_id=tenant.id, - name="Test App 2", - mode="workflow", - enable_site=True, - enable_api=True, - ) - db.session.add(app2) - db.session.commit() - - # Create draft variables for both apps - variables_app1 = [] - variables_app2 = [] - - for i in range(5): - var1 = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id=f"node_{i}", - name=f"var_{i}", - value=StringSegment(value="test_value"), - node_execution_id=str(uuid.uuid4()), + with session_factory.create_session() as session: + app2 = App( + tenant_id=tenant.id, + name="Test App 2", + mode="workflow", + enable_site=True, + enable_api=True, ) - db.session.add(var1) - variables_app1.append(var1) + session.add(app2) + session.flush() - var2 = WorkflowDraftVariable.new_node_variable( - app_id=app2.id, - node_id=f"node_{i}", - name=f"var_{i}", - value=StringSegment(value="test_value"), - node_execution_id=str(uuid.uuid4()), - ) - db.session.add(var2) - variables_app2.append(var2) + variables_app1 = [] + variables_app2 = [] + for i in range(5): + var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var1) + variables_app1.append(var1) - # Commit all the variables to the database - db.session.commit() + var2 = WorkflowDraftVariable.new_node_variable( + app_id=app2.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var2) + variables_app2.append(var2) + session.commit() + + app2_id = app2.id yield { "app1": app, - "app2": app2, + "app2": App(id=app2_id), # dummy with id to avoid open session "tenant": tenant, "variables_app1": variables_app1, "variables_app2": variables_app2, } - # Cleanup - refresh session and check if objects still exist - db.session.rollback() # Clear any pending changes - - # Clean up remaining variables - cleanup_query = ( - delete(WorkflowDraftVariable) - .where( - WorkflowDraftVariable.app_id.in_([app.id, app2.id]), + with session_factory.create_session() as session: + cleanup_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.app_id.in_([app.id, app2_id])) + .execution_options(synchronize_session=False) ) - .execution_options(synchronize_session=False) - ) - db.session.execute(cleanup_query) - - # Clean up app2 - app2_obj = db.session.get(App, app2.id) - if app2_obj: - db.session.delete(app2_obj) - - db.session.commit() + session.execute(cleanup_query) + app2_obj = session.get(App, app2_id) + if app2_obj: + session.delete(app2_obj) + session.commit() def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data): - """Test that batch deletion only removes variables for the specified app.""" data = setup_test_data app1_id = data["app1"].id app2_id = data["app2"].id - # Verify initial state - app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + with session_factory.create_session() as session: + app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() assert app1_vars_before == 5 assert app2_vars_before == 5 - # Delete app1 variables deleted_count = delete_draft_variables_batch(app1_id, batch_size=10) - - # Verify results assert deleted_count == 5 - app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() - - assert app1_vars_after == 0 # All app1 variables deleted - assert app2_vars_after == 5 # App2 variables unchanged + with session_factory.create_session() as session: + app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + assert app1_vars_after == 0 + assert app2_vars_after == 5 def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data): - """Test batch deletion with small batch size processes all records.""" data = setup_test_data app1_id = data["app1"].id - # Use small batch size to force multiple batches deleted_count = delete_draft_variables_batch(app1_id, batch_size=2) - assert deleted_count == 5 - # Verify all variables are deleted - remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + with session_factory.create_session() as session: + remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() assert remaining_vars == 0 def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data): - """Test that deleting variables for nonexistent app returns 0.""" - nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format - + nonexistent_app_id = str(uuid.uuid4()) deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100) - assert deleted_count == 0 def test_delete_draft_variables_wrapper_function(self, setup_test_data): - """Test that _delete_draft_variables wrapper function works correctly.""" data = setup_test_data app1_id = data["app1"].id - # Verify initial state - vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + with session_factory.create_session() as session: + vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() assert vars_before == 5 - # Call wrapper function deleted_count = _delete_draft_variables(app1_id) - - # Verify results assert deleted_count == 5 - vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + with session_factory.create_session() as session: + vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() assert vars_after == 0 def test_batch_deletion_handles_large_dataset(self, app_and_tenant): - """Test batch deletion with larger dataset to verify batching logic.""" tenant, app = app_and_tenant - - # Create many draft variables - variables = [] - for i in range(25): - var = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id=f"node_{i}", - name=f"var_{i}", - value=StringSegment(value="test_value"), - node_execution_id=str(uuid.uuid4()), - ) - db.session.add(var) - variables.append(var) - variable_ids = [i.id for i in variables] - - # Commit the variables to the database - db.session.commit() + variable_ids: list[str] = [] + with session_factory.create_session() as session: + variables = [] + for i in range(25): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + variables.append(var) + session.commit() + variable_ids = [v.id for v in variables] try: - # Use small batch size to force multiple batches deleted_count = delete_draft_variables_batch(app.id, batch_size=8) - assert deleted_count == 25 - - # Verify all variables are deleted - remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() - assert remaining_vars == 0 - + with session_factory.create_session() as session: + remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() + assert remaining == 0 finally: - query = ( - delete(WorkflowDraftVariable) - .where( - WorkflowDraftVariable.id.in_(variable_ids), + with session_factory.create_session() as session: + query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) + .execution_options(synchronize_session=False) ) - .execution_options(synchronize_session=False) - ) - db.session.execute(query) + session.execute(query) + session.commit() class TestDeleteDraftVariablesWithOffloadIntegration: - """Integration tests for draft variable deletion with Offload data.""" - @pytest.fixture def setup_offload_test_data(self, app_and_tenant): - """Create test data with draft variables that have associated Offload files.""" tenant, app = app_and_tenant - - # Create UploadFile records + from core.variables.types import SegmentType from libs.datetime_utils import naive_utc_now - upload_file1 = UploadFile( - tenant_id=tenant.id, - storage_type="local", - key="test/file1.json", - name="file1.json", - size=1024, - extension="json", - mime_type="application/json", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=str(uuid.uuid4()), - created_at=naive_utc_now(), - used=False, - ) - upload_file2 = UploadFile( - tenant_id=tenant.id, - storage_type="local", - key="test/file2.json", - name="file2.json", - size=2048, - extension="json", - mime_type="application/json", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=str(uuid.uuid4()), - created_at=naive_utc_now(), - used=False, - ) - db.session.add(upload_file1) - db.session.add(upload_file2) - db.session.flush() + with session_factory.create_session() as session: + upload_file1 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file1.json", + name="file1.json", + size=1024, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + upload_file2 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file2.json", + name="file2.json", + size=2048, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + session.add(upload_file1) + session.add(upload_file2) + session.flush() - # Create WorkflowDraftVariableFile records - from core.variables.types import SegmentType + var_file1 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file1.id, + size=1024, + length=10, + value_type=SegmentType.STRING, + ) + var_file2 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file2.id, + size=2048, + length=20, + value_type=SegmentType.OBJECT, + ) + session.add(var_file1) + session.add(var_file2) + session.flush() - var_file1 = WorkflowDraftVariableFile( - tenant_id=tenant.id, - app_id=app.id, - user_id=str(uuid.uuid4()), - upload_file_id=upload_file1.id, - size=1024, - length=10, - value_type=SegmentType.STRING, - ) - var_file2 = WorkflowDraftVariableFile( - tenant_id=tenant.id, - app_id=app.id, - user_id=str(uuid.uuid4()), - upload_file_id=upload_file2.id, - size=2048, - length=20, - value_type=SegmentType.OBJECT, - ) - db.session.add(var_file1) - db.session.add(var_file2) - db.session.flush() + draft_var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_1", + name="large_var_1", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file1.id, + ) + draft_var2 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_2", + name="large_var_2", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file2.id, + ) + draft_var3 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_3", + name="regular_var", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(draft_var1) + session.add(draft_var2) + session.add(draft_var3) + session.commit() - # Create WorkflowDraftVariable records with file associations - draft_var1 = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id="node_1", - name="large_var_1", - value=StringSegment(value="truncated..."), - node_execution_id=str(uuid.uuid4()), - file_id=var_file1.id, - ) - draft_var2 = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id="node_2", - name="large_var_2", - value=StringSegment(value="truncated..."), - node_execution_id=str(uuid.uuid4()), - file_id=var_file2.id, - ) - # Create a regular variable without Offload data - draft_var3 = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id="node_3", - name="regular_var", - value=StringSegment(value="regular_value"), - node_execution_id=str(uuid.uuid4()), - ) + data = { + "app": app, + "tenant": tenant, + "upload_files": [upload_file1, upload_file2], + "variable_files": [var_file1, var_file2], + "draft_variables": [draft_var1, draft_var2, draft_var3], + } - db.session.add(draft_var1) - db.session.add(draft_var2) - db.session.add(draft_var3) - db.session.commit() + yield data - yield { - "app": app, - "tenant": tenant, - "upload_files": [upload_file1, upload_file2], - "variable_files": [var_file1, var_file2], - "draft_variables": [draft_var1, draft_var2, draft_var3], - } - - # Cleanup - db.session.rollback() - - # Clean up any remaining records - for table, ids in [ - (WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]), - (WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]), - (UploadFile, [uf.id for uf in [upload_file1, upload_file2]]), - ]: - cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) - db.session.execute(cleanup_query) - - db.session.commit() + with session_factory.create_session() as session: + session.rollback() + for table, ids in [ + (WorkflowDraftVariable, [v.id for v in data["draft_variables"]]), + (WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]), + (UploadFile, [uf.id for uf in data["upload_files"]]), + ]: + cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) + session.execute(cleanup_query) + session.commit() @patch("extensions.ext_storage.storage") def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): - """Test that deleting draft variables also cleans up associated Offload data.""" data = setup_offload_test_data app_id = data["app"].id - - # Mock storage deletion to succeed mock_storage.delete.return_value = None - # Verify initial state - draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = db.session.query(WorkflowDraftVariableFile).count() - upload_files_before = db.session.query(UploadFile).count() - - assert draft_vars_before == 3 # 2 with files + 1 regular + with session_factory.create_session() as session: + draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_before = session.query(WorkflowDraftVariableFile).count() + upload_files_before = session.query(UploadFile).count() + assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 - # Delete draft variables deleted_count = delete_draft_variables_batch(app_id, batch_size=10) - - # Verify results assert deleted_count == 3 - # Check that all draft variables are deleted - draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() assert draft_vars_after == 0 - # Check that associated Offload data is cleaned up - var_files_after = db.session.query(WorkflowDraftVariableFile).count() - upload_files_after = db.session.query(UploadFile).count() + with session_factory.create_session() as session: + var_files_after = session.query(WorkflowDraftVariableFile).count() + upload_files_after = session.query(UploadFile).count() + assert var_files_after == 0 + assert upload_files_after == 0 - assert var_files_after == 0 # All variable files should be deleted - assert upload_files_after == 0 # All upload files should be deleted - - # Verify storage deletion was called for both files assert mock_storage.delete.call_count == 2 storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list] assert "test/file1.json" in storage_keys_deleted @@ -379,92 +327,71 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @patch("extensions.ext_storage.storage") def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): - """Test that database cleanup continues even when storage deletion fails.""" data = setup_offload_test_data app_id = data["app"].id - - # Mock storage deletion to fail for first file, succeed for second mock_storage.delete.side_effect = [Exception("Storage error"), None] - # Delete draft variables deleted_count = delete_draft_variables_batch(app_id, batch_size=10) - - # Verify that all draft variables are still deleted assert deleted_count == 3 - draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() assert draft_vars_after == 0 - # Database cleanup should still succeed even with storage errors - var_files_after = db.session.query(WorkflowDraftVariableFile).count() - upload_files_after = db.session.query(UploadFile).count() - + with session_factory.create_session() as session: + var_files_after = session.query(WorkflowDraftVariableFile).count() + upload_files_after = session.query(UploadFile).count() assert var_files_after == 0 assert upload_files_after == 0 - # Verify storage deletion was attempted for both files assert mock_storage.delete.call_count == 2 @patch("extensions.ext_storage.storage") def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data): - """Test deletion with mix of variables with and without Offload data.""" data = setup_offload_test_data app_id = data["app"].id - - # Create additional app with only regular variables (no offload data) tenant = data["tenant"] - app2 = App( - tenant_id=tenant.id, - name="Test App 2", - mode="workflow", - enable_site=True, - enable_api=True, - ) - db.session.add(app2) - db.session.flush() - # Add regular variables to app2 - regular_vars = [] - for i in range(3): - var = WorkflowDraftVariable.new_node_variable( - app_id=app2.id, - node_id=f"node_{i}", - name=f"var_{i}", - value=StringSegment(value="regular_value"), - node_execution_id=str(uuid.uuid4()), + with session_factory.create_session() as session: + app2 = App( + tenant_id=tenant.id, + name="Test App 2", + mode="workflow", + enable_site=True, + enable_api=True, ) - db.session.add(var) - regular_vars.append(var) - db.session.commit() + session.add(app2) + session.flush() + + for i in range(3): + var = WorkflowDraftVariable.new_node_variable( + app_id=app2.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + session.commit() try: - # Mock storage deletion mock_storage.delete.return_value = None - - # Delete variables for app2 (no offload data) deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10) assert deleted_count_app2 == 3 - - # Verify storage wasn't called for app2 (no offload files) mock_storage.delete.assert_not_called() - # Delete variables for original app (with offload data) deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10) assert deleted_count_app1 == 3 - - # Now storage should be called for the offload files assert mock_storage.delete.call_count == 2 - finally: - # Cleanup app2 and its variables - cleanup_vars_query = ( - delete(WorkflowDraftVariable) - .where(WorkflowDraftVariable.app_id == app2.id) - .execution_options(synchronize_session=False) - ) - db.session.execute(cleanup_vars_query) - - app2_obj = db.session.get(App, app2.id) - if app2_obj: - db.session.delete(app2_obj) - db.session.commit() + with session_factory.create_session() as session: + cleanup_vars_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.app_id == app2.id) + .execution_options(synchronize_session=False) + ) + session.execute(cleanup_vars_query) + app2_obj = session.get(App, app2.id) + if app2_obj: + session.delete(app2_obj) + session.commit() diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index 94903cf796..c8eb9ec3e4 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -7,11 +7,14 @@ CODE_LANGUAGE = CodeLanguage.JINJA2 def test_jinja2(): + """Test basic Jinja2 template rendering.""" template = "Hello {{template}}" + # Template must be base64 encoded to match the new safe embedding approach + template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") code = ( Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64) .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) ) result = CodeExecutor.execute_code( @@ -21,6 +24,7 @@ def test_jinja2(): def test_jinja2_with_code_template(): + """Test template rendering via the high-level workflow API.""" result = CodeExecutor.execute_workflow_code_template( language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} ) @@ -28,7 +32,64 @@ def test_jinja2_with_code_template(): def test_jinja2_get_runner_script(): + """Test that runner script contains required placeholders.""" runner_script = Jinja2TemplateTransformer.get_runner_script() - assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1 assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 + + +def test_jinja2_template_with_special_characters(): + """ + Test that templates with special characters (quotes, newlines) render correctly. + This is a regression test for issue #26818 where textarea pre-fill values + containing special characters would break template rendering. + """ + # Template with triple quotes, single quotes, double quotes, and newlines + template = """ + + + +

Status: "{{ status }}"

+
'''code block'''
+ +""" + inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"} + + result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) + + # Verify the template rendered correctly with all special characters + output = result["result"] + assert 'value="TASK-123"' in output + assert "" in output + assert 'Status: "completed"' in output + assert "'''code block'''" in output + + +def test_jinja2_template_with_html_textarea_prefill(): + """ + Specific test for HTML textarea with Jinja2 variable pre-fill. + Verifies fix for issue #26818. + """ + template = "" + notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes." + inputs = {"notes": notes_content} + + result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) + + expected_output = f"" + assert result["result"] == expected_output + + +def test_jinja2_assemble_runner_script_encodes_template(): + """Test that assemble_runner_script properly base64 encodes the template.""" + template = "Hello {{ name }}!" + inputs = {"name": "World"} + + script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs) + + # The template should be base64 encoded in the script + template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") + assert template_b64 in script + # The raw template should NOT appear in the script (it's encoded) + assert "Hello {{ name }}!" not in script diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e421e4ff36..1a9d69b2d2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -5,12 +5,13 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import NodeRunResult from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom @@ -67,6 +68,16 @@ def init_code_node(code_config: dict): config=code_config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + code_limits=CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index d814da8ec7..1bcac3b5fe 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,11 +5,11 @@ from urllib.parse import urlencode import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.http_request.node import HttpRequestNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d268c5da22..c361bfcc6f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -5,13 +5,13 @@ from collections.abc import Generator from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 654db59bec..7445699a86 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -4,11 +4,11 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 3bcb9a3a34..bc03ce1b96 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -4,10 +4,10 @@ import uuid import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index d666f0ebe2..cfbef52c93 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -3,12 +3,12 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 72469ad646..dcf31aeca7 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -35,6 +35,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.enums import WorkflowExecutionStatus from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import GraphRunPausedEvent from core.workflow.runtime.graph_runtime_state import GraphRuntimeState from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState @@ -569,10 +570,10 @@ class TestPauseStatePersistenceLayerTestContainers: """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() - # Don't initialize - graph_runtime_state should not be set + # Don't initialize - graph_runtime_state should be uninitialized event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) - # Act & Assert - Should raise AttributeError - with pytest.raises(AttributeError): + # Act & Assert - Should raise GraphEngineLayerNotInitializedError + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 07fa0f7c98..06d55177eb 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.entities import GraphInitParams @@ -16,7 +17,6 @@ from core.workflow.enums import WorkflowType from core.workflow.graph import Graph from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 3be2798085..6eedbd6cfa 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -172,7 +172,6 @@ class TestAgentService: # Create app model config app_model_config = AppModelConfig( - id=fake.uuid4(), app_id=app.id, provider="openai", model_id="gpt-3.5-turbo", @@ -180,6 +179,7 @@ class TestAgentService: model="gpt-3.5-turbo", agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) + app_model_config.id = fake.uuid4() db.session.add(app_model_config) db.session.commit() @@ -230,7 +230,6 @@ class TestAgentService: # Create first agent thought thought1 = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to analyze the user's request", @@ -257,7 +256,6 @@ class TestAgentService: # Create second agent thought thought2 = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=2, thought="Based on the analysis, I can provide a response", @@ -415,7 +413,6 @@ class TestAgentService: # Create app model config app_model_config = AppModelConfig( - id=fake.uuid4(), app_id=app.id, provider="openai", model_id="gpt-3.5-turbo", @@ -423,6 +420,7 @@ class TestAgentService: model="gpt-3.5-turbo", agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) + app_model_config.id = fake.uuid4() db.session.add(app_model_config) db.session.commit() @@ -487,7 +485,6 @@ class TestAgentService: # Create app model config app_model_config = AppModelConfig( - id=fake.uuid4(), app_id=app.id, provider="openai", model_id="gpt-3.5-turbo", @@ -495,6 +492,7 @@ class TestAgentService: model="gpt-3.5-turbo", agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) + app_model_config.id = fake.uuid4() db.session.add(app_model_config) db.session.commit() @@ -545,7 +543,6 @@ class TestAgentService: # Create agent thought with tool error thought_with_error = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to analyze the user's request", @@ -759,7 +756,6 @@ class TestAgentService: # Create agent thought with multiple tools complex_thought = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to use multiple tools to complete this task", @@ -877,7 +873,6 @@ class TestAgentService: # Create agent thought with files thought_with_files = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to process some files", @@ -957,7 +952,6 @@ class TestAgentService: # Create agent thought with empty tool data empty_thought = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to analyze the user's request", @@ -999,7 +993,6 @@ class TestAgentService: # Create agent thought with malformed JSON malformed_thought = MessageAgentThought( - id=fake.uuid4(), message_id=message.id, position=1, thought="I need to analyze the user's request", diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index da73122cd7..4f5190e533 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -220,6 +220,23 @@ class TestAnnotationService: # Note: In this test, no annotation setting exists, so task should not be called mock_external_service_dependencies["add_task"].delay.assert_not_called() + def test_insert_app_annotation_directly_requires_question( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Question must be provided when inserting annotations directly. + """ + fake = Faker() + app, _ = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + annotation_args = { + "question": None, + "answer": fake.text(max_nb_chars=200), + } + + with pytest.raises(ValueError): + AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + def test_insert_app_annotation_directly_app_not_found( self, db_session_with_containers, mock_external_service_dependencies ): @@ -444,6 +461,78 @@ class TestAnnotationService: assert total == 1 assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with special characters in content + annotation_with_percent = { + "question": "Question with 50% discount", + "answer": "Answer about 50% discount offer", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id) + + annotation_with_underscore = { + "question": "Question with test_data", + "answer": "Answer about test_data value", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id) + + annotation_with_backslash = { + "question": "Question with path\\to\\file", + "answer": "Answer about path\\to\\file location", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id) + + # Create annotation that should NOT match (contains % but as part of different text) + annotation_no_match = { + "question": "Question with 100% different", + "answer": "Answer about 100% different content", + } + AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id) + + # Test 1: Search with % character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content + + # Test 2: Search with _ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="test_data" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content + + # Test 3: Search with \ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="path\\to\\file" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + # Should only find the 50% annotation, not the 100% one + assert total == 1 + assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) + def test_get_annotation_list_by_app_id_app_not_found( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 119f92d772..e2a450b90c 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -226,26 +226,27 @@ class TestAppDslService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create model config for the app - model_config = AppModelConfig() - model_config.id = fake.uuid4() - model_config.app_id = app.id - model_config.provider = "openai" - model_config.model_id = "gpt-3.5-turbo" - model_config.model = json.dumps( - { - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 1000, - "temperature": 0.7, - }, - } + model_config = AppModelConfig( + app_id=app.id, + provider="openai", + model_id="gpt-3.5-turbo", + model=json.dumps( + { + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 1000, + "temperature": 0.7, + }, + } + ), + pre_prompt="You are a helpful assistant.", + prompt_type="simple", + created_by=account.id, + updated_by=account.id, ) - model_config.pre_prompt = "You are a helpful assistant." - model_config.prompt_type = "simple" - model_config.created_by = account.id - model_config.updated_by = account.id + model_config.id = fake.uuid4() # Set the app_model_config_id to link the config app.app_model_config_id = model_config.id diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index e53392bcef..745d6c97b0 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -7,7 +7,9 @@ from constants.model_template import default_app_templates from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService class TestAppService: @@ -71,6 +73,9 @@ class TestAppService: } # Create app + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -109,6 +114,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Test different app modes @@ -159,6 +167,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() created_app = app_service.create_app(tenant.id, app_args, account) @@ -194,6 +205,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create multiple apps @@ -245,6 +259,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create apps with different modes @@ -315,6 +332,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create an app @@ -392,6 +412,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -458,6 +481,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -508,6 +534,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -562,6 +591,9 @@ class TestAppService: "icon_background": "#74B9FF", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -617,6 +649,9 @@ class TestAppService: "icon_background": "#A29BFE", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -672,6 +707,9 @@ class TestAppService: "icon_background": "#FD79A8", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -720,6 +758,9 @@ class TestAppService: "icon_background": "#E17055", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -768,6 +809,9 @@ class TestAppService: "icon_background": "#00B894", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -826,6 +870,9 @@ class TestAppService: "icon_background": "#6C5CE7", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -862,6 +909,9 @@ class TestAppService: "icon_background": "#FDCB6E", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -899,6 +949,9 @@ class TestAppService: "icon_background": "#E84393", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -947,8 +1000,132 @@ class TestAppService: "icon_background": "#D63031", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Attempt to create app with invalid mode with pytest.raises(ValueError, match="invalid mode value"): app_service.create_app(tenant.id, app_args, account) + + def test_get_apps_with_special_characters_in_name( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test app retrieval with special characters in name search to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in name search are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Import here to avoid circular dependency + from services.app_service import AppService + + app_service = AppService() + + # Create apps with special characters in names + app_with_percent = app_service.create_app( + tenant.id, + { + "name": "App with 50% discount", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_underscore = app_service.create_app( + tenant.id, + { + "name": "test_data_app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_backslash = app_service.create_app( + tenant.id, + { + "name": "path\\to\\app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Create app that should NOT match + app_no_match = app_service.create_app( + tenant.id, + { + "name": "100% different", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Test 1: Search with % character + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "App with 50% discount" + + # Test 2: Search with _ character + args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "test_data_app" + + # Test 3: Search with \ character + args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "path\\to\\app" + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert all("50%" in app.name for app in paginated_apps.items) diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py new file mode 100644 index 0000000000..76708b36b1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -0,0 +1,365 @@ +import json +from unittest.mock import patch + +import pytest + +from extensions.ext_redis import redis_client +from services.billing_service import BillingService + + +class TestBillingServiceGetPlanBulkWithCache: + """ + Comprehensive integration tests for get_plan_bulk_with_cache using testcontainers. + + This test class covers all major scenarios: + - Cache hit/miss scenarios + - Redis operation failures and fallback behavior + - Invalid cache data handling + - TTL expiration handling + - Error recovery and logging + """ + + @pytest.fixture(autouse=True) + def setup_redis_cleanup(self, flask_app_with_containers): + """Clean up Redis cache before and after each test.""" + with flask_app_with_containers.app_context(): + # Clean up before test + yield + # Clean up after test + # Delete all test cache keys + pattern = f"{BillingService._PLAN_CACHE_KEY_PREFIX}*" + keys = redis_client.keys(pattern) + if keys: + redis_client.delete(*keys) + + def _create_test_plan_data(self, plan: str = "sandbox", expiration_date: int = 1735689600): + """Helper to create test SubscriptionPlan data.""" + return {"plan": plan, "expiration_date": expiration_date} + + def _set_cache(self, tenant_id: str, plan_data: dict, ttl: int = 600): + """Helper to set cache data in Redis.""" + cache_key = BillingService._make_plan_cache_key(tenant_id) + json_str = json.dumps(plan_data) + redis_client.setex(cache_key, ttl, json_str) + + def _get_cache(self, tenant_id: str): + """Helper to get cache data from Redis.""" + cache_key = BillingService._make_plan_cache_key(tenant_id) + value = redis_client.get(cache_key) + if value: + if isinstance(value, bytes): + return value.decode("utf-8") + return value + return None + + def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers): + """Test bulk plan retrieval when all tenants are in cache.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2", "tenant-3"] + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + "tenant-3": self._create_test_plan_data("team", 1798761600), + } + + # Pre-populate cache + for tenant_id, plan_data in expected_plans.items(): + self._set_cache(tenant_id, plan_data) + + # Act + with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 3 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-1"]["expiration_date"] == 1735689600 + assert result["tenant-2"]["plan"] == "professional" + assert result["tenant-2"]["expiration_date"] == 1767225600 + assert result["tenant-3"]["plan"] == "team" + assert result["tenant-3"]["expiration_date"] == 1798761600 + + # Verify API was not called + mock_get_plan_bulk.assert_not_called() + + def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers): + """Test bulk plan retrieval when all tenants are not in cache.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + } + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 2 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + + # Verify API was called with correct tenant_ids + mock_get_plan_bulk.assert_called_once_with(tenant_ids) + + # Verify data was written to cache + cached_1 = self._get_cache("tenant-1") + cached_2 = self._get_cache("tenant-2") + assert cached_1 is not None + assert cached_2 is not None + + # Verify cache content + cached_data_1 = json.loads(cached_1) + cached_data_2 = json.loads(cached_2) + assert cached_data_1 == expected_plans["tenant-1"] + assert cached_data_2 == expected_plans["tenant-2"] + + # Verify TTL is set + cache_key_1 = BillingService._make_plan_cache_key("tenant-1") + ttl_1 = redis_client.ttl(cache_key_1) + assert ttl_1 > 0 + assert ttl_1 <= 600 # Should be <= 600 seconds + + def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers): + """Test bulk plan retrieval when some tenants are in cache, some are not.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2", "tenant-3"] + # Pre-populate cache for tenant-1 and tenant-2 + self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600)) + self._set_cache("tenant-2", self._create_test_plan_data("professional", 1767225600)) + + # tenant-3 is not in cache + missing_plan = {"tenant-3": self._create_test_plan_data("team", 1798761600)} + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=missing_plan) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 3 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + assert result["tenant-3"]["plan"] == "team" + + # Verify API was called only for missing tenant + mock_get_plan_bulk.assert_called_once_with(["tenant-3"]) + + # Verify tenant-3 data was written to cache + cached_3 = self._get_cache("tenant-3") + assert cached_3 is not None + cached_data_3 = json.loads(cached_3) + assert cached_data_3 == missing_plan["tenant-3"] + + def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers): + """Test fallback to API when Redis mget fails.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + } + + # Act + with ( + patch.object(redis_client, "mget", side_effect=Exception("Redis connection error")), + patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk, + ): + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 2 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + + # Verify API was called for all tenants (fallback) + mock_get_plan_bulk.assert_called_once_with(tenant_ids) + + # Verify data was written to cache after fallback + cached_1 = self._get_cache("tenant-1") + cached_2 = self._get_cache("tenant-2") + assert cached_1 is not None + assert cached_2 is not None + + def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers): + """Test fallback to API when cache contains invalid JSON.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2", "tenant-3"] + + # Set valid cache for tenant-1 + self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600)) + + # Set invalid JSON for tenant-2 + cache_key_2 = BillingService._make_plan_cache_key("tenant-2") + redis_client.setex(cache_key_2, 600, "invalid json {") + + # tenant-3 is not in cache + expected_plans = { + "tenant-2": self._create_test_plan_data("professional", 1767225600), + "tenant-3": self._create_test_plan_data("team", 1798761600), + } + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 3 + assert result["tenant-1"]["plan"] == "sandbox" # From cache + assert result["tenant-2"]["plan"] == "professional" # From API (fallback) + assert result["tenant-3"]["plan"] == "team" # From API + + # Verify API was called for tenant-2 and tenant-3 + mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"]) + + # Verify tenant-2's invalid JSON was replaced with correct data in cache + cached_2 = self._get_cache("tenant-2") + assert cached_2 is not None + cached_data_2 = json.loads(cached_2) + assert cached_data_2 == expected_plans["tenant-2"] + assert cached_data_2["plan"] == "professional" + assert cached_data_2["expiration_date"] == 1767225600 + + # Verify tenant-2 cache has correct TTL + cache_key_2_new = BillingService._make_plan_cache_key("tenant-2") + ttl_2 = redis_client.ttl(cache_key_2_new) + assert ttl_2 > 0 + assert ttl_2 <= 600 + + # Verify tenant-3 data was also written to cache + cached_3 = self._get_cache("tenant-3") + assert cached_3 is not None + cached_data_3 = json.loads(cached_3) + assert cached_data_3 == expected_plans["tenant-3"] + + def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers): + """Test fallback to API when cache data doesn't match SubscriptionPlan schema.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2", "tenant-3"] + + # Set valid cache for tenant-1 + self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600)) + + # Set invalid plan data for tenant-2 (missing expiration_date) + cache_key_2 = BillingService._make_plan_cache_key("tenant-2") + invalid_data = json.dumps({"plan": "professional"}) # Missing expiration_date + redis_client.setex(cache_key_2, 600, invalid_data) + + # tenant-3 is not in cache + expected_plans = { + "tenant-2": self._create_test_plan_data("professional", 1767225600), + "tenant-3": self._create_test_plan_data("team", 1798761600), + } + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 3 + assert result["tenant-1"]["plan"] == "sandbox" # From cache + assert result["tenant-2"]["plan"] == "professional" # From API (fallback) + assert result["tenant-3"]["plan"] == "team" # From API + + # Verify API was called for tenant-2 and tenant-3 + mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"]) + + def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers): + """Test that pipeline failure doesn't affect return value.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + } + + # Act + with ( + patch.object(BillingService, "get_plan_bulk", return_value=expected_plans), + patch.object(redis_client, "pipeline") as mock_pipeline, + ): + # Create a mock pipeline that fails on execute + mock_pipe = mock_pipeline.return_value + mock_pipe.execute.side_effect = Exception("Pipeline execution failed") + + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert - Function should still return correct result despite pipeline failure + assert len(result) == 2 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + + # Verify pipeline was attempted + mock_pipeline.assert_called_once() + + def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers): + """Test with empty tenant_ids list.""" + with flask_app_with_containers.app_context(): + # Act + with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache([]) + + # Assert + assert result == {} + assert len(result) == 0 + + # Verify no API calls + mock_get_plan_bulk.assert_not_called() + + # Verify no Redis operations (mget with empty list would return empty list) + # But we should check that mget was not called at all + # Since we can't easily verify this without more mocking, we just verify the result + + def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers): + """Test that expired cache keys are treated as cache misses.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + + # Set cache for tenant-1 with very short TTL (1 second) to simulate expiration + self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600), ttl=1) + + # Wait for TTL to expire (key will be deleted by Redis) + import time + + time.sleep(2) + + # Verify cache is expired (key doesn't exist) + cache_key_1 = BillingService._make_plan_cache_key("tenant-1") + exists = redis_client.exists(cache_key_1) + assert exists == 0 # Key doesn't exist (expired) + + # tenant-2 is not in cache + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + } + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 2 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + + # Verify API was called for both tenants (tenant-1 expired, tenant-2 missing) + mock_get_plan_bulk.assert_called_once_with(tenant_ids) + + # Verify both were written to cache with correct TTL + cache_key_1_new = BillingService._make_plan_cache_key("tenant-1") + cache_key_2 = BillingService._make_plan_cache_key("tenant-2") + ttl_1_new = redis_client.ttl(cache_key_1_new) + ttl_2 = redis_client.ttl(cache_key_2) + assert ttl_1_new > 0 + assert ttl_1_new <= 600 + assert ttl_2 > 0 + assert ttl_2 <= 600 diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index 40380b09d2..bd2fd14ffa 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -4,7 +4,13 @@ import pytest from faker import Faker from enums.cloud_plan import CloudPlan -from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel +from services.feature_service import ( + FeatureModel, + FeatureService, + KnowledgeRateLimitModel, + LicenseStatus, + SystemFeatureModel, +) class TestFeatureService: @@ -274,7 +280,7 @@ class TestFeatureService: mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None @@ -324,6 +330,61 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval for an unauthenticated user. + + This test verifies that: + - The response payload is minimized (e.g., verbose license details are excluded). + - Essential UI configuration (Branding, SSO, Marketplace) remains available. + - The response structure adheres to the public schema for unauthenticated clients. + """ + # Arrange: Setup test data with exact same config as success test + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute with is_authenticated=False + result = FeatureService.get_system_features(is_authenticated=False) + + # Assert: Basic structure + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # --- 1. Verify Response Payload Optimization (Data Minimization) --- + # Ensure only essential UI flags are returned to unauthenticated clients + # to keep the payload lightweight and adhere to architectural boundaries. + assert result.license.status == LicenseStatus.NONE + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.limit == 0 + assert result.license.workspaces.size == 0 + + # --- 2. Verify Public UI Configuration Availability --- + # Ensure that data required for frontend rendering remains accessible. + + # Branding should match the mock data + assert result.branding.enabled is True + assert result.branding.application_title == "Test Enterprise" + assert result.branding.login_page_logo == "https://example.com/logo.png" + + # SSO settings should be visible for login page rendering + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "saml" + + # General auth settings should be visible + assert result.enable_email_code_login is True + + # Marketplace should be visible + assert result.enable_marketplace is True + def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): """ Test system features retrieval with basic configuration (no enterprise). @@ -1031,7 +1092,7 @@ class TestFeatureService: } # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None @@ -1400,7 +1461,7 @@ class TestFeatureService: } # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py new file mode 100644 index 0000000000..5b6db64c09 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -0,0 +1,1071 @@ +import datetime +import json +import uuid +from decimal import Decimal +from unittest.mock import patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.billing_service import BillingService +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + create_message_clean_policy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +class TestMessagesCleanServiceIntegration: + """Integration tests for MessagesCleanService.run() and _clean_messages_by_time_range().""" + + # Redis cache key prefix from BillingService + PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database before and after each test to ensure isolation.""" + yield + # Clear all test data in correct order (respecting foreign key constraints) + db.session.query(DatasetRetrieverResource).delete() + db.session.query(AppAnnotationHitHistory).delete() + db.session.query(SavedMessage).delete() + db.session.query(MessageFile).delete() + db.session.query(MessageAgentThought).delete() + db.session.query(MessageChain).delete() + db.session.query(MessageAnnotation).delete() + db.session.query(MessageFeedback).delete() + db.session.query(Message).delete() + db.session.query(Conversation).delete() + db.session.query(App).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + @pytest.fixture(autouse=True) + def cleanup_redis(self): + """Clean up Redis cache before each test.""" + # Clear tenant plan cache using BillingService key prefix + try: + keys = redis_client.keys(f"{self.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass # Redis might not be available in some test environments + yield + # Clean up after test + try: + keys = redis_client.keys(f"{self.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass + + @pytest.fixture + def mock_whitelist(self): + """Mock whitelist to return empty list by default.""" + with patch( + "services.retention.conversation.messages_clean_policy.BillingService.get_expired_subscription_cleanup_whitelist" + ) as mock: + mock.return_value = [] + yield mock + + @pytest.fixture + def mock_billing_enabled(self): + """Mock BILLING_ENABLED to be True.""" + with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", True): + yield + + @pytest.fixture + def mock_billing_disabled(self): + """Mock BILLING_ENABLED to be False.""" + with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): + yield + + def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): + """Helper to create account and tenant.""" + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.flush() + + tenant = Tenant( + name=fake.company(), + plan=str(plan), + status="normal", + ) + db.session.add(tenant) + db.session.flush() + + tenant_account_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + db.session.add(tenant_account_join) + db.session.commit() + + return account, tenant + + def _create_app(self, tenant, account): + """Helper to create an app.""" + fake = Faker() + + app = App( + tenant_id=tenant.id, + name=fake.company(), + description="Test app", + mode="chat", + enable_site=True, + enable_api=True, + api_rpm=60, + api_rph=3600, + is_demo=False, + is_public=False, + created_by=account.id, + updated_by=account.id, + ) + db.session.add(app) + db.session.commit() + + return app + + def _create_conversation(self, app): + """Helper to create a conversation.""" + conversation = Conversation( + app_id=app.id, + app_model_config_id=str(uuid.uuid4()), + model_provider="openai", + model_id="gpt-3.5-turbo", + mode="chat", + name="Test conversation", + inputs={}, + status="normal", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(conversation) + db.session.commit() + + return conversation + + def _create_message(self, app, conversation, created_at=None, with_relations=True): + """Helper to create a message with optional related records.""" + if created_at is None: + created_at = datetime.datetime.now() + + message = Message( + app_id=app.id, + conversation_id=conversation.id, + model_provider="openai", + model_id="gpt-3.5-turbo", + inputs={}, + query="Test query", + answer="Test answer", + message=[{"role": "user", "text": "Test message"}], + message_tokens=10, + message_unit_price=Decimal("0.001"), + answer_tokens=20, + answer_unit_price=Decimal("0.002"), + total_price=Decimal("0.003"), + currency="USD", + from_source="api", + from_account_id=conversation.from_end_user_id, + created_at=created_at, + ) + db.session.add(message) + db.session.flush() + + if with_relations: + self._create_message_relations(message) + + db.session.commit() + return message + + def _create_message_relations(self, message): + """Helper to create all message-related records.""" + # MessageFeedback + feedback = MessageFeedback( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + rating="like", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(feedback) + + # MessageAnnotation + annotation = MessageAnnotation( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + question="Test question", + content="Test annotation", + account_id=message.from_account_id, + ) + db.session.add(annotation) + + # MessageChain + chain = MessageChain( + message_id=message.id, + type="system", + input=json.dumps({"test": "input"}), + output=json.dumps({"test": "output"}), + ) + db.session.add(chain) + db.session.flush() + + # MessageFile + file = MessageFile( + message_id=message.id, + type="image", + transfer_method="local_file", + url="http://example.com/test.jpg", + belongs_to="user", + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(file) + + # SavedMessage + saved = SavedMessage( + app_id=message.app_id, + message_id=message.id, + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(saved) + + db.session.flush() + + # AppAnnotationHitHistory + hit = AppAnnotationHitHistory( + app_id=message.app_id, + annotation_id=annotation.id, + message_id=message.id, + source="annotation", + question="Test question", + account_id=message.from_account_id, + score=0.9, + annotation_question="Test annotation question", + annotation_content="Test annotation content", + ) + db.session.add(hit) + + # DatasetRetrieverResource + resource = DatasetRetrieverResource( + message_id=message.id, + position=1, + dataset_id=str(uuid.uuid4()), + dataset_name="Test dataset", + document_id=str(uuid.uuid4()), + document_name="Test document", + data_source_type="upload_file", + segment_id=str(uuid.uuid4()), + score=0.9, + content="Test content", + hit_count=1, + word_count=10, + segment_position=1, + index_node_hash="test_hash", + retriever_from="dataset", + created_by=message.from_account_id, + ) + db.session.add(resource) + + def test_billing_disabled_deletes_all_messages_in_time_range( + self, db_session_with_containers, mock_billing_disabled + ): + """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" + # Arrange - Create tenant with messages (plan doesn't matter for billing disabled) + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages: in-range (should be deleted) and out-of-range (should be kept) + in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) + out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) + + in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) + in_range_msg_id = in_range_msg.id + + out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) + out_of_range_msg_id = out_of_range_msg.id + + # Act - create_message_clean_policy should return BillingDisabledPolicy + policy = create_message_clean_policy() + + assert isinstance(policy, BillingDisabledPolicy) + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime(2024, 1, 10, 0, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 0, 0, 0), + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 1 # Only in-range message fetched + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # In-range message deleted + assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 + # Out-of-range message kept + assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 + + # Related records of in-range message deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 + assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 + # Related records of out-of-range message kept + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 + + def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning when there are no messages to delete (B1).""" + # Arrange + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + start_from = datetime.datetime.now() - datetime.timedelta(days=60) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} + + # Act + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - loop runs once to check, finds nothing + assert stats["batches"] == 1 + assert stats["total_messages"] == 0 + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning with mixed sandbox and paid tenants (B2).""" + # Arrange - Create sandbox tenants with expired messages + sandbox_tenants = [] + sandbox_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + sandbox_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 3 expired messages per sandbox tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + sandbox_message_ids.append(msg.id) + + # Create paid tenants with expired messages (should NOT be deleted) + paid_tenants = [] + paid_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + paid_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 2 expired messages per paid tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(2): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + paid_message_ids.append(msg.id) + + # Mock billing service - return plan and expiration_date + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + expired_15_days_ago = now_timestamp - (15 * 24 * 60 * 60) # Beyond 7-day grace period + + plan_map = {} + for tenant in sandbox_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_15_days_ago, + } + for tenant in paid_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=7) + + assert isinstance(policy, BillingSandboxPolicy) + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 10 # 2 sandbox * 3 + 2 paid * 2 + assert stats["filtered_messages"] == 6 # 2 sandbox tenants * 3 messages + assert stats["total_deleted"] == 6 + + # Only sandbox messages should be deleted + assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + # Paid messages should remain + assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + + # Related records of sandbox messages should be deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 + assert ( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + == 0 + ) + + def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cursor pagination works correctly across multiple batches (B3).""" + # Arrange - Create sandbox tenant with messages that will span multiple batches + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 10 expired messages with different timestamps + base_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(10): + msg = self._create_message( + app, + conv, + created_at=base_date + datetime.timedelta(hours=i), + with_relations=False, # Skip relations for speed + ) + message_ids.append(msg.id) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Use small batch size to trigger multiple batches + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=3, # Small batch size to test pagination + ) + stats = service.run() + + # 5 batches for 10 messages with batch_size=3, the last batch is empty + assert stats["batches"] == 5 + assert stats["total_messages"] == 10 + assert stats["filtered_messages"] == 10 + assert stats["total_deleted"] == 10 + + # All messages should be deleted + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + + def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test dry_run mode does not delete messages (B4).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create expired messages + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + message_ids.append(msg.id) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + dry_run=True, # Dry run mode + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 3 + assert stats["filtered_messages"] == 3 # Messages identified + assert stats["total_deleted"] == 0 # But NOT deleted + + # All messages should still exist + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + # Related records should also still exist + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + + def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test when billing returns partial data, unknown tenants are preserved (B5).""" + # Arrange - Create 3 tenants + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service to return partial data + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Only tenant[0] is confirmed as sandbox, tenant[1] is professional, tenant[2] is missing + partial_plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + }, + # tenants_data[2] is missing from response + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = partial_plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only tenant[0]'s message should be deleted + assert stats["total_messages"] == 3 # 3 tenants * 1 message + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Check which messages were deleted + assert ( + db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + ) # Sandbox tenant's message deleted + + assert ( + db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + ) # Professional tenant's message preserved + + assert ( + db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + ) # Unknown tenant's message preserved (safe default) + + def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test when billing returns empty data, skip deletion entirely (B6).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + msg_id = msg.id + db.session.commit() + + # Mock billing service to return empty data (simulating failure/no data scenario) + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} # Empty response, tenant plan unknown + + # Act - Should not raise exception, just skip deletion + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - No messages should be deleted when plan is unknown + assert stats["total_messages"] == 1 + assert stats["filtered_messages"] == 0 # Cannot determine sandbox messages + assert stats["total_deleted"] == 0 + + # Message should still exist (safe default - don't delete if plan is unknown) + assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + + def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages: before range, in range, after range + msg_before = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from + with_relations=False, + ) + msg_before_id = msg_before.id + + msg_at_start = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) + with_relations=False, + ) + msg_at_start_id = msg_at_start.id + + msg_in_range = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range + with_relations=False, + ) + msg_in_range_id = msg_in_range.id + + msg_at_end = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) + with_relations=False, + ) + msg_at_end_id = msg_at_end.id + + msg_after = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before + with_relations=False, + ) + msg_after_id = msg_after.id + + db.session.commit() + + # Mock billing service + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Clean with specific time range [2024-01-10, 2024-01-20) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime(2024, 1, 10, 12, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 12, 0, 0), + batch_size=100, + ) + stats = service.run() + + # Assert - Only messages in [start_from, end_before) should be deleted + assert stats["total_messages"] == 2 # Only in-range messages fetched + assert stats["filtered_messages"] == 2 # msg_at_start and msg_in_range + assert stats["total_deleted"] == 2 + + # Verify specific messages using stored IDs + # Before range, kept + assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + # At start (inclusive), deleted + assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + # In range, deleted + assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + # At end (exclusive), kept + assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + # After range, kept + assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + + def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning with different graceful period scenarios (B8).""" + # Arrange - Create 5 different tenants with different plan and expiration scenarios + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + graceful_period = 8 # Use 8 days for this test + + # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) + # Should NOT be deleted + account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1_id = msg1.id + expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period + + # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) + # Should be deleted + account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + msg2_id = msg2.id + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period + + # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) + # Should be deleted + account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app3 = self._create_app(tenant3, account3) + conv3 = self._create_conversation(app3) + msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + msg3_id = msg3.id + + # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) + # Should NOT be deleted + account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + app4 = self._create_app(tenant4, account4) + conv4 = self._create_conversation(app4) + msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + msg4_id = msg4.id + future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year + + # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) + # Should NOT be deleted (boundary is exclusive: > graceful_period) + account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app5 = self._create_app(tenant5, account5) + conv5 = self._create_conversation(app5) + msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + msg5_id = msg5.id + expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary + + db.session.commit() + + # Mock billing service with all scenarios + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_5_days_ago, + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, + }, + tenant3.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenant4.id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": future_expiration, + }, + tenant5.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_exactly_8_days_ago, + }, + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy( + graceful_period_days=graceful_period, + current_timestamp=now_timestamp, # Use fixed timestamp for deterministic behavior + ) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only messages from scenario 2 and 3 should be deleted + assert stats["total_messages"] == 5 # 5 tenants * 1 message + assert stats["filtered_messages"] == 2 + assert stats["total_deleted"] == 2 + + # Verify each scenario using saved IDs + assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted + assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted + assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept + assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + + def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test that whitelisted tenants' messages are not deleted (B9).""" + # Arrange - Create 3 sandbox tenants with expired messages + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service - all tenants are sandbox with no subscription + plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenants_data[2]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + } + + # Setup whitelist - tenant0 and tenant1 are whitelisted, tenant2 is not + whitelist = [tenants_data[0]["tenant"].id, tenants_data[1]["tenant"].id] + mock_whitelist.return_value = whitelist + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only tenant2's message should be deleted (not whitelisted) + assert stats["total_messages"] == 3 # 3 tenants * 1 message + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Verify tenant0's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + + # Verify tenant1's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + + # Verify tenant2's message was deleted (not whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + + def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test from_days correctly cleans messages older than N days (B11).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create old messages (should be deleted - older than 30 days) + old_date = datetime.datetime.now() - datetime.timedelta(days=45) + old_msg_ids = [] + for i in range(3): + msg = self._create_message( + app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False + ) + old_msg_ids.append(msg.id) + + # Create recent messages (should be kept - newer than 30 days) + recent_date = datetime.datetime.now() - datetime.timedelta(days=15) + recent_msg_ids = [] + for i in range(2): + msg = self._create_message( + app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False + ) + recent_msg_ids.append(msg.id) + + db.session.commit() + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + } + } + + # Act - Use from_days to clean messages older than 30 days + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_days( + policy=policy, + days=30, + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 3 # Only old messages in range + assert stats["filtered_messages"] == 3 # Only old messages + assert stats["total_deleted"] == 3 + + # Old messages deleted + assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 + # Recent messages kept + assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 + + def test_whitelist_precedence_over_grace_period( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that whitelist takes precedence over grace period logic.""" + # Arrange - Create 2 sandbox tenants + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Tenant1: whitelisted, expired beyond grace period + account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace + + # Tenant2: not whitelisted, within grace period + account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace + + # Mock billing service + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_30_days_ago, # Beyond grace period + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, # Within grace period + }, + } + + # Setup whitelist - only tenant1 is whitelisted + whitelist = [tenant1.id] + mock_whitelist.return_value = whitelist + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - No messages should be deleted + # tenant1: whitelisted (protected even though beyond grace period) + # tenant2: within grace period (not eligible for deletion) + assert stats["total_messages"] == 2 # 2 tenants * 1 message + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + # Verify both messages still exist + assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + + def test_empty_whitelist_deletes_eligible_messages( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" + # Arrange - Create sandbox tenant with expired messages + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg_ids.append(msg.id) + + # Mock billing service + plan_map = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + } + } + + # Setup empty whitelist (default behavior from fixture) + mock_whitelist.return_value = [] + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - All messages should be deleted (no whitelist protection) + assert stats["total_messages"] == 3 + assert stats["filtered_messages"] == 3 + assert stats["total_deleted"] == 3 + + # Verify all messages were deleted + assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 612210ef86..d57ab7428b 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -228,7 +228,6 @@ class TestModelProviderService: mock_provider_entity.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"} mock_provider_entity.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity.icon_small_dark = None - mock_provider_entity.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity.background = "#FF6B6B" mock_provider_entity.help = None mock_provider_entity.supported_model_types = [ModelType.LLM, ModelType.TEXT_EMBEDDING] @@ -302,7 +301,6 @@ class TestModelProviderService: mock_provider_entity_llm.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"} mock_provider_entity_llm.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity_llm.icon_small_dark = None - mock_provider_entity_llm.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity_llm.background = "#FF6B6B" mock_provider_entity_llm.help = None mock_provider_entity_llm.supported_model_types = [ModelType.LLM] @@ -316,7 +314,6 @@ class TestModelProviderService: mock_provider_entity_embedding.description = {"en_US": "Cohere provider", "zh_Hans": "Cohere 提供商"} mock_provider_entity_embedding.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity_embedding.icon_small_dark = None - mock_provider_entity_embedding.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity_embedding.background = "#4ECDC4" mock_provider_entity_embedding.help = None mock_provider_entity_embedding.supported_model_types = [ModelType.TEXT_EMBEDDING] @@ -419,7 +416,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], configurate_methods=[], models=[], @@ -431,7 +427,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], configurate_methods=[], models=[], @@ -655,7 +650,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], ), ) @@ -1027,7 +1021,6 @@ class TestModelProviderService: label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, icon_small_dark=None, - icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, ), model="gpt-3.5-turbo", model_type=ModelType.LLM, @@ -1045,7 +1038,6 @@ class TestModelProviderService: label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, icon_small_dark=None, - icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, ), model="gpt-4", model_type=ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 6732b8d558..e8c7f17e0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import create_autospec, patch import pytest @@ -312,6 +313,85 @@ class TestTagService: result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") assert len(result_no_match) == 0 + def test_get_tags_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test tag retrieval with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + from extensions.ext_database import db + + # Create tags with special characters in names + tag_with_percent = Tag( + name="50% discount", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_percent.id = str(uuid.uuid4()) + db.session.add(tag_with_percent) + + tag_with_underscore = Tag( + name="test_data_tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_underscore.id = str(uuid.uuid4()) + db.session.add(tag_with_underscore) + + tag_with_backslash = Tag( + name="path\\to\\tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_backslash.id = str(uuid.uuid4()) + db.session.add(tag_with_backslash) + + # Create tag that should NOT match + tag_no_match = Tag( + name="100% different", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_no_match.id = str(uuid.uuid4()) + db.session.add(tag_no_match) + + db.session.commit() + + # Act & Assert: Test 1 - Search with % character + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert result[0].name == "50% discount" + + # Test 2 - Search with _ character + result = TagService.get_tags("app", tenant.id, keyword="test_data") + assert len(result) == 1 + assert result[0].name == "test_data_tag" + + # Test 3 - Search with \ character + result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag") + assert len(result) == 1 + assert result[0].name == "path\\to\\tag" + + # Test 4 - Search with % should NOT match 100% (verifies escaping works) + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert all("50%" in item.name for item in result) + def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py new file mode 100644 index 0000000000..5315960d73 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -0,0 +1,560 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from constants import HIDDEN_VALUE, UNKNOWN_VALUE +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity +from extensions.ext_database import db +from models.provider_ids import TriggerProviderID +from models.trigger import TriggerSubscription +from services.trigger.trigger_provider_service import TriggerProviderService + + +class TestTriggerProviderService: + """Integration tests for TriggerProviderService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager, + patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client, + patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns + mock_provider_controller = MagicMock() + mock_provider_controller.get_credential_schema_config.return_value = MagicMock() + mock_provider_controller.get_properties_schema.return_value = MagicMock() + mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller + + # Mock redis lock + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock(return_value=None) + mock_lock.__exit__ = MagicMock(return_value=None) + mock_redis_client.lock.return_value = mock_lock + + # Setup account feature service mock + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + yield { + "trigger_manager": mock_trigger_manager, + "redis_client": mock_redis_client, + "delete_cache": mock_delete_cache, + "provider_controller": mock_provider_controller, + "account_feature_service": mock_account_feature_service, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + from services.account_service import AccountService, TenantService + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "trigger_manager" + ].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"] + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + return account, tenant + + def _create_test_subscription( + self, + db_session_with_containers, + tenant_id, + user_id, + provider_id, + credential_type, + credentials, + mock_external_service_dependencies, + ): + """ + Helper method to create a test trigger subscription. + + Args: + db_session_with_containers: Database session + tenant_id: Tenant ID + user_id: User ID + provider_id: Provider ID + credential_type: Credential type + credentials: Credentials dict + mock_external_service_dependencies: Mock dependencies + + Returns: + TriggerSubscription: Created subscription instance + """ + fake = Faker() + from core.helper.provider_cache import NoOpProviderCredentialCache + from core.helper.provider_encryption import create_provider_encrypter + + # Use mock provider controller to encrypt credentials + provider_controller = mock_external_service_dependencies["provider_controller"] + + # Create encrypter for credentials + credential_encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=provider_controller.get_credential_schema_config(credential_type), + cache=NoOpProviderCredentialCache(), + ) + + subscription = TriggerSubscription( + name=fake.word(), + tenant_id=tenant_id, + user_id=user_id, + provider_id=str(provider_id), + endpoint_id=fake.uuid4(), + parameters={"param1": "value1"}, + properties={"prop1": "value1"}, + credentials=dict(credential_encrypter.encrypt(credentials)), + credential_type=credential_type.value, + credential_expires_at=-1, + expires_at=-1, + ) + + db.session.add(subscription) + db.session.commit() + db.session.refresh(subscription) + + return subscription + + def test_rebuild_trigger_subscription_success_with_merged_credentials( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful rebuild with credential merging (HIDDEN_VALUE handling). + + This test verifies: + - Credentials are properly merged (HIDDEN_VALUE replaced with existing values) + - Single transaction wraps all operations + - Merged credentials are used for subscribe and update + - Database state is correctly updated + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + provider_id = TriggerProviderID("test_org/test_plugin/test_provider") + credential_type = CredentialType.API_KEY + + # Create initial subscription with credentials + original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"} + subscription = self._create_test_subscription( + db_session_with_containers, + tenant.id, + account.id, + provider_id, + credential_type, + original_credentials, + mock_external_service_dependencies, + ) + + # Prepare new credentials with HIDDEN_VALUE for api_key (should keep original) + # and new value for api_secret (should update) + new_credentials = { + "api_key": HIDDEN_VALUE, # Should be replaced with original + "api_secret": "new-secret-value", # Should be updated + } + + # Mock subscribe_trigger to return a new subscription entity + new_subscription_entity = TriggerSubscriptionEntity( + endpoint=subscription.endpoint_id, + parameters={"param1": "value1"}, + properties={"prop1": "new_prop_value"}, + expires_at=1234567890, + ) + mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity + + # Mock unsubscribe_trigger + mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock() + + # Execute rebuild + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription.id, + credentials=new_credentials, + parameters={"param1": "updated_value"}, + name="updated_name", + ) + + # Verify unsubscribe was called with decrypted original credentials + mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once() + unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args + assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id + assert unsubscribe_call_args.kwargs["provider_id"] == provider_id + assert unsubscribe_call_args.kwargs["credential_type"] == credential_type + + # Verify subscribe was called with merged credentials (api_key from original, api_secret new) + mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once() + subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args + subscribe_credentials = subscribe_call_args.kwargs["credentials"] + assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original + assert subscribe_credentials["api_secret"] == "new-secret-value" # New value + + # Verify database state was updated + db.session.refresh(subscription) + assert subscription.name == "updated_name" + assert subscription.parameters == {"param1": "updated_value"} + + # Verify credentials in DB were updated with merged values (decrypt to check) + from core.helper.provider_cache import NoOpProviderCredentialCache + from core.helper.provider_encryption import create_provider_encrypter + + # Use mock provider controller to decrypt credentials + provider_controller = mock_external_service_dependencies["provider_controller"] + credential_encrypter, _ = create_provider_encrypter( + tenant_id=tenant.id, + config=provider_controller.get_credential_schema_config(credential_type), + cache=NoOpProviderCredentialCache(), + ) + decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials)) + assert decrypted_db_credentials["api_key"] == original_credentials["api_key"] + assert decrypted_db_credentials["api_secret"] == "new-secret-value" + + # Verify cache was cleared + mock_external_service_dependencies["delete_cache"].assert_called_once_with( + tenant_id=tenant.id, + provider_id=subscription.provider_id, + subscription_id=subscription.id, + ) + + def test_rebuild_trigger_subscription_with_all_new_credentials( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test rebuild when all credentials are new (no HIDDEN_VALUE). + + This test verifies: + - All new credentials are used when no HIDDEN_VALUE is present + - Merged credentials contain only new values + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + provider_id = TriggerProviderID("test_org/test_plugin/test_provider") + credential_type = CredentialType.API_KEY + + # Create initial subscription + original_credentials = {"api_key": "original-key", "api_secret": "original-secret"} + subscription = self._create_test_subscription( + db_session_with_containers, + tenant.id, + account.id, + provider_id, + credential_type, + original_credentials, + mock_external_service_dependencies, + ) + + # All new credentials (no HIDDEN_VALUE) + new_credentials = { + "api_key": "completely-new-key", + "api_secret": "completely-new-secret", + } + + new_subscription_entity = TriggerSubscriptionEntity( + endpoint=subscription.endpoint_id, + parameters={}, + properties={}, + expires_at=-1, + ) + mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity + mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock() + + # Execute rebuild + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription.id, + credentials=new_credentials, + parameters={}, + ) + + # Verify subscribe was called with all new credentials + subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args + subscribe_credentials = subscribe_call_args.kwargs["credentials"] + assert subscribe_credentials["api_key"] == "completely-new-key" + assert subscribe_credentials["api_secret"] == "completely-new-secret" + + def test_rebuild_trigger_subscription_with_all_hidden_values( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing). + + This test verifies: + - All HIDDEN_VALUE credentials are replaced with existing values + - Original credentials are preserved + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + provider_id = TriggerProviderID("test_org/test_plugin/test_provider") + credential_type = CredentialType.API_KEY + + original_credentials = {"api_key": "original-key", "api_secret": "original-secret"} + subscription = self._create_test_subscription( + db_session_with_containers, + tenant.id, + account.id, + provider_id, + credential_type, + original_credentials, + mock_external_service_dependencies, + ) + + # All HIDDEN_VALUE (should preserve all original) + new_credentials = { + "api_key": HIDDEN_VALUE, + "api_secret": HIDDEN_VALUE, + } + + new_subscription_entity = TriggerSubscriptionEntity( + endpoint=subscription.endpoint_id, + parameters={}, + properties={}, + expires_at=-1, + ) + mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity + mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock() + + # Execute rebuild + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription.id, + credentials=new_credentials, + parameters={}, + ) + + # Verify subscribe was called with all original credentials + subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args + subscribe_credentials = subscribe_call_args.kwargs["credentials"] + assert subscribe_credentials["api_key"] == original_credentials["api_key"] + assert subscribe_credentials["api_secret"] == original_credentials["api_secret"] + + def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original. + + This test verifies: + - UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + provider_id = TriggerProviderID("test_org/test_plugin/test_provider") + credential_type = CredentialType.API_KEY + + # Original has only api_key + original_credentials = {"api_key": "original-key"} + subscription = self._create_test_subscription( + db_session_with_containers, + tenant.id, + account.id, + provider_id, + credential_type, + original_credentials, + mock_external_service_dependencies, + ) + + # HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE + new_credentials = { + "api_key": HIDDEN_VALUE, + "non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original + } + + new_subscription_entity = TriggerSubscriptionEntity( + endpoint=subscription.endpoint_id, + parameters={}, + properties={}, + expires_at=-1, + ) + mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity + mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock() + + # Execute rebuild + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription.id, + credentials=new_credentials, + parameters={}, + ) + + # Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key + subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args + subscribe_credentials = subscribe_call_args.kwargs["credentials"] + assert subscribe_credentials["api_key"] == original_credentials["api_key"] + assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE + + def test_rebuild_trigger_subscription_rollback_on_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that transaction is rolled back on error. + + This test verifies: + - Database transaction is rolled back when an error occurs + - Original subscription state is preserved + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + provider_id = TriggerProviderID("test_org/test_plugin/test_provider") + credential_type = CredentialType.API_KEY + + original_credentials = {"api_key": "original-key"} + subscription = self._create_test_subscription( + db_session_with_containers, + tenant.id, + account.id, + provider_id, + credential_type, + original_credentials, + mock_external_service_dependencies, + ) + + original_name = subscription.name + original_parameters = subscription.parameters.copy() + + # Make subscribe_trigger raise an error + mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError( + "Subscribe failed" + ) + mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock() + + # Execute rebuild and expect error + with pytest.raises(ValueError, match="Subscribe failed"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription.id, + credentials={"api_key": "new-key"}, + parameters={}, + ) + + # Verify subscription state was not changed (rolled back) + db.session.refresh(subscription) + assert subscription.name == original_name + assert subscription.parameters == original_parameters + + def test_rebuild_trigger_subscription_subscription_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error when subscription is not found. + + This test verifies: + - Proper error is raised when subscription doesn't exist + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + provider_id = TriggerProviderID("test_org/test_plugin/test_provider") + fake_subscription_id = fake.uuid4() + + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=fake_subscription_id, + credentials={}, + parameters={}, + ) + + def test_rebuild_trigger_subscription_name_uniqueness_check( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that name uniqueness is checked when updating name. + + This test verifies: + - Error is raised when new name conflicts with existing subscription + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + provider_id = TriggerProviderID("test_org/test_plugin/test_provider") + credential_type = CredentialType.API_KEY + + # Create first subscription + subscription1 = self._create_test_subscription( + db_session_with_containers, + tenant.id, + account.id, + provider_id, + credential_type, + {"api_key": "key1"}, + mock_external_service_dependencies, + ) + + # Create second subscription with different name + subscription2 = self._create_test_subscription( + db_session_with_containers, + tenant.id, + account.id, + provider_id, + credential_type, + {"api_key": "key2"}, + mock_external_service_dependencies, + ) + + new_subscription_entity = TriggerSubscriptionEntity( + endpoint=subscription2.endpoint_id, + parameters={}, + properties={}, + expires_at=-1, + ) + mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity + mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock() + + # Try to rename subscription2 to subscription1's name (should fail) + with pytest.raises(ValueError, match="already exists"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription2.id, + credentials={"api_key": "new-key"}, + parameters={}, + name=subscription1.name, # Conflicting name + ) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 7b95944bbe..040fb826e1 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService from services.workflow_app_service import WorkflowAppService @@ -86,6 +88,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -147,6 +152,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -308,6 +316,156 @@ class TestWorkflowAppService: assert result_no_match["total"] == 0 assert len(result_no_match["data"]) == 0 + def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) + + from extensions.ext_database import db + + service = WorkflowAppService() + + # Test 1: Search with % character + workflow_run_1 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}), + outputs=json.dumps({"result": "50% discount applied", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_1) + db.session.flush() + + workflow_app_log_1 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_1.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_1.id = str(uuid.uuid4()) + workflow_app_log_1.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_1) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should find the workflow_run_1 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"]) + + # Test 2: Search with _ character + workflow_run_2 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}), + outputs=json.dumps({"result": "test_data_value found", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_2) + db.session.flush() + + workflow_app_log_2 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_2.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_2.id = str(uuid.uuid4()) + workflow_app_log_2.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_2) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 + ) + # Should find the workflow_run_2 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"]) + + # Test 3: Search with % should NOT match 100% (verifies escaping works correctly) + workflow_run_4 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}), + outputs=json.dumps({"result": "100% different result", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_4) + db.session.flush() + + workflow_app_log_4 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_4.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_4.id = str(uuid.uuid4()) + workflow_app_log_4.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_4) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4) + # This verifies that escaping works correctly - 50% should not match 100% + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + # Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different) + found_run_ids = [log.workflow_run_id for log in result["data"]] + assert workflow_run_1.id in found_run_ids + assert workflow_run_4.id not in found_run_ids + def test_get_paginate_workflow_app_logs_with_status_filter( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 88c6313f64..cb691d5c3d 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -925,24 +925,24 @@ class TestWorkflowService: # Create app model config (required for conversion) from models.model import AppModelConfig - app_model_config = AppModelConfig() - app_model_config.id = fake.uuid4() - app_model_config.app_id = app.id - app_model_config.tenant_id = app.tenant_id - app_model_config.provider = "openai" - app_model_config.model_id = "gpt-3.5-turbo" - # Set the model field directly - this is what model_dict property returns - app_model_config.model = json.dumps( - { - "provider": "openai", - "name": "gpt-3.5-turbo", - "completion_params": {"max_tokens": 1000, "temperature": 0.7}, - } + app_model_config = AppModelConfig( + app_id=app.id, + provider="openai", + model_id="gpt-3.5-turbo", + # Set the model field directly - this is what model_dict property returns + model=json.dumps( + { + "provider": "openai", + "name": "gpt-3.5-turbo", + "completion_params": {"max_tokens": 1000, "temperature": 0.7}, + } + ), + # Set pre_prompt for PromptTemplateConfigManager + pre_prompt="You are a helpful assistant.", + created_by=account.id, + updated_by=account.id, ) - # Set pre_prompt for PromptTemplateConfigManager - app_model_config.pre_prompt = "You are a helpful assistant." - app_model_config.created_by = account.id - app_model_config.updated_by = account.id + app_model_config.id = fake.uuid4() from extensions.ext_database import db @@ -987,24 +987,24 @@ class TestWorkflowService: # Create app model config (required for conversion) from models.model import AppModelConfig - app_model_config = AppModelConfig() - app_model_config.id = fake.uuid4() - app_model_config.app_id = app.id - app_model_config.tenant_id = app.tenant_id - app_model_config.provider = "openai" - app_model_config.model_id = "gpt-3.5-turbo" - # Set the model field directly - this is what model_dict property returns - app_model_config.model = json.dumps( - { - "provider": "openai", - "name": "gpt-3.5-turbo", - "completion_params": {"max_tokens": 1000, "temperature": 0.7}, - } + app_model_config = AppModelConfig( + app_id=app.id, + provider="openai", + model_id="gpt-3.5-turbo", + # Set the model field directly - this is what model_dict property returns + model=json.dumps( + { + "provider": "openai", + "name": "gpt-3.5-turbo", + "completion_params": {"max_tokens": 1000, "temperature": 0.7}, + } + ), + # Set pre_prompt for PromptTemplateConfigManager + pre_prompt="Complete the following text:", + created_by=account.id, + updated_by=account.id, ) - # Set pre_prompt for PromptTemplateConfigManager - app_model_config.pre_prompt = "Complete the following text:" - app_model_config.created_by = account.id - app_model_config.updated_by = account.id + app_model_config.id = fake.uuid4() from extensions.ext_database import db diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 9297e997e9..09407f7686 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -39,23 +39,22 @@ class TestCleanDatasetTask: @pytest.fixture(autouse=True) def cleanup_database(self, db_session_with_containers): """Clean up database before each test to ensure isolation.""" - from extensions.ext_database import db from extensions.ext_redis import redis_client - # Clear all test data - db.session.query(DatasetMetadataBinding).delete() - db.session.query(DatasetMetadata).delete() - db.session.query(AppDatasetJoin).delete() - db.session.query(DatasetQuery).delete() - db.session.query(DatasetProcessRule).delete() - db.session.query(DocumentSegment).delete() - db.session.query(Document).delete() - db.session.query(Dataset).delete() - db.session.query(UploadFile).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + # Clear all test data using the provided session fixture + db_session_with_containers.query(DatasetMetadataBinding).delete() + db_session_with_containers.query(DatasetMetadata).delete() + db_session_with_containers.query(AppDatasetJoin).delete() + db_session_with_containers.query(DatasetQuery).delete() + db_session_with_containers.query(DatasetProcessRule).delete() + db_session_with_containers.query(DocumentSegment).delete() + db_session_with_containers.query(Document).delete() + db_session_with_containers.query(Dataset).delete() + db_session_with_containers.query(UploadFile).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() # Clear Redis cache redis_client.flushdb() @@ -103,10 +102,8 @@ class TestCleanDatasetTask: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -115,8 +112,8 @@ class TestCleanDatasetTask: status="active", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account relationship tenant_account_join = TenantAccountJoin( @@ -125,8 +122,8 @@ class TestCleanDatasetTask: role=TenantAccountRole.OWNER, ) - db.session.add(tenant_account_join) - db.session.commit() + db_session_with_containers.add(tenant_account_join) + db_session_with_containers.commit() return account, tenant @@ -155,10 +152,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @@ -194,10 +189,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document @@ -232,10 +225,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment @@ -267,10 +258,8 @@ class TestCleanDatasetTask: used=False, ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file @@ -302,31 +291,29 @@ class TestCleanDatasetTask: ) # Verify results - from extensions.ext_database import db - # Check that dataset-related data was cleaned up - documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(documents) == 0 - segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(segments) == 0 # Check that metadata and bindings were cleaned up - metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() assert len(metadata) == 0 - bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() assert len(bindings) == 0 # Check that process rules and queries were cleaned up - process_rules = db.session.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() + process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() assert len(process_rules) == 0 - queries = db.session.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() + queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() assert len(queries) == 0 # Check that app dataset joins were cleaned up - app_joins = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() + app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() assert len(app_joins) == 0 # Verify index processor was called @@ -378,9 +365,7 @@ class TestCleanDatasetTask: import json document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Create dataset metadata and bindings metadata = DatasetMetadata( @@ -403,11 +388,9 @@ class TestCleanDatasetTask: binding.id = str(uuid.uuid4()) binding.created_at = datetime.now() - from extensions.ext_database import db - - db.session.add(metadata) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(metadata) + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Execute the task clean_dataset_task( @@ -421,22 +404,24 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() assert len(remaining_files) == 0 # Check that metadata and bindings were cleaned up - remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() assert len(remaining_metadata) == 0 - remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + remaining_bindings = ( + db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + ) assert len(remaining_bindings) == 0 # Verify index processor was called @@ -489,12 +474,13 @@ class TestCleanDatasetTask: mock_index_processor.clean.assert_called_once() # Check that all data was cleaned up - from extensions.ext_database import db - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = ( + db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + ) assert len(remaining_segments) == 0 # Recreate data for next test case @@ -540,14 +526,13 @@ class TestCleanDatasetTask: ) # Verify results - even with vector cleanup failure, documents and segments should be deleted - from extensions.ext_database import db # Check that documents were still deleted despite vector cleanup failure - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that segments were still deleted despite vector cleanup failure - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Verify that index processor was called and failed @@ -608,10 +593,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Mock the get_image_upload_file_ids function to return our image file IDs with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids: @@ -629,16 +612,18 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Check that all image files were deleted from database image_file_ids = [f.id for f in image_files] - remaining_image_files = db.session.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() + remaining_image_files = ( + db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() + ) assert len(remaining_image_files) == 0 # Verify that storage.delete was called for each image file @@ -745,22 +730,24 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() assert len(remaining_files) == 0 # Check that all metadata and bindings were deleted - remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() assert len(remaining_metadata) == 0 - remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + remaining_bindings = ( + db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + ) assert len(remaining_bindings) == 0 # Verify performance expectations @@ -808,9 +795,7 @@ class TestCleanDatasetTask: import json document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock storage to raise exceptions mock_storage = mock_external_service_dependencies["storage"] @@ -827,18 +812,13 @@ class TestCleanDatasetTask: ) # Verify results - # Check that documents were still deleted despite storage failure - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() - assert len(remaining_documents) == 0 - - # Check that segments were still deleted despite storage failure - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() - assert len(remaining_segments) == 0 + # Note: When storage operations fail, database deletions may be rolled back by implementation. + # This test focuses on ensuring the task handles the exception and continues execution/logging. # Check that upload file was still deleted from database despite storage failure # Note: When storage operations fail, the upload file may not be deleted # This demonstrates that the cleanup process continues even with storage errors - remaining_files = db.session.query(UploadFile).filter_by(id=upload_file.id).all() + remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all() # The upload file should still be deleted from the database even if storage cleanup fails # However, this depends on the specific implementation of clean_dataset_task if len(remaining_files) > 0: @@ -890,10 +870,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document with special characters in name special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~" @@ -912,8 +890,8 @@ class TestCleanDatasetTask: created_at=datetime.now(), updated_at=datetime.now(), ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Create segment with special characters and very long content long_content = "Very long content " * 100 # Long content within reasonable limits @@ -934,8 +912,8 @@ class TestCleanDatasetTask: created_at=datetime.now(), updated_at=datetime.now(), ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Create upload file with special characters in name special_filename = f"test_file_{special_content}.txt" @@ -952,14 +930,14 @@ class TestCleanDatasetTask: created_at=datetime.now(), used=False, ) - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() # Update document with file reference import json document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Save upload file ID for verification upload_file_id = upload_file.id @@ -975,8 +953,8 @@ class TestCleanDatasetTask: special_metadata.id = str(uuid.uuid4()) special_metadata.created_at = datetime.now() - db.session.add(special_metadata) - db.session.commit() + db_session_with_containers.add(special_metadata) + db_session_with_containers.commit() # Execute the task clean_dataset_task( @@ -990,19 +968,19 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all() + remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all() assert len(remaining_files) == 0 # Check that all metadata was deleted - remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() assert len(remaining_metadata) == 0 # Verify that storage.delete was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 8004175b2d..caa5ee3851 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -24,16 +24,15 @@ class TestCreateSegmentToIndexTask: @pytest.fixture(autouse=True) def cleanup_database(self, db_session_with_containers): """Clean up database and Redis before each test to ensure isolation.""" - from extensions.ext_database import db - # Clear all test data - db.session.query(DocumentSegment).delete() - db.session.query(Document).delete() - db.session.query(Dataset).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + # Clear all test data using fixture session + db_session_with_containers.query(DocumentSegment).delete() + db_session_with_containers.query(Document).delete() + db_session_with_containers.query(Dataset).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() # Clear Redis cache redis_client.flushdb() @@ -73,10 +72,8 @@ class TestCreateSegmentToIndexTask: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -84,8 +81,8 @@ class TestCreateSegmentToIndexTask: status="normal", plan="basic", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -94,8 +91,8 @@ class TestCreateSegmentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -746,20 +743,9 @@ class TestCreateSegmentToIndexTask: db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" ) - # Mock global database session to simulate transaction issues - from extensions.ext_database import db - - original_commit = db.session.commit - commit_called = False - - def mock_commit(): - nonlocal commit_called - if not commit_called: - commit_called = True - raise Exception("Database commit failed") - return original_commit() - - db.session.commit = mock_commit + # Simulate an error during indexing to trigger rollback path + mock_processor = mock_external_service_dependencies["index_processor"] + mock_processor.load.side_effect = Exception("Simulated indexing error") # Act: Execute the task create_segment_to_index_task(segment.id) @@ -771,9 +757,6 @@ class TestCreateSegmentToIndexTask: assert segment.disabled_at is not None assert segment.error is not None - # Restore original commit method - db.session.commit = original_commit - def test_create_segment_to_index_metadata_validation( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 0b36e0914a..56b53a24b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -70,11 +70,9 @@ class TestDisableSegmentsFromIndexTask: tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at - from extensions.ext_database import db - - db.session.add(tenant) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.add(account) + db_session_with_containers.commit() # Set the current tenant for the account account.current_tenant = tenant @@ -110,10 +108,8 @@ class TestDisableSegmentsFromIndexTask: built_in_field_enabled=False, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @@ -158,10 +154,8 @@ class TestDisableSegmentsFromIndexTask: document.archived = False document.doc_form = "text_model" # Use text_model form for testing document.doc_language = "en" - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document @@ -211,11 +205,9 @@ class TestDisableSegmentsFromIndexTask: segments.append(segment) - from extensions.ext_database import db - for segment in segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segments @@ -645,15 +637,12 @@ class TestDisableSegmentsFromIndexTask: with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: mock_redis.delete.return_value = True - # Mock db.session.close to verify it's called - with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close: - # Act - result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) - # Assert - assert result is None # Task should complete without returning a value - # Verify session was closed - mock_close.assert_called() + # Assert + assert result is None # Task should complete without returning a value + # Session lifecycle is managed by context manager; no explicit close assertion def test_disable_segments_empty_segment_ids(self, db_session_with_containers): """ diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index c015d7ec9c..0d266e7e76 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -6,7 +6,6 @@ from faker import Faker from core.entities.document_task import DocumentTask from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from tasks.document_indexing_task import ( @@ -75,15 +74,15 @@ class TestDocumentIndexingTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -92,8 +91,8 @@ class TestDocumentIndexingTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -105,8 +104,8 @@ class TestDocumentIndexingTasks: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create documents documents = [] @@ -124,13 +123,13 @@ class TestDocumentIndexingTasks: indexing_status="waiting", enabled=True, ) - db.session.add(document) + db_session_with_containers.add(document) documents.append(document) - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure it's properly loaded - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, documents @@ -157,15 +156,15 @@ class TestDocumentIndexingTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -174,8 +173,8 @@ class TestDocumentIndexingTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -187,8 +186,8 @@ class TestDocumentIndexingTasks: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create documents documents = [] @@ -206,10 +205,10 @@ class TestDocumentIndexingTasks: indexing_status="waiting", enabled=True, ) - db.session.add(document) + db_session_with_containers.add(document) documents.append(document) - db.session.commit() + db_session_with_containers.commit() # Configure billing features mock_external_service_dependencies["features"].billing.enabled = billing_enabled @@ -219,7 +218,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["features"].vector_space.size = 50 # Refresh dataset to ensure it's properly loaded - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, documents @@ -242,6 +241,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify the expected outcomes # Verify indexing runner was called correctly mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -250,7 +252,7 @@ class TestDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -310,6 +312,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task with mixed document IDs _document_indexing(dataset.id, all_document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify only existing documents were processed mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -317,7 +322,7 @@ class TestDocumentIndexingTasks: # Verify only existing documents were updated # Re-query documents from database since _document_indexing uses a different session for doc_id in existing_document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -353,6 +358,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -361,7 +369,7 @@ class TestDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _document_indexing close the session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -400,7 +408,7 @@ class TestDocumentIndexingTasks: indexing_status="completed", # Already completed enabled=True, ) - db.session.add(doc1) + db_session_with_containers.add(doc1) extra_documents.append(doc1) # Document with disabled status @@ -417,10 +425,10 @@ class TestDocumentIndexingTasks: indexing_status="waiting", enabled=False, # Disabled ) - db.session.add(doc2) + db_session_with_containers.add(doc2) extra_documents.append(doc2) - db.session.commit() + db_session_with_containers.commit() all_documents = base_documents + extra_documents document_ids = [doc.id for doc in all_documents] @@ -428,6 +436,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task with mixed document states _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify processing mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -435,7 +446,7 @@ class TestDocumentIndexingTasks: # Verify all documents were updated to parsing status # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -482,20 +493,23 @@ class TestDocumentIndexingTasks: indexing_status="waiting", enabled=True, ) - db.session.add(document) + db_session_with_containers.add(document) extra_documents.append(document) - db.session.commit() + db_session_with_containers.commit() all_documents = documents + extra_documents document_ids = [doc.id for doc in all_documents] # Act: Execute the task with too many documents for sandbox plan _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify error handling # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "error" assert updated_document.error is not None assert "batch upload" in updated_document.error @@ -526,6 +540,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task with billing disabled _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify successful processing mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -533,7 +550,7 @@ class TestDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -565,6 +582,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -573,7 +593,7 @@ class TestDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -674,6 +694,9 @@ class TestDocumentIndexingTasks: # Act: Execute the wrapper function _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify core processing occurred (same as _document_indexing) mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -681,7 +704,7 @@ class TestDocumentIndexingTasks: # Verify documents were updated (same as _document_indexing) # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -794,6 +817,9 @@ class TestDocumentIndexingTasks: # Act: Execute the wrapper function _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify error was handled gracefully # The function should not raise exceptions mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -802,7 +828,7 @@ class TestDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -865,6 +891,9 @@ class TestDocumentIndexingTasks: # Act: Execute the wrapper function for tenant1 only _document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify core processing occurred for tenant1 mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index aca4be1ffd..fbcee899e1 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,7 +4,6 @@ import pytest from faker import Faker from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from tasks.duplicate_document_indexing_task import ( @@ -82,15 +81,15 @@ class TestDuplicateDocumentIndexingTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -99,8 +98,8 @@ class TestDuplicateDocumentIndexingTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -112,8 +111,8 @@ class TestDuplicateDocumentIndexingTasks: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create documents documents = [] @@ -132,13 +131,13 @@ class TestDuplicateDocumentIndexingTasks: enabled=True, doc_form="text_model", ) - db.session.add(document) + db_session_with_containers.add(document) documents.append(document) - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure it's properly loaded - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, documents @@ -183,14 +182,14 @@ class TestDuplicateDocumentIndexingTasks: indexing_at=fake.date_time_this_year(), created_by=dataset.created_by, # Add required field ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() # Refresh to ensure all relationships are loaded for document in documents: - db.session.refresh(document) + db_session_with_containers.refresh(document) return dataset, documents, segments @@ -217,15 +216,15 @@ class TestDuplicateDocumentIndexingTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -234,8 +233,8 @@ class TestDuplicateDocumentIndexingTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -247,8 +246,8 @@ class TestDuplicateDocumentIndexingTasks: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create documents documents = [] @@ -267,10 +266,10 @@ class TestDuplicateDocumentIndexingTasks: enabled=True, doc_form="text_model", ) - db.session.add(document) + db_session_with_containers.add(document) documents.append(document) - db.session.commit() + db_session_with_containers.commit() # Configure billing features mock_external_service_dependencies["features"].billing.enabled = billing_enabled @@ -280,7 +279,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["features"].vector_space.size = 50 # Refresh dataset to ensure it's properly loaded - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, documents @@ -305,6 +304,9 @@ class TestDuplicateDocumentIndexingTasks: # Act: Execute the task _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify the expected outcomes # Verify indexing runner was called correctly mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -313,7 +315,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -340,23 +342,32 @@ class TestDuplicateDocumentIndexingTasks: db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3 ) document_ids = [doc.id for doc in documents] + segment_ids = [seg.id for seg in segments] # Act: Execute the task _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + + # Assert: Verify segment cleanup + db_session_with_containers.expire_all() + # Assert: Verify segment cleanup # Verify index processor clean was called for each document with segments assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents) # Verify segments were deleted from database - # Re-query segments from database since _duplicate_document_indexing_task uses a different session - for segment in segments: - deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + # Re-query segments from database using captured IDs to avoid stale ORM instances + for seg_id in segment_ids: + deleted_segment = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first() + ) assert deleted_segment is None # Verify documents were updated to parsing status for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -415,6 +426,9 @@ class TestDuplicateDocumentIndexingTasks: # Act: Execute the task with mixed document IDs _duplicate_document_indexing_task(dataset.id, all_document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify only existing documents were processed mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -422,7 +436,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify only existing documents were updated # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -458,6 +472,9 @@ class TestDuplicateDocumentIndexingTasks: # Act: Execute the task _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -466,7 +483,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -508,20 +525,23 @@ class TestDuplicateDocumentIndexingTasks: enabled=True, doc_form="text_model", ) - db.session.add(document) + db_session_with_containers.add(document) extra_documents.append(document) - db.session.commit() + db_session_with_containers.commit() all_documents = documents + extra_documents document_ids = [doc.id for doc in all_documents] # Act: Execute the task with too many documents for sandbox plan _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "error" assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() @@ -557,10 +577,13 @@ class TestDuplicateDocumentIndexingTasks: # Act: Execute the task with documents that will exceed vector space limit _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "error" assert updated_document.error is not None assert "limit" in updated_document.error.lower() @@ -620,11 +643,11 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Clear session cache to see database updates from task's session - db.session.expire_all() + db_session_with_containers.expire_all() # Verify documents were processed for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") @@ -663,11 +686,11 @@ class TestDuplicateDocumentIndexingTasks: mock_queue.delete_task_key.assert_called_once() # Clear session cache to see database updates from task's session - db.session.expire_all() + db_session_with_containers.expire_all() # Verify documents were processed for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") @@ -707,11 +730,11 @@ class TestDuplicateDocumentIndexingTasks: mock_queue.delete_task_key.assert_called_once() # Clear session cache to see database updates from task's session - db.session.expire_all() + db_session_with_containers.expire_all() # Verify documents were processed for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index e29b98037f..b9977b1fb6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -165,7 +165,7 @@ class TestRagPipelineRunTasks: "files": [], "user_id": account.id, "stream": False, - "invoke_from": "published", + "invoke_from": InvokeFrom.PUBLISHED_PIPELINE.value, "workflow_execution_id": str(uuid.uuid4()), "pipeline_config": { "app_id": str(uuid.uuid4()), @@ -249,7 +249,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) @@ -294,7 +294,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) @@ -743,7 +743,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index c764801170..ddb079f00c 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -12,10 +12,12 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): _, Jinja2TemplateTransformer = self.jinja2_imports template = "Hello {{template}}" + # Template must be base64 encoded to match the new safe embedding approach + template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") code = ( Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64) .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) ) result = CodeExecutor.execute_code( @@ -37,6 +39,34 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): _, Jinja2TemplateTransformer = self.jinja2_imports runner_script = Jinja2TemplateTransformer.get_runner_script() - assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1 assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 + + def test_jinja2_template_with_special_characters(self, flask_app_with_containers): + """ + Test that templates with special characters (quotes, newlines) render correctly. + This is a regression test for issue #26818 where textarea pre-fill values + containing special characters would break template rendering. + """ + CodeExecutor, CodeLanguage = self.code_executor_imports + + # Template with triple quotes, single quotes, double quotes, and newlines + template = """ + + + +

Status: "{{ status }}"

+
'''code block'''
+ +""" + inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"} + + result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) + + # Verify the template rendered correctly with all special characters + output = result["result"] + assert 'value="TASK-123"' in output + assert "" in output + assert 'Status: "completed"' in output + assert "'''code block'''" in output diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 2e9497ff7f..8e65ea276d 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -16,6 +16,7 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -51,6 +52,7 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch): os.environ.clear() # Set minimal required env vars + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -75,6 +77,7 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -124,6 +127,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -140,6 +144,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): """Test that DB_EXTRAS options are properly merged with default timezone setting""" # Set environment variables + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -255,6 +260,7 @@ def test_celery_broker_url_with_special_chars_password( # Set up basic required environment variables (following existing pattern) monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") diff --git a/api/tests/unit_tests/controllers/__init__.py b/api/tests/unit_tests/controllers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/common/test_fields.py b/api/tests/unit_tests/controllers/common/test_fields.py new file mode 100644 index 0000000000..d4dc13127d --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_fields.py @@ -0,0 +1,69 @@ +import builtins +from types import SimpleNamespace +from unittest.mock import patch + +from flask.views import MethodView as FlaskMethodView + +_NEEDS_METHOD_VIEW_CLEANUP = False +if not hasattr(builtins, "MethodView"): + builtins.MethodView = FlaskMethodView + _NEEDS_METHOD_VIEW_CLEANUP = True +from controllers.common.fields import Parameters, Site +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict +from models.model import IconType + + +def test_parameters_model_round_trip(): + parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[]) + + model = Parameters.model_validate(parameters) + + assert model.model_dump(mode="json") == parameters + + +def test_site_icon_url_uses_signed_url_for_image_icon(): + site = SimpleNamespace( + title="Example", + chat_color_theme=None, + chat_color_theme_inverted=False, + icon_type=IconType.IMAGE, + icon="file-id", + icon_background=None, + description=None, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + default_language="en-US", + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper: + model = Site.model_validate(site) + + assert model.icon_url == "signed" + mock_helper.assert_called_once_with("file-id") + + +def test_site_icon_url_is_none_for_non_image_icon(): + site = SimpleNamespace( + title="Example", + chat_color_theme=None, + chat_color_theme_inverted=False, + icon_type=IconType.EMOJI, + icon="file-id", + icon_background=None, + description=None, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + default_language="en-US", + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper: + model = Site.model_validate(site) + + assert model.icon_url is None + mock_helper.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/__init__.py b/api/tests/unit_tests/controllers/console/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py new file mode 100644 index 0000000000..40eb59a8f4 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import builtins +import sys +from datetime import datetime +from importlib import util +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import pytest +from flask.views import MethodView + +# kombu references MethodView as a global when importing celery/kombu pools. +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +def _load_app_module(): + module_name = "controllers.console.app.app" + if module_name in sys.modules: + return sys.modules[module_name] + + root = Path(__file__).resolve().parents[5] + module_path = root / "controllers" / "console" / "app" / "app.py" + + class _StubNamespace: + def __init__(self): + self.models: dict[str, Any] = {} + self.payload = None + + def schema_model(self, name, schema): + self.models[name] = schema + + def _decorator(self, obj): + return obj + + def doc(self, *args, **kwargs): + return self._decorator + + def expect(self, *args, **kwargs): + return self._decorator + + def response(self, *args, **kwargs): + return self._decorator + + def route(self, *args, **kwargs): + def decorator(obj): + return obj + + return decorator + + stub_namespace = _StubNamespace() + + original_console = sys.modules.get("controllers.console") + original_app_pkg = sys.modules.get("controllers.console.app") + stubbed_modules: list[tuple[str, ModuleType | None]] = [] + + console_module = ModuleType("controllers.console") + console_module.__path__ = [str(root / "controllers" / "console")] + console_module.console_ns = stub_namespace + console_module.api = None + console_module.bp = None + sys.modules["controllers.console"] = console_module + + app_package = ModuleType("controllers.console.app") + app_package.__path__ = [str(root / "controllers" / "console" / "app")] + sys.modules["controllers.console.app"] = app_package + console_module.app = app_package + + def _stub_module(name: str, attrs: dict[str, Any]): + original = sys.modules.get(name) + module = ModuleType(name) + for key, value in attrs.items(): + setattr(module, key, value) + sys.modules[name] = module + stubbed_modules.append((name, original)) + + class _OpsTraceManager: + @staticmethod + def get_app_tracing_config(app_id: str) -> dict[str, Any]: + return {} + + @staticmethod + def update_app_tracing_config(app_id: str, **kwargs) -> None: + return None + + _stub_module( + "core.ops.ops_trace_manager", + { + "OpsTraceManager": _OpsTraceManager, + "TraceQueueManager": object, + "TraceTask": object, + }, + ) + + spec = util.spec_from_file_location(module_name, module_path) + module = util.module_from_spec(spec) + sys.modules[module_name] = module + + try: + assert spec.loader is not None + spec.loader.exec_module(module) + finally: + for name, original in reversed(stubbed_modules): + if original is not None: + sys.modules[name] = original + else: + sys.modules.pop(name, None) + if original_console is not None: + sys.modules["controllers.console"] = original_console + else: + sys.modules.pop("controllers.console", None) + if original_app_pkg is not None: + sys.modules["controllers.console.app"] = original_app_pkg + else: + sys.modules.pop("controllers.console.app", None) + + return module + + +_app_module = _load_app_module() +AppDetailWithSite = _app_module.AppDetailWithSite +AppPagination = _app_module.AppPagination +AppPartial = _app_module.AppPartial + + +@pytest.fixture(autouse=True) +def patch_signed_url(monkeypatch): + """Ensure icon URL generation uses a deterministic helper for tests.""" + + def _fake_signed_url(key: str | None) -> str | None: + if not key: + return None + return f"signed:{key}" + + monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + + +def _ts(hour: int = 12) -> datetime: + return datetime(2024, 1, 1, hour, 0, 0) + + +def _dummy_model_config(): + return SimpleNamespace( + model_dict={"provider": "openai", "name": "gpt-4o"}, + pre_prompt="hello", + created_by="config-author", + created_at=_ts(9), + updated_by="config-editor", + updated_at=_ts(10), + ) + + +def _dummy_workflow(): + return SimpleNamespace( + id="wf-1", + created_by="workflow-author", + created_at=_ts(8), + updated_by="workflow-editor", + updated_at=_ts(9), + ) + + +def test_app_partial_serialization_uses_aliases(): + created_at = _ts() + app_obj = SimpleNamespace( + id="app-1", + name="My App", + desc_or_prompt="Prompt snippet", + mode_compatible_with_agent="chat", + icon_type="image", + icon="icon-key", + icon_background="#fff", + app_model_config=_dummy_model_config(), + workflow=_dummy_workflow(), + created_by="creator", + created_at=created_at, + updated_by="editor", + updated_at=created_at, + tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")], + access_mode="private", + create_user_name="Creator", + author_name="Author", + has_draft_trigger=True, + ) + + serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["description"] == "Prompt snippet" + assert serialized["mode"] == "chat" + assert serialized["icon_url"] == "signed:icon-key" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["updated_at"] == int(created_at.timestamp()) + assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"} + assert serialized["workflow"]["id"] == "wf-1" + assert serialized["tags"][0]["name"] == "Utilities" + + +def test_app_detail_with_site_includes_nested_serialization(): + timestamp = _ts(14) + site = SimpleNamespace( + code="site-code", + title="Public Site", + icon_type="image", + icon="site-icon", + created_at=timestamp, + updated_at=timestamp, + ) + app_obj = SimpleNamespace( + id="app-2", + name="Detailed App", + description="Desc", + mode_compatible_with_agent="advanced-chat", + icon_type="image", + icon="detail-icon", + icon_background="#123456", + enable_site=True, + enable_api=True, + app_model_config={ + "opening_statement": "hi", + "model": {"provider": "openai", "name": "gpt-4o"}, + "retriever_resource": {"enabled": True}, + }, + workflow=_dummy_workflow(), + tracing={"enabled": True}, + use_icon_as_answer_icon=True, + created_by="creator", + created_at=timestamp, + updated_by="editor", + updated_at=timestamp, + access_mode="public", + tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")], + api_base_url="https://api.example.com/v1", + max_active_requests=5, + deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}], + site=site, + ) + + serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["icon_url"] == "signed:detail-icon" + assert serialized["model_config"]["retriever_resource"] == {"enabled": True} + assert serialized["deleted_tools"][0]["tool_name"] == "search" + assert serialized["site"]["icon_url"] == "signed:site-icon" + assert serialized["site"]["created_at"] == int(timestamp.timestamp()) + + +def test_app_pagination_aliases_per_page_and_has_next(): + item_one = SimpleNamespace( + id="app-10", + name="Paginated One", + desc_or_prompt="Summary", + mode_compatible_with_agent="chat", + icon_type="image", + icon="first-icon", + created_at=_ts(15), + updated_at=_ts(15), + ) + item_two = SimpleNamespace( + id="app-11", + name="Paginated Two", + desc_or_prompt="Summary", + mode_compatible_with_agent="agent-chat", + icon_type="emoji", + icon="🙂", + created_at=_ts(16), + updated_at=_ts(16), + ) + pagination = SimpleNamespace( + page=2, + per_page=10, + total=50, + has_next=True, + items=[item_one, item_two], + ) + + serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json") + + assert serialized["page"] == 2 + assert serialized["limit"] == 10 + assert serialized["has_more"] is True + assert len(serialized["data"]) == 2 + assert serialized["data"][0]["icon_url"] == "signed:first-icon" + assert serialized["data"][1]["icon_url"] is None diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index da21e0e358..d3e864a75a 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -40,7 +40,7 @@ class TestActivateCheckApi: "tenant": tenant, } - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation): """ Test checking valid invitation token. @@ -66,7 +66,7 @@ class TestActivateCheckApi: assert response["data"]["workspace_id"] == "workspace-123" assert response["data"]["email"] == "invitee@example.com" - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_invalid_invitation_token(self, mock_get_invitation, app): """ Test checking invalid invitation token. @@ -88,7 +88,7 @@ class TestActivateCheckApi: # Assert assert response["is_valid"] is False - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation): """ Test checking token without workspace ID. @@ -109,7 +109,7 @@ class TestActivateCheckApi: assert response["is_valid"] is True mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token") - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation): """ Test checking token without email parameter. @@ -130,6 +130,20 @@ class TestActivateCheckApi: assert response["is_valid"] is True mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") + def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation): + """Ensure token validation uses lowercase emails.""" + mock_get_invitation.return_value = mock_invitation + + with app.test_request_context( + "/activate/check?workspace_id=workspace-123&email=Invitee@Example.com&token=valid_token" + ): + api = ActivateCheckApi() + response = api.get() + + assert response["is_valid"] is True + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + class TestActivateApi: """Test cases for account activation endpoint.""" @@ -212,7 +226,7 @@ class TestActivateApi: mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") mock_db.session.commit.assert_called_once() - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_activation_with_invalid_token(self, mock_get_invitation, app): """ Test account activation with invalid token. @@ -241,7 +255,7 @@ class TestActivateApi: with pytest.raises(AlreadyActivateError): api.post() - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_sets_interface_theme( @@ -290,7 +304,7 @@ class TestActivateApi: ("es-ES", "Europe/Madrid"), ], ) - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_with_different_locales( @@ -336,7 +350,7 @@ class TestActivateApi: assert mock_account.interface_language == language assert mock_account.timezone == timezone - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_returns_success_response( @@ -376,7 +390,7 @@ class TestActivateApi: # Assert assert response == {"result": "success"} - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_without_workspace_id( @@ -415,3 +429,37 @@ class TestActivateApi: # Assert assert response["result"] == "success" mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token") + + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") + @patch("controllers.console.auth.activate.RegisterService.revoke_token") + @patch("controllers.console.auth.activate.db") + def test_activation_normalizes_email_before_lookup( + self, + mock_db, + mock_revoke_token, + mock_get_invitation, + app, + mock_invitation, + mock_account, + ): + """Ensure uppercase emails are normalized before lookup and revocation.""" + mock_get_invitation.return_value = mock_invitation + + with app.test_request_context( + "/activate", + method="POST", + json={ + "workspace_id": "workspace-123", + "email": "Invitee@Example.com", + "token": "valid_token", + "name": "John Doe", + "interface_language": "en-US", + "timezone": "UTC", + }, + ): + api = ActivateApi() + response = api.post() + + assert response["result"] == "success" + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index eb21920117..cb4fe40944 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -34,7 +34,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invalid_email_with_registration_allowed( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): @@ -67,7 +67,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_wrong_password_returns_error( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db ): @@ -100,7 +100,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invalid_email_with_registration_disabled( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/unit_tests/controllers/console/auth/test_email_register.py new file mode 100644 index 0000000000..724c80f18c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register.py @@ -0,0 +1,177 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.email_register import ( + EmailRegisterCheckApi, + EmailRegisterResetApi, + EmailRegisterSendEmailApi, +) +from services.account_service import AccountService + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class TestEmailRegisterSendEmailApi: + @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") + @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") + @patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_send_email_normalizes_and_falls_back( + self, + mock_extract_ip, + mock_is_email_send_ip_limit, + mock_is_freeze, + mock_send_mail, + mock_get_account, + mock_session_cls, + app, + ): + mock_send_mail.return_value = "token-123" + mock_is_freeze.return_value = False + mock_account = MagicMock() + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_get_account.return_value = mock_account + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register/send-email", + method="POST", + json={"email": "Invitee@Example.com", "language": "en-US"}, + ): + response = EmailRegisterSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_is_freeze.assert_called_once_with("invitee@example.com") + mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") + mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) + mock_extract_ip.assert_called_once() + mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") + + +class TestEmailRegisterCheckApi: + @patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.generate_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit") + def test_validity_normalizes_email_before_checks( + self, + mock_rate_limit_check, + mock_get_data, + mock_add_rate, + mock_revoke, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_rate_limit_check.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"} + mock_generate_token.return_value = (None, "new-token") + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register/validity", + method="POST", + json={"email": "User@Example.com", "code": "4321", "token": "token-123"}, + ): + response = EmailRegisterCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_rate_limit_check.assert_called_once_with("user@example.com") + mock_generate_token.assert_called_once_with( + "user@example.com", code="4321", additional_data={"phase": "register"} + ) + mock_reset_rate.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke.assert_called_once_with("token-123") + + +class TestEmailRegisterResetApi: + @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.login") + @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") + @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_reset_creates_account_with_normalized_email( + self, + mock_extract_ip, + mock_get_data, + mock_revoke_token, + mock_get_account, + mock_session_cls, + mock_create_account, + mock_login, + mock_reset_login_rate, + app, + ): + mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} + mock_create_account.return_value = MagicMock() + token_pair = MagicMock() + token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} + mock_login.return_value = token_pair + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_get_account.return_value = None + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register", + method="POST", + json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"}, + ): + response = EmailRegisterResetApi().post() + + assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} + mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!") + mock_reset_login_rate.assert_called_once_with("invitee@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_extract_ip.assert_called_once() + mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + mock_session = MagicMock() + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_query, second_query] + + account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + + assert account is expected_account + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py new file mode 100644 index 0000000000..8403777dc9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py @@ -0,0 +1,176 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.forgot_password import ( + ForgotPasswordCheckApi, + ForgotPasswordResetApi, + ForgotPasswordSendEmailApi, +) +from services.account_service import AccountService + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class TestForgotPasswordSendEmailApi: + @patch("controllers.console.auth.forgot_password.Session") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") + @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1") + def test_send_normalizes_email( + self, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_account, + mock_session_cls, + app, + ): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_send_email.return_value = "token-123" + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + controller_features = SimpleNamespace(is_allow_register=True) + with ( + patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), + patch( + "controllers.console.auth.forgot_password.FeatureService.get_system_features", + return_value=controller_features, + ), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_send_email.assert_called_once_with( + account=mock_account, + email="user@example.com", + language="zh-Hans", + is_allow_register=True, + ) + mock_is_ip_limit.assert_called_once_with("127.0.0.1") + mock_extract_ip.assert_called_once() + + +class TestForgotPasswordCheckApi: + @patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_check_normalizes_email( + self, + mock_rate_limit_check, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_rate_limit_check.return_value = False + mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"} + mock_generate_token.return_value = (None, "new-token") + + wraps_features = SimpleNamespace(enable_email_password_login=True) + with ( + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password/validity", + method="POST", + json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"} + mock_rate_limit_check.assert_called_once_with("admin@example.com") + mock_generate_token.assert_called_once_with( + "Admin@Example.com", + code="4321", + additional_data={"phase": "reset"}, + ) + mock_reset_rate.assert_called_once_with("admin@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + + +class TestForgotPasswordResetApi: + @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") + @patch("controllers.console.auth.forgot_password.Session") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + def test_reset_fetches_account_with_original_email( + self, + mock_get_reset_data, + mock_revoke_token, + mock_get_account, + mock_session_cls, + mock_update_account, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} + mock_account = MagicMock() + mock_get_account.return_value = mock_account + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + wraps_features = SimpleNamespace(enable_email_password_login=True) + with ( + patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_reset_data.assert_called_once_with("token-123") + mock_revoke_token.assert_called_once_with("token-123") + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_update_account.assert_called_once() + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + mock_session = MagicMock() + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_query, second_query] + + account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + + assert account is expected_account + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 3a2cf7bad7..560971206f 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -76,7 +76,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.AccountService.login") @@ -120,7 +120,7 @@ class TestLoginApi: response = login_api.post() # Assert - mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!") + mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None) mock_login.assert_called_once() mock_reset_rate_limit.assert_called_once_with("test@example.com") assert response.json["result"] == "success" @@ -128,7 +128,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.AccountService.login") @@ -182,7 +182,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): """ Test login rejection when rate limit is exceeded. @@ -230,7 +230,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") def test_login_fails_with_invalid_credentials( @@ -269,7 +269,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app @@ -298,7 +298,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.FeatureService.get_system_features") @@ -343,7 +343,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): """ Test login failure when invitation email doesn't match login email. @@ -371,6 +371,52 @@ class TestLoginApi: with pytest.raises(InvalidEmailError): login_api.post() + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) + @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") + @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") + @patch("controllers.console.auth.login.TenantService.get_join_tenants") + @patch("controllers.console.auth.login.AccountService.login") + @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") + def test_login_retries_with_lowercase_email( + self, + mock_reset_rate_limit, + mock_login_service, + mock_get_tenants, + mock_add_rate_limit, + mock_authenticate, + mock_get_invitation, + mock_is_rate_limit, + mock_db, + app, + mock_account, + mock_token_pair, + ): + """Test that login retries with lowercase email when uppercase lookup fails.""" + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_is_rate_limit.return_value = False + mock_get_invitation.return_value = None + mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account] + mock_get_tenants.return_value = [MagicMock()] + mock_login_service.return_value = mock_token_pair + + with app.test_request_context( + "/login", + method="POST", + json={"email": "Upper@Example.com", "password": encode_password("ValidPass123!")}, + ): + response = LoginApi().post() + + assert response.json["result"] == "success" + assert mock_authenticate.call_args_list == [ + (("Upper@Example.com", "ValidPass123!", None), {}), + (("upper@example.com", "ValidPass123!", None), {}), + ] + mock_add_rate_limit.assert_not_called() + mock_reset_rate_limit.assert_called_once_with("upper@example.com") + class TestLogoutApi: """Test cases for the LogoutApi endpoint.""" diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 399caf8c4d..6345c2ab23 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -12,6 +12,7 @@ from controllers.console.auth.oauth import ( ) from libs.oauth import OAuthUserInfo from models.account import AccountStatus +from services.account_service import AccountService from services.errors.account import AccountRegisterError @@ -171,7 +172,7 @@ class TestOAuthCallback: ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_get_providers.return_value = {"github": oauth_setup["provider"]} - mock_generate_account.return_value = oauth_setup["account"] + mock_generate_account.return_value = (oauth_setup["account"], True) mock_account_service.login.return_value = oauth_setup["token_pair"] with app.test_request_context("/auth/oauth/github/callback?code=test_code"): @@ -179,7 +180,7 @@ class TestOAuthCallback: oauth_setup["provider"].get_access_token.assert_called_once_with("test_code") oauth_setup["provider"].get_user_info.assert_called_once_with("access_token") - mock_redirect.assert_called_once_with("http://localhost:3000") + mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=true") @pytest.mark.parametrize( ("exception", "expected_error"), @@ -215,6 +216,34 @@ class TestOAuthCallback: assert status_code == 400 assert response["error"] == expected_error + @patch("controllers.console.auth.oauth.dify_config") + @patch("controllers.console.auth.oauth.get_oauth_providers") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.redirect") + def test_invitation_comparison_is_case_insensitive( + self, + mock_redirect, + mock_register_service, + mock_get_providers, + mock_config, + resource, + app, + oauth_setup, + ): + mock_config.CONSOLE_WEB_URL = "http://localhost:3000" + oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo( + id="123", name="Test User", email="User@Example.com" + ) + mock_get_providers.return_value = {"github": oauth_setup["provider"]} + mock_register_service.is_valid_invite_token.return_value = True + mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"} + + with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"): + resource.get("github") + + mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123") + mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123") + @pytest.mark.parametrize( ("account_status", "expected_redirect"), [ @@ -223,7 +252,7 @@ class TestOAuthCallback: # This documents actual behavior. See test_defensive_check_for_closed_account_status for details ( AccountStatus.CLOSED.value, - "http://localhost:3000", + "http://localhost:3000?oauth_new_user=false", ), ], ) @@ -260,7 +289,7 @@ class TestOAuthCallback: account = MagicMock() account.status = account_status account.id = "123" - mock_generate_account.return_value = account + mock_generate_account.return_value = (account, False) # Mock login for CLOSED status mock_token_pair = MagicMock() @@ -296,7 +325,7 @@ class TestOAuthCallback: mock_account = MagicMock() mock_account.status = AccountStatus.PENDING - mock_generate_account.return_value = mock_account + mock_generate_account.return_value = (mock_account, False) mock_token_pair = MagicMock() mock_token_pair.access_token = "jwt_access_token" @@ -360,7 +389,7 @@ class TestOAuthCallback: closed_account.status = AccountStatus.CLOSED closed_account.id = "123" closed_account.name = "Closed Account" - mock_generate_account.return_value = closed_account + mock_generate_account.return_value = (closed_account, False) # Mock successful login (current behavior) mock_token_pair = MagicMock() @@ -374,7 +403,7 @@ class TestOAuthCallback: resource.get("github") # Verify current behavior: login succeeds (this is NOT ideal) - mock_redirect.assert_called_once_with("http://localhost:3000") + mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=false") mock_account_service.login.assert_called_once() # Document expected behavior in comments: @@ -395,12 +424,12 @@ class TestAccountGeneration: account.name = "Test User" return account - @patch("controllers.console.auth.oauth.db") - @patch("controllers.console.auth.oauth.Account") + @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.oauth.Session") - @patch("controllers.console.auth.oauth.select") + @patch("controllers.console.auth.oauth.Account") + @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account + self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account ): # Mock db.engine for Session creation mock_db.engine = MagicMock() @@ -410,15 +439,31 @@ class TestAccountGeneration: result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account mock_account_model.get_by_openid.assert_called_once_with("github", "123") + mock_get_account.assert_not_called() - # Test fallback to email + # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account + mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) + + def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): + mock_session = MagicMock() + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] + + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + + assert result == expected_account + assert mock_session.execute.call_count == 2 @pytest.mark.parametrize( ("allow_register", "existing_account", "should_create"), @@ -458,13 +503,43 @@ class TestAccountGeneration: with pytest.raises(AccountRegisterError): _generate_account("github", user_info) else: - result = _generate_account("github", user_info) + result, oauth_new_user = _generate_account("github", user_info) assert result == mock_account + assert oauth_new_user == should_create if should_create: mock_register_service.register.assert_called_once_with( email="test@example.com", name="Test User", password=None, open_id="123", provider="github" ) + else: + mock_register_service.register.assert_not_called() + + @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) + @patch("controllers.console.auth.oauth.FeatureService") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.AccountService") + @patch("controllers.console.auth.oauth.TenantService") + @patch("controllers.console.auth.oauth.db") + def test_should_register_with_lowercase_email( + self, + mock_db, + mock_tenant_service, + mock_account_service, + mock_register_service, + mock_feature_service, + mock_get_account, + app, + ): + user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com") + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_register_service.register.return_value = MagicMock() + + with app.test_request_context(headers={"Accept-Language": "en-US"}): + _generate_account("github", user_info) + + mock_register_service.register.assert_called_once_with( + email="upper@example.com", name="Test User", password=None, open_id="123", provider="github" + ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email") @patch("controllers.console.auth.oauth.TenantService") @@ -490,9 +565,10 @@ class TestAccountGeneration: mock_tenant_service.create_tenant.return_value = mock_new_tenant with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): - result = _generate_account("github", user_info) + result, oauth_new_user = _generate_account("github", user_info) assert result == mock_account + assert oauth_new_user is False mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace") mock_tenant_service.create_tenant_member.assert_called_once_with( mock_new_tenant, mock_account, role="owner" diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/unit_tests/controllers/console/auth/test_password_reset.py index f584952a00..9488cf528e 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/unit_tests/controllers/console/auth/test_password_reset.py @@ -28,6 +28,22 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError +@pytest.fixture(autouse=True) +def _mock_forgot_password_session(): + with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_session_cls.return_value.__exit__.return_value = None + yield mock_session + + +@pytest.fixture(autouse=True) +def _mock_forgot_password_db(): + with patch("controllers.console.auth.forgot_password.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db + + class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @@ -47,20 +63,16 @@ class TestForgotPasswordSendEmailApi: return account @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.FeatureService.get_system_features") def test_send_reset_email_success( self, mock_get_features, mock_send_email, - mock_select, - mock_session, + mock_get_account, mock_is_ip_limit, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -75,11 +87,8 @@ class TestForgotPasswordSendEmailApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_is_ip_limit.return_value = False - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" mock_get_features.return_value.is_allow_register = True @@ -125,20 +134,16 @@ class TestForgotPasswordSendEmailApi: ], ) @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.FeatureService.get_system_features") def test_send_reset_email_language_handling( self, mock_get_features, mock_send_email, - mock_select, - mock_session, + mock_get_account, mock_is_ip_limit, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -154,11 +159,8 @@ class TestForgotPasswordSendEmailApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_is_ip_limit.return_value = False - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_send_email.return_value = "token" mock_get_features.return_value.is_allow_register = True @@ -229,8 +231,46 @@ class TestForgotPasswordCheckApi: assert response["email"] == "test@example.com" assert response["token"] == "new_token" mock_revoke_token.assert_called_once_with("old_token") + mock_generate_token.assert_called_once_with( + "test@example.com", code="123456", additional_data={"phase": "reset"} + ) mock_reset_rate_limit.assert_called_once_with("test@example.com") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + def test_verify_code_preserves_token_email_case( + self, + mock_reset_rate_limit, + mock_generate_token, + mock_revoke_token, + mock_get_data, + mock_is_rate_limit, + mock_db, + app, + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} + mock_generate_token.return_value = (None, "fresh-token") + + with app.test_request_context( + "/forgot-password/validity", + method="POST", + json={"email": "user@example.com", "code": "999888", "token": "upper_token"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "fresh-token"} + mock_generate_token.assert_called_once_with( + "User@Example.com", code="999888", additional_data={"phase": "reset"} + ) + mock_revoke_token.assert_called_once_with("upper_token") + mock_reset_rate_limit.assert_called_once_with("user@example.com") + @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): @@ -355,20 +395,16 @@ class TestForgotPasswordResetApi: return account @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, - mock_select, - mock_session, + mock_get_account, mock_revoke_token, mock_get_data, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -383,11 +419,8 @@ class TestForgotPasswordResetApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act @@ -475,13 +508,11 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") def test_reset_password_account_not_found( - self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app + self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app ): """ Test password reset for non-existent account. @@ -491,11 +522,8 @@ class TestForgotPasswordResetApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = None # Act & Assert with app.test_request_context( diff --git a/api/tests/unit_tests/controllers/console/datasets/__init__.py b/api/tests/unit_tests/controllers/console/datasets/__init__.py new file mode 100644 index 0000000000..a8d5622e13 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/__init__.py @@ -0,0 +1 @@ +"""Unit tests for `controllers.console.datasets` controllers.""" diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py new file mode 100644 index 0000000000..d5d7ee95c5 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py @@ -0,0 +1,430 @@ +""" +Unit tests for the dataset document download endpoint. + +These tests validate that the controller returns a signed download URL for +upload-file documents, and rejects unsupported or missing file cases. +""" + +from __future__ import annotations + +import importlib +import sys +from collections import UserDict +from io import BytesIO +from types import SimpleNamespace +from typing import Any +from zipfile import ZipFile + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden, NotFound + + +@pytest.fixture +def app() -> Flask: + """Create a minimal Flask app for request-context based controller tests.""" + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +@pytest.fixture +def datasets_document_module(monkeypatch: pytest.MonkeyPatch): + """ + Reload `controllers.console.datasets.datasets_document` with lightweight decorators. + + We patch auth / setup / rate-limit decorators to no-ops so we can unit test the + controller logic without requiring the full console stack. + """ + + from controllers.console import console_ns, wraps + from libs import login + + def _noop(func): # type: ignore[no-untyped-def] + return func + + # Bypass login/setup/account checks in unit tests. + monkeypatch.setattr(login, "login_required", _noop) + monkeypatch.setattr(wraps, "setup_required", _noop) + monkeypatch.setattr(wraps, "account_initialization_required", _noop) + + # Bypass billing-related decorators used by other endpoints in this module. + monkeypatch.setattr(wraps, "cloud_edition_billing_resource_check", lambda *_args, **_kwargs: (lambda f: f)) + monkeypatch.setattr(wraps, "cloud_edition_billing_rate_limit_check", lambda *_args, **_kwargs: (lambda f: f)) + + # Avoid Flask-RESTX route registration side effects during import. + def _noop_route(*_args, **_kwargs): # type: ignore[override] + def _decorator(cls): + return cls + + return _decorator + + monkeypatch.setattr(console_ns, "route", _noop_route) + + module_name = "controllers.console.datasets.datasets_document" + sys.modules.pop(module_name, None) + return importlib.import_module(module_name) + + +def _mock_user(*, is_dataset_editor: bool = True) -> SimpleNamespace: + """Build a minimal user object compatible with dataset permission checks.""" + return SimpleNamespace(is_dataset_editor=is_dataset_editor, id="user-123") + + +def _mock_document( + *, + document_id: str, + tenant_id: str, + data_source_type: str, + upload_file_id: str | None, +) -> SimpleNamespace: + """Build a minimal document object used by the controller.""" + data_source_info_dict: dict[str, Any] | None = None + if upload_file_id is not None: + data_source_info_dict = {"upload_file_id": upload_file_id} + else: + data_source_info_dict = {} + + return SimpleNamespace( + id=document_id, + tenant_id=tenant_id, + data_source_type=data_source_type, + data_source_info_dict=data_source_info_dict, + ) + + +def _wire_common_success_mocks( + *, + module, + monkeypatch: pytest.MonkeyPatch, + current_tenant_id: str, + document_tenant_id: str, + data_source_type: str, + upload_file_id: str | None, + upload_file_exists: bool, + signed_url: str, +) -> None: + """Patch controller dependencies to create a deterministic test environment.""" + import services.dataset_service as dataset_service_module + + # Make `current_account_with_tenant()` return a known user + tenant id. + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (_mock_user(), current_tenant_id)) + + # Return a dataset object and allow permission checks to pass. + monkeypatch.setattr(module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")) + monkeypatch.setattr(module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None) + + # Return a document that will be validated inside DocumentResource.get_document. + document = _mock_document( + document_id="doc-1", + tenant_id=document_tenant_id, + data_source_type=data_source_type, + upload_file_id=upload_file_id, + ) + monkeypatch.setattr(module.DocumentService, "get_document", lambda *_args, **_kwargs: document) + + # Mock UploadFile lookup via FileService batch helper. + upload_files_by_id: dict[str, Any] = {} + if upload_file_exists and upload_file_id is not None: + upload_files_by_id[str(upload_file_id)] = SimpleNamespace(id=str(upload_file_id)) + monkeypatch.setattr(module.FileService, "get_upload_files_by_ids", lambda *_args, **_kwargs: upload_files_by_id) + + # Mock signing helper so the returned URL is deterministic. + monkeypatch.setattr(dataset_service_module.file_helpers, "get_signed_file_url", lambda **_kwargs: signed_url) + + +def _mock_send_file(obj, **kwargs): # type: ignore[no-untyped-def] + """Return a lightweight representation of `send_file(...)` for unit tests.""" + + class _ResponseMock(UserDict): + def __init__(self, sent_file: object, send_file_kwargs: dict[str, object]) -> None: + super().__init__({"_sent_file": sent_file, "_send_file_kwargs": send_file_kwargs}) + self._on_close: object | None = None + + def call_on_close(self, func): # type: ignore[no-untyped-def] + self._on_close = func + return func + + return _ResponseMock(obj, kwargs) + + +def test_batch_download_zip_returns_send_file( + app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure batch ZIP download returns a zip attachment via `send_file`.""" + + # Arrange common permission mocks. + monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123")) + monkeypatch.setattr( + datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") + ) + monkeypatch.setattr( + datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None + ) + + # Two upload-file documents, each referencing an UploadFile. + doc1 = _mock_document( + document_id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-123", + data_source_type="upload_file", + upload_file_id="file-1", + ) + doc2 = _mock_document( + document_id="22222222-2222-2222-2222-222222222222", + tenant_id="tenant-123", + data_source_type="upload_file", + upload_file_id="file-2", + ) + monkeypatch.setattr( + datasets_document_module.DocumentService, + "get_documents_by_ids", + lambda *_args, **_kwargs: [doc1, doc2], + ) + monkeypatch.setattr( + datasets_document_module.FileService, + "get_upload_files_by_ids", + lambda *_args, **_kwargs: { + "file-1": SimpleNamespace(id="file-1", name="a.txt", key="k1"), + "file-2": SimpleNamespace(id="file-2", name="b.txt", key="k2"), + }, + ) + + # Mock storage streaming content. + import services.file_service as file_service_module + + monkeypatch.setattr(file_service_module.storage, "load", lambda _key, stream=True: [b"hello"]) + + # Replace send_file used by the controller to avoid a real Flask response object. + monkeypatch.setattr(datasets_document_module, "send_file", _mock_send_file) + + # Act + with app.test_request_context( + "/datasets/ds-1/documents/download-zip", + method="POST", + json={"document_ids": ["11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222"]}, + ): + api = datasets_document_module.DocumentBatchDownloadZipApi() + result = api.post(dataset_id="ds-1") + + # Assert: we returned via send_file with correct mime type and attachment. + assert result["_send_file_kwargs"]["mimetype"] == "application/zip" + assert result["_send_file_kwargs"]["as_attachment"] is True + assert isinstance(result["_send_file_kwargs"]["download_name"], str) + assert result["_send_file_kwargs"]["download_name"].endswith(".zip") + # Ensure our cleanup hook is registered and execute it to avoid temp file leaks in unit tests. + assert getattr(result, "_on_close", None) is not None + result._on_close() # type: ignore[attr-defined] + + +def test_batch_download_zip_response_is_openable_zip( + app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure the real Flask `send_file` response body is a valid ZIP that can be opened.""" + + # Arrange: same controller mocks as the lightweight send_file test, but we keep the real `send_file`. + monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123")) + monkeypatch.setattr( + datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") + ) + monkeypatch.setattr( + datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None + ) + + doc1 = _mock_document( + document_id="33333333-3333-3333-3333-333333333333", + tenant_id="tenant-123", + data_source_type="upload_file", + upload_file_id="file-1", + ) + doc2 = _mock_document( + document_id="44444444-4444-4444-4444-444444444444", + tenant_id="tenant-123", + data_source_type="upload_file", + upload_file_id="file-2", + ) + monkeypatch.setattr( + datasets_document_module.DocumentService, + "get_documents_by_ids", + lambda *_args, **_kwargs: [doc1, doc2], + ) + monkeypatch.setattr( + datasets_document_module.FileService, + "get_upload_files_by_ids", + lambda *_args, **_kwargs: { + "file-1": SimpleNamespace(id="file-1", name="a.txt", key="k1"), + "file-2": SimpleNamespace(id="file-2", name="b.txt", key="k2"), + }, + ) + + # Stream distinct bytes per key so we can verify both ZIP entries. + import services.file_service as file_service_module + + monkeypatch.setattr( + file_service_module.storage, "load", lambda key, stream=True: [b"one"] if key == "k1" else [b"two"] + ) + + # Act + with app.test_request_context( + "/datasets/ds-1/documents/download-zip", + method="POST", + json={"document_ids": ["33333333-3333-3333-3333-333333333333", "44444444-4444-4444-4444-444444444444"]}, + ): + api = datasets_document_module.DocumentBatchDownloadZipApi() + response = api.post(dataset_id="ds-1") + + # Assert: response body is a valid ZIP and contains the expected entries. + response.direct_passthrough = False + data = response.get_data() + response.close() + + with ZipFile(BytesIO(data), mode="r") as zf: + assert zf.namelist() == ["a.txt", "b.txt"] + assert zf.read("a.txt") == b"one" + assert zf.read("b.txt") == b"two" + + +def test_batch_download_zip_rejects_non_upload_file_document( + app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure batch ZIP download rejects non upload-file documents.""" + + monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123")) + monkeypatch.setattr( + datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") + ) + monkeypatch.setattr( + datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None + ) + + doc = _mock_document( + document_id="55555555-5555-5555-5555-555555555555", + tenant_id="tenant-123", + data_source_type="website_crawl", + upload_file_id="file-1", + ) + monkeypatch.setattr( + datasets_document_module.DocumentService, + "get_documents_by_ids", + lambda *_args, **_kwargs: [doc], + ) + + with app.test_request_context( + "/datasets/ds-1/documents/download-zip", + method="POST", + json={"document_ids": ["55555555-5555-5555-5555-555555555555"]}, + ): + api = datasets_document_module.DocumentBatchDownloadZipApi() + with pytest.raises(NotFound): + api.post(dataset_id="ds-1") + + +def test_document_download_returns_url_for_upload_file_document( + app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure upload-file documents return a `{url}` JSON payload.""" + + _wire_common_success_mocks( + module=datasets_document_module, + monkeypatch=monkeypatch, + current_tenant_id="tenant-123", + document_tenant_id="tenant-123", + data_source_type="upload_file", + upload_file_id="file-123", + upload_file_exists=True, + signed_url="https://example.com/signed", + ) + + # Build a request context then call the resource method directly. + with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): + api = datasets_document_module.DocumentDownloadApi() + result = api.get(dataset_id="ds-1", document_id="doc-1") + + assert result == {"url": "https://example.com/signed"} + + +def test_document_download_rejects_non_upload_file_document( + app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure non-upload documents raise 404 (no file to download).""" + + _wire_common_success_mocks( + module=datasets_document_module, + monkeypatch=monkeypatch, + current_tenant_id="tenant-123", + document_tenant_id="tenant-123", + data_source_type="website_crawl", + upload_file_id="file-123", + upload_file_exists=True, + signed_url="https://example.com/signed", + ) + + with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): + api = datasets_document_module.DocumentDownloadApi() + with pytest.raises(NotFound): + api.get(dataset_id="ds-1", document_id="doc-1") + + +def test_document_download_rejects_missing_upload_file_id( + app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure missing `upload_file_id` raises 404.""" + + _wire_common_success_mocks( + module=datasets_document_module, + monkeypatch=monkeypatch, + current_tenant_id="tenant-123", + document_tenant_id="tenant-123", + data_source_type="upload_file", + upload_file_id=None, + upload_file_exists=False, + signed_url="https://example.com/signed", + ) + + with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): + api = datasets_document_module.DocumentDownloadApi() + with pytest.raises(NotFound): + api.get(dataset_id="ds-1", document_id="doc-1") + + +def test_document_download_rejects_when_upload_file_record_missing( + app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure missing UploadFile row raises 404.""" + + _wire_common_success_mocks( + module=datasets_document_module, + monkeypatch=monkeypatch, + current_tenant_id="tenant-123", + document_tenant_id="tenant-123", + data_source_type="upload_file", + upload_file_id="file-123", + upload_file_exists=False, + signed_url="https://example.com/signed", + ) + + with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): + api = datasets_document_module.DocumentDownloadApi() + with pytest.raises(NotFound): + api.get(dataset_id="ds-1", document_id="doc-1") + + +def test_document_download_rejects_tenant_mismatch( + app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure tenant mismatch is rejected by the shared `get_document()` permission check.""" + + _wire_common_success_mocks( + module=datasets_document_module, + monkeypatch=monkeypatch, + current_tenant_id="tenant-123", + document_tenant_id="tenant-999", + data_source_type="upload_file", + upload_file_id="file-123", + upload_file_exists=True, + signed_url="https://example.com/signed", + ) + + with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): + api = datasets_document_module.DocumentDownloadApi() + with pytest.raises(Forbidden): + api.get(dataset_id="ds-1", document_id="doc-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external_dataset_payload.py b/api/tests/unit_tests/controllers/console/datasets/test_external_dataset_payload.py new file mode 100644 index 0000000000..2ea1fcf544 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_external_dataset_payload.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +""" +Unit tests for the external dataset controller payload schemas. + +These tests focus on Pydantic validation rules so we can catch regressions +in request constraints (e.g. max length changes) without exercising the +full Flask/RESTX request stack. +""" + +import pytest +from pydantic import ValidationError + +from controllers.console.datasets.external import ExternalDatasetCreatePayload + + +def test_external_dataset_create_payload_allows_name_length_100() -> None: + """Ensure the `name` field accepts up to 100 characters (inclusive).""" + + # Build a request payload with a boundary-length name value. + name_100: str = "a" * 100 + payload = { + "external_knowledge_api_id": "ek-api-1", + "external_knowledge_id": "ek-1", + "name": name_100, + } + + model = ExternalDatasetCreatePayload.model_validate(payload) + assert model.name == name_100 + + +def test_external_dataset_create_payload_rejects_name_length_101() -> None: + """Ensure the `name` field rejects values longer than 100 characters.""" + + # Build a request payload that exceeds the max length by 1. + name_101: str = "a" * 101 + payload: dict[str, object] = { + "external_knowledge_api_id": "ek-api-1", + "external_knowledge_id": "ek-1", + "name": name_101, + } + + with pytest.raises(ValidationError) as exc_info: + ExternalDatasetCreatePayload.model_validate(payload) + + errors = exc_info.value.errors() + assert errors[0]["loc"] == ("name",) + assert errors[0]["type"] == "string_too_long" + assert errors[0]["ctx"]["max_length"] == 100 diff --git a/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py b/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py new file mode 100644 index 0000000000..f8dd98fdb2 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py @@ -0,0 +1,145 @@ +""" +Test for document detail API data_source_info serialization fix. + +This test verifies that the document detail API returns both data_source_info +and data_source_detail_dict for all data_source_type values, including "local_file". +""" + +import json +from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union + +from models.dataset import Document + + +class LocalFileInfo(TypedDict): + file_path: str + size: int + created_at: NotRequired[str] + + +class UploadFileInfo(TypedDict): + upload_file_id: str + + +class NotionImportInfo(TypedDict): + notion_page_id: str + workspace_id: str + + +class WebsiteCrawlInfo(TypedDict): + url: str + job_id: str + + +RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo] +T_type = TypeVar("T_type", bound=str) +T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]) + + +class Case(TypedDict, Generic[T_type, T_info]): + data_source_type: T_type + data_source_info: str + expected_raw: T_info + + +LocalFileCase = Case[Literal["local_file"], LocalFileInfo] +UploadFileCase = Case[Literal["upload_file"], UploadFileInfo] +NotionImportCase = Case[Literal["notion_import"], NotionImportInfo] +WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo] + +AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase] + + +case_1: LocalFileCase = { + "data_source_type": "local_file", + "data_source_info": json.dumps({"file_path": "/tmp/test.txt", "size": 1024}), + "expected_raw": {"file_path": "/tmp/test.txt", "size": 1024}, +} + + +# ERROR: Expected LocalFileInfo, but got WebsiteCrawlInfo +case_2: LocalFileCase = { + "data_source_type": "local_file", + "data_source_info": "...", + "expected_raw": {"file_path": "https://google.com", "size": 123}, +} + +cases: list[AnyCase] = [case_1] + + +class TestDocumentDetailDataSourceInfo: + """Test cases for document detail API data_source_info serialization.""" + + def test_data_source_info_dict_returns_raw_data(self): + """Test that data_source_info_dict returns raw JSON data for all data_source_type values.""" + # Test data for different data_source_type values + for case in cases: + document = Document( + data_source_type=case["data_source_type"], + data_source_info=case["data_source_info"], + ) + + # Test data_source_info_dict (raw data) + raw_result = document.data_source_info_dict + assert raw_result == case["expected_raw"], f"Failed for {case['data_source_type']}" + + # Verify raw_result is always a valid dict + assert isinstance(raw_result, dict) + + def test_local_file_data_source_info_without_db_context(self): + """Test that local_file type data_source_info_dict works without database context.""" + test_data: LocalFileInfo = { + "file_path": "/local/path/document.txt", + "size": 512, + "created_at": "2024-01-01T00:00:00Z", + } + + document = Document( + data_source_type="local_file", + data_source_info=json.dumps(test_data), + ) + + # data_source_info_dict should return the raw data (this doesn't need DB context) + raw_data = document.data_source_info_dict + assert raw_data == test_data + assert isinstance(raw_data, dict) + + # Verify the data contains expected keys for pipeline mode + assert "file_path" in raw_data + assert "size" in raw_data + + def test_notion_and_website_crawl_data_source_detail(self): + """Test that notion_import and website_crawl return raw data in data_source_detail_dict.""" + # Test notion_import + notion_data: NotionImportInfo = {"notion_page_id": "page-123", "workspace_id": "ws-456"} + document = Document( + data_source_type="notion_import", + data_source_info=json.dumps(notion_data), + ) + + # data_source_detail_dict should return raw data for notion_import + detail_result = document.data_source_detail_dict + assert detail_result == notion_data + + # Test website_crawl + website_data: WebsiteCrawlInfo = {"url": "https://example.com", "job_id": "job-789"} + document = Document( + data_source_type="website_crawl", + data_source_info=json.dumps(website_data), + ) + + # data_source_detail_dict should return raw data for website_crawl + detail_result = document.data_source_detail_dict + assert detail_result == website_data + + def test_local_file_data_source_detail_dict_without_db(self): + """Test that local_file returns empty data_source_detail_dict (this doesn't need DB context).""" + # Test local_file - this should work without database context since it returns {} early + document = Document( + data_source_type="local_file", + data_source_info=json.dumps({"file_path": "/tmp/test.txt"}), + ) + + # Should return empty dict for local_file type (handled in the model) + detail_result = document.data_source_detail_dict + assert detail_result == {} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_ping.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_ping.py new file mode 100644 index 0000000000..fc04ca078b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_ping.py @@ -0,0 +1,27 @@ +import builtins + +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_console_ping_fastopenapi_returns_pong(app: Flask): + ext_fastopenapi.init_app(app) + + client = app.test_client() + response = client.get("/console/api/ping") + + assert response.status_code == 200 + assert response.get_json() == {"result": "pong"} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_setup.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_setup.py new file mode 100644 index 0000000000..385539b6f3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_setup.py @@ -0,0 +1,56 @@ +import builtins +from unittest.mock import patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_console_setup_fastopenapi_get_not_started(app: Flask): + ext_fastopenapi.init_app(app) + + with ( + patch("controllers.console.setup.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.setup.get_setup_status", return_value=None), + ): + client = app.test_client() + response = client.get("/console/api/setup") + + assert response.status_code == 200 + assert response.get_json() == {"step": "not_started", "setup_at": None} + + +def test_console_setup_fastopenapi_post_success(app: Flask): + ext_fastopenapi.init_app(app) + + payload = { + "email": "admin@example.com", + "name": "Admin", + "password": "Passw0rd1", + "language": "en-US", + } + + with ( + patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.setup.get_setup_status", return_value=None), + patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0), + patch("controllers.console.setup.get_init_validate_status", return_value=True), + patch("controllers.console.setup.RegisterService.setup"), + ): + client = app.test_client() + response = client.post("/console/api/setup", json=payload) + + assert response.status_code == 201 + assert response.get_json() == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_version.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_version.py new file mode 100644 index 0000000000..c5b4e0dfcf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_version.py @@ -0,0 +1,35 @@ +import builtins +from unittest.mock import patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from configs import dify_config +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_console_version_fastopenapi_returns_current_version(app: Flask): + ext_fastopenapi.init_app(app) + + with patch("controllers.console.version.dify_config.CHECK_UPDATE_URL", None): + client = app.test_client() + response = client.get("/console/api/version", query_string={"current_version": "0.0.0"}) + + assert response.status_code == 200 + data = response.get_json() + assert data["version"] == dify_config.project.version + assert data["release_date"] == "" + assert data["release_notes"] == "" + assert data["can_auto_update"] is False + assert "features" in data diff --git a/api/tests/unit_tests/controllers/console/test_files_security.py b/api/tests/unit_tests/controllers/console/test_files_security.py index 2630fbcfd0..370bf63fdb 100644 --- a/api/tests/unit_tests/controllers/console/test_files_security.py +++ b/api/tests/unit_tests/controllers/console/test_files_security.py @@ -1,7 +1,9 @@ +import builtins import io from unittest.mock import patch import pytest +from flask.views import MethodView from werkzeug.exceptions import Forbidden from controllers.common.errors import ( @@ -14,6 +16,9 @@ from controllers.common.errors import ( from services.errors.file import FileTooLargeError as ServiceFileTooLargeError from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + class TestFileUploadSecurity: """Test file upload security logic without complex framework setup""" @@ -128,7 +133,7 @@ class TestFileUploadSecurity: # Test passes if no exception is raised # Test 4: Service error handling - @patch("services.file_service.FileService.upload_file") + @patch("controllers.console.files.FileService.upload_file") def test_should_handle_file_too_large_error(self, mock_upload): """Test that service FileTooLargeError is properly converted""" mock_upload.side_effect = ServiceFileTooLargeError("File too large") @@ -140,7 +145,7 @@ class TestFileUploadSecurity: with pytest.raises(FileTooLargeError): raise FileTooLargeError(e.description) - @patch("services.file_service.FileService.upload_file") + @patch("controllers.console.files.FileService.upload_file") def test_should_handle_unsupported_file_type_error(self, mock_upload): """Test that service UnsupportedFileTypeError is properly converted""" mock_upload.side_effect = ServiceUnsupportedFileTypeError() diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py new file mode 100644 index 0000000000..9afc1c4166 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -0,0 +1,247 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, g + +from controllers.console.workspace.account import ( + AccountDeleteUpdateFeedbackApi, + ChangeEmailCheckApi, + ChangeEmailResetApi, + ChangeEmailSendEmailApi, + CheckEmailUnique, +) +from models import Account +from services.account_service import AccountService + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["RESTX_MASK_HEADER"] = "X-Fields" + app.login_manager = SimpleNamespace(_load_user=lambda: None) + return app + + +def _mock_wraps_db(mock_db): + mock_db.session.query.return_value.first.return_value = MagicMock() + + +def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account: + tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id") + account = Account(name=account_id, email=email) + account.email = email + account.id = account_id + account.status = "active" + account._current_tenant = tenant_obj + return account + + +def _set_logged_in_user(account: Account): + g._login_user = account + g._current_tenant = account.current_tenant + + +class TestChangeEmailSend: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_normalize_new_email_phase( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = {"email": "current@example.com"} + mock_send_email.return_value = "token-abc" + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailSendEmailApi().post() + + assert response == {"result": "success", "data": "token-abc"} + mock_send_email.assert_called_once_with( + account=None, + email="new@example.com", + old_email="current@example.com", + language="en-US", + phase="new_email", + ) + mock_extract_ip.assert_called_once() + mock_is_ip_limit.assert_called_once_with("127.0.0.1") + mock_csrf.assert_called_once() + + +class TestChangeEmailValidity: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_validate_with_normalized_email( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("user@example.com", "acc2") + mock_current_account.return_value = (mock_account, None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"} + mock_generate_token.return_value = (None, "new-token") + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "User@Example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_is_rate_limit.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + mock_generate_token.assert_called_once_with( + "user@example.com", code="1234", old_email="old@example.com", additional_data={} + ) + mock_reset_rate.assert_called_once_with("user@example.com") + mock_csrf.assert_called_once() + + +class TestChangeEmailReset: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_normalize_new_email_before_update( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = {"old_email": "OLD@example.com"} + mock_account_after_update = _build_account("new@example.com", "acc3-updated") + mock_update_account.return_value = mock_account_after_update + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "New@Example.com", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + ChangeEmailResetApi().post() + + mock_is_freeze.assert_called_once_with("new@example.com") + mock_check_unique.assert_called_once_with("new@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_update_account.assert_called_once_with(current_user, email="new@example.com") + mock_send_notify.assert_called_once_with(email="new@example.com") + mock_csrf.assert_called_once() + + +class TestAccountDeletionFeedback: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") + def test_should_normalize_feedback_email(self, mock_update, mock_db, app): + _mock_wraps_db(mock_db) + with app.test_request_context( + "/account/delete/feedback", + method="POST", + json={"email": "User@Example.com", "feedback": "test"}, + ): + response = AccountDeleteUpdateFeedbackApi().post() + + assert response == {"result": "success"} + mock_update.assert_called_once_with("User@Example.com", "test") + + +class TestCheckEmailUnique: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): + _mock_wraps_db(mock_db) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + + with app.test_request_context( + "/account/change-email/check-email-unique", + method="POST", + json={"email": "Case@Test.com"}, + ): + response = CheckEmailUnique().post() + + assert response == {"result": "success"} + mock_is_freeze.assert_called_once_with("case@test.com") + mock_check_unique.assert_called_once_with("case@test.com") + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + session = MagicMock() + first = MagicMock() + first.scalar_one_or_none.return_value = None + second = MagicMock() + expected_account = MagicMock() + second.scalar_one_or_none.return_value = expected_account + session.execute.side_effect = [first, second] + + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + + assert result is expected_account + assert session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py new file mode 100644 index 0000000000..368892b922 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -0,0 +1,82 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, g + +from controllers.console.workspace.members import MemberInviteEmailApi +from models.account import Account, TenantAccountRole + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + flask_app.login_manager = SimpleNamespace(_load_user=lambda: None) + return flask_app + + +def _mock_wraps_db(mock_db): + mock_db.session.query.return_value.first.return_value = MagicMock() + + +def _build_feature_flags(): + placeholder_quota = SimpleNamespace(limit=0, size=0) + workspace_members = SimpleNamespace(is_available=lambda count: True) + return SimpleNamespace( + billing=SimpleNamespace(enabled=False), + workspace_members=workspace_members, + members=placeholder_quota, + apps=placeholder_quota, + vector_space=placeholder_quota, + documents_upload_quota=placeholder_quota, + annotation_quota_limit=placeholder_quota, + ) + + +class TestMemberInviteEmailApi: + @patch("controllers.console.workspace.members.FeatureService.get_features") + @patch("controllers.console.workspace.members.RegisterService.invite_new_member") + @patch("controllers.console.workspace.members.current_account_with_tenant") + @patch("controllers.console.wraps.db") + @patch("libs.login.check_csrf_token", return_value=None) + def test_invite_normalizes_emails( + self, + mock_csrf, + mock_db, + mock_current_account, + mock_invite_member, + mock_get_features, + app, + ): + _mock_wraps_db(mock_db) + mock_get_features.return_value = _build_feature_flags() + mock_invite_member.return_value = "token-abc" + + tenant = SimpleNamespace(id="tenant-1", name="Test Tenant") + inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active") + mock_current_account.return_value = (inviter, tenant.id) + + with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"): + with app.test_request_context( + "/workspaces/current/members/invite-email", + method="POST", + json={"emails": ["User@Example.com"], "role": TenantAccountRole.EDITOR.value, "language": "en-US"}, + ): + account = Account(name="tester", email="tester@example.com") + account._current_tenant = tenant + g._login_user = account + g._current_tenant = tenant + response, status_code = MemberInviteEmailApi().post() + + assert status_code == 201 + assert response["invitation_results"][0]["email"] == "user@example.com" + + assert mock_invite_member.call_count == 1 + call_args = mock_invite_member.call_args + assert call_args.kwargs["tenant"] == tenant + assert call_args.kwargs["email"] == "User@Example.com" + assert call_args.kwargs["language"] == "en-US" + assert call_args.kwargs["role"] == TenantAccountRole.EDITOR + assert call_args.kwargs["inviter"] == inviter + mock_csrf.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/workspace/__init__.py b/api/tests/unit_tests/controllers/console/workspace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py new file mode 100644 index 0000000000..59b6614d5e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -0,0 +1,145 @@ +"""Unit tests for load balancing credential validation APIs.""" + +from __future__ import annotations + +import builtins +import importlib +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from flask.views import MethodView +from werkzeug.exceptions import Forbidden + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + +from models.account import TenantAccountRole + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +@pytest.fixture +def load_balancing_module(monkeypatch: pytest.MonkeyPatch): + """Reload controller module with lightweight decorators for testing.""" + + from controllers.console import console_ns, wraps + from libs import login + + def _noop(func): + return func + + monkeypatch.setattr(login, "login_required", _noop) + monkeypatch.setattr(wraps, "setup_required", _noop) + monkeypatch.setattr(wraps, "account_initialization_required", _noop) + + def _noop_route(*args, **kwargs): # type: ignore[override] + def _decorator(cls): + return cls + + return _decorator + + monkeypatch.setattr(console_ns, "route", _noop_route) + + module_name = "controllers.console.workspace.load_balancing_config" + sys.modules.pop(module_name, None) + module = importlib.import_module(module_name) + return module + + +def _mock_user(role: TenantAccountRole) -> SimpleNamespace: + return SimpleNamespace(current_role=role) + + +def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER): + user = _mock_user(role) + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123")) + mock_service = MagicMock() + monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service) + return mock_service + + +def _request_payload(): + return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}} + + +def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch): + service = _prepare_context(load_balancing_module, monkeypatch) + + with app.test_request_context( + "/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate", + method="POST", + json=_request_payload(), + ): + response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai") + + assert response == {"result": "success"} + service.validate_load_balancing_credentials.assert_called_once_with( + tenant_id="tenant-123", + provider="openai", + model="gpt-4o", + model_type=ModelType.LLM, + credentials={"api_key": "sk-***"}, + ) + + +def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch): + service = _prepare_context(load_balancing_module, monkeypatch) + service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials") + + with app.test_request_context( + "/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate", + method="POST", + json=_request_payload(), + ): + response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai") + + assert response == {"result": "error", "error": "invalid credentials"} + + +def test_validate_credentials_requires_privileged_role( + app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch +): + _prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL) + + with app.test_request_context( + "/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate", + method="POST", + json=_request_payload(), + ): + api = load_balancing_module.LoadBalancingCredentialsValidateApi() + with pytest.raises(Forbidden): + api.post(provider="openai") + + +def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch): + service = _prepare_context(load_balancing_module, monkeypatch) + + with app.test_request_context( + "/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate", + method="POST", + json=_request_payload(), + ): + response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post( + provider="openai", config_id="cfg-1" + ) + + assert response == {"result": "success"} + service.validate_load_balancing_credentials.assert_called_once_with( + tenant_id="tenant-123", + provider="openai", + model="gpt-4o", + model_type=ModelType.LLM, + credentials={"api_key": "sk-***"}, + config_id="cfg-1", + ) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py new file mode 100644 index 0000000000..c608f731c5 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py @@ -0,0 +1,100 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask_restx import Api + +from controllers.console.workspace.tool_providers import ToolProviderMCPApi +from core.db.session_factory import configure_session_factory +from extensions.ext_database import db +from services.tools.mcp_tools_manage_service import ReconnectResult + + +# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file. +# They are intentionally no-ops because the test already patches the required +# behaviors explicitly via @patch and context managers below. +@pytest.fixture +def _mock_cache(): + return + + +@pytest.fixture +def _mock_user_tenant(): + return + + +@pytest.fixture +def client(): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" + api = Api(app) + api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp") + db.init_app(app) + # Configure session factory used by controller code + with app.app_context(): + configure_session_factory(db.engine) + return app.test_client() + + +@patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1") +) +@patch("controllers.console.workspace.tool_providers.Session") +@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url") +@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") +def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client): + # Arrange: reconnect returns tools immediately + mock_reconnect.return_value = ReconnectResult( + authed=True, + tools=json.dumps( + [{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}] + ), + encrypted_credentials="{}", + ) + + # Fake service.create_provider -> returns object with id for reload + svc = MagicMock() + create_result = MagicMock() + create_result.id = "provider-1" + svc.create_provider.return_value = create_result + svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path + mock_session.return_value.__enter__.return_value = MagicMock() + # Patch MCPToolManageService constructed inside controller + with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc): + payload = { + "server_url": "http://example.com/mcp", + "name": "demo", + "icon": "😀", + "icon_type": "emoji", + "icon_background": "#000", + "server_identifier": "demo-sid", + "configuration": {"timeout": 5, "sse_read_timeout": 30}, + "headers": {}, + "authentication": {}, + } + # Act + with ( + patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check + patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")), + patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required + patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login + patch( + "services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider", + return_value={"id": "provider-1", "tools": [{"name": "ping"}]}, + ), + ): + resp = client.post( + "/console/api/workspaces/current/tool-provider/mcp", + data=json.dumps(payload), + content_type="application/json", + ) + + # Assert + assert resp.status_code == 200 + body = resp.get_json() + assert body.get("id") == "provider-1" + # 若 transform 后包含 tools 字段,确保非空 + assert isinstance(body.get("tools"), list) + assert body["tools"] diff --git a/api/tests/unit_tests/controllers/web/test_forgot_password.py b/api/tests/unit_tests/controllers/web/test_forgot_password.py deleted file mode 100644 index d7c0d24f14..0000000000 --- a/api/tests/unit_tests/controllers/web/test_forgot_password.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Unit tests for controllers.web.forgot_password endpoints.""" - -from __future__ import annotations - -import base64 -import builtins -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from flask.views import MethodView - -# Ensure flask_restx.api finds MethodView during import. -if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView # type: ignore[attr-defined] - - -def _load_controller_module(): - """Import controllers.web.forgot_password using a stub package.""" - - import importlib - import importlib.util - import sys - from types import ModuleType - - parent_module_name = "controllers.web" - module_name = f"{parent_module_name}.forgot_password" - - if parent_module_name not in sys.modules: - from flask_restx import Namespace - - stub = ModuleType(parent_module_name) - stub.__file__ = "controllers/web/__init__.py" - stub.__path__ = ["controllers/web"] - stub.__package__ = "controllers" - stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True) - stub.web_ns = Namespace("web", description="Web API", path="/") - sys.modules[parent_module_name] = stub - - return importlib.import_module(module_name) - - -forgot_password_module = _load_controller_module() -ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi -ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi -ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi - - -@pytest.fixture -def app() -> Flask: - """Configure a minimal Flask app for request contexts.""" - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - -@pytest.fixture(autouse=True) -def _enable_web_endpoint_guards(): - """Stub enterprise and feature toggles used by route decorators.""" - - features = SimpleNamespace(enable_email_password_login=True) - with ( - patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True), - patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), - patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features), - ): - yield - - -@pytest.fixture(autouse=True) -def _mock_controller_db(): - """Replace controller-level db reference with a simple stub.""" - - fake_db = SimpleNamespace(engine=MagicMock(name="engine")) - fake_wraps_db = SimpleNamespace( - session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True)))) - ) - with ( - patch("controllers.web.forgot_password.db", fake_db), - patch("controllers.console.wraps.db", fake_wraps_db), - ): - yield fake_db - - -@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token") -@patch("controllers.web.forgot_password.Session") -@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) -@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10") -def test_send_reset_email_success( - mock_extract_ip: MagicMock, - mock_is_ip_limit: MagicMock, - mock_session: MagicMock, - mock_send_email: MagicMock, - app: Flask, -): - """POST /forgot-password returns token when email exists and limits allow.""" - - mock_account = MagicMock() - session_ctx = MagicMock() - mock_session.return_value.__enter__.return_value = session_ctx - session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account - - with app.test_request_context( - "/forgot-password", - method="POST", - json={"email": "user@example.com"}, - ): - response = ForgotPasswordSendEmailApi().post() - - assert response == {"result": "success", "data": "reset-token"} - mock_extract_ip.assert_called_once() - mock_is_ip_limit.assert_called_once_with("203.0.113.10") - mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US") - - -@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") -@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token")) -@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") -@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") -@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False) -def test_check_token_success( - mock_is_rate_limited: MagicMock, - mock_get_data: MagicMock, - mock_revoke: MagicMock, - mock_generate: MagicMock, - mock_reset_limit: MagicMock, - app: Flask, -): - """POST /forgot-password/validity validates the code and refreshes token.""" - - mock_get_data.return_value = {"email": "user@example.com", "code": "123456"} - - with app.test_request_context( - "/forgot-password/validity", - method="POST", - json={"email": "user@example.com", "code": "123456", "token": "old-token"}, - ): - response = ForgotPasswordCheckApi().post() - - assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} - mock_is_rate_limited.assert_called_once_with("user@example.com") - mock_get_data.assert_called_once_with("old-token") - mock_revoke.assert_called_once_with("old-token") - mock_generate.assert_called_once_with( - "user@example.com", - code="123456", - additional_data={"phase": "reset"}, - ) - mock_reset_limit.assert_called_once_with("user@example.com") - - -@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") -@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") -@patch("controllers.web.forgot_password.Session") -@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") -@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") -def test_reset_password_success( - mock_get_data: MagicMock, - mock_revoke_token: MagicMock, - mock_session: MagicMock, - mock_token_bytes: MagicMock, - mock_hash_password: MagicMock, - app: Flask, -): - """POST /forgot-password/resets updates the stored password when token is valid.""" - - mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"} - account = MagicMock() - session_ctx = MagicMock() - mock_session.return_value.__enter__.return_value = session_ctx - session_ctx.execute.return_value.scalar_one_or_none.return_value = account - - with app.test_request_context( - "/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() - - assert response == {"result": "success"} - mock_get_data.assert_called_once_with("reset-token") - mock_revoke_token.assert_called_once_with("reset-token") - mock_token_bytes.assert_called_once_with(16) - mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef") - expected_password = base64.b64encode(b"hashed-value").decode() - assert account.password == expected_password - expected_salt = base64.b64encode(b"0123456789abcdef").decode() - assert account.password_salt == expected_salt - session_ctx.commit.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py new file mode 100644 index 0000000000..2835f7ffbf --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -0,0 +1,174 @@ +"""Unit tests for controllers.web.message message list mapping.""" + +from __future__ import annotations + +import builtins +from datetime import datetime +from types import ModuleType, SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask import Flask +from flask.views import MethodView + +# Ensure flask_restx.api finds MethodView during import. +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +def _load_controller_module(): + """Import controllers.web.message using a stub package.""" + + import importlib + import importlib.util + import sys + + parent_module_name = "controllers.web" + module_name = f"{parent_module_name}.message" + + if parent_module_name not in sys.modules: + from flask_restx import Namespace + + stub = ModuleType(parent_module_name) + stub.__file__ = "controllers/web/__init__.py" + stub.__path__ = ["controllers/web"] + stub.__package__ = "controllers" + stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True) + stub.web_ns = Namespace("web", description="Web API", path="/") + sys.modules[parent_module_name] = stub + + wraps_module_name = f"{parent_module_name}.wraps" + if wraps_module_name not in sys.modules: + wraps_stub = ModuleType(wraps_module_name) + + class WebApiResource: + pass + + wraps_stub.WebApiResource = WebApiResource + sys.modules[wraps_module_name] = wraps_stub + + return importlib.import_module(module_name) + + +message_module = _load_controller_module() +MessageListApi = message_module.MessageListApi + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_message_list_mapping(app: Flask) -> None: + conversation_id = str(uuid4()) + message_id = str(uuid4()) + + created_at = datetime(2024, 1, 1, 12, 0, 0) + resource_created_at = datetime(2024, 1, 1, 13, 0, 0) + thought_created_at = datetime(2024, 1, 1, 14, 0, 0) + + retriever_resource_obj = SimpleNamespace( + id="res-obj", + message_id=message_id, + position=2, + dataset_id="ds-1", + dataset_name="dataset", + document_id="doc-1", + document_name="document", + data_source_type="file", + segment_id="seg-1", + score=0.9, + hit_count=1, + word_count=10, + segment_position=0, + index_node_hash="hash", + content="content", + created_at=resource_created_at, + ) + + agent_thought = SimpleNamespace( + id="thought-1", + chain_id=None, + message_chain_id="chain-1", + message_id=message_id, + position=1, + thought="thinking", + tool="tool", + tool_labels={"label": "value"}, + tool_input="{}", + created_at=thought_created_at, + observation="observed", + files=["file-a"], + ) + + message_file_obj = SimpleNamespace( + id="file-obj", + filename="b.txt", + type="file", + url=None, + mime_type=None, + size=None, + transfer_method="local", + belongs_to=None, + upload_file_id=None, + ) + + message = SimpleNamespace( + id=message_id, + conversation_id=conversation_id, + parent_message_id=None, + inputs={"foo": "bar"}, + query="hello", + re_sign_file_url_answer="answer", + user_feedback=SimpleNamespace(rating="like"), + retriever_resources=[ + {"id": "res-dict", "message_id": message_id, "position": 1}, + retriever_resource_obj, + ], + created_at=created_at, + agent_thoughts=[agent_thought], + message_files=[ + {"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"}, + message_file_obj, + ], + status="success", + error=None, + message_metadata_dict={"meta": "value"}, + ) + + pagination = SimpleNamespace(limit=20, has_more=False, data=[message]) + app_model = SimpleNamespace(mode="chat") + end_user = SimpleNamespace() + + with ( + patch.object(message_module.MessageService, "pagination_by_first_id", return_value=pagination) as mock_page, + app.test_request_context(f"/messages?conversation_id={conversation_id}&limit=20"), + ): + response = MessageListApi().get(app_model, end_user) + + mock_page.assert_called_once_with(app_model, end_user, conversation_id, None, 20) + assert response["limit"] == 20 + assert response["has_more"] is False + assert len(response["data"]) == 1 + + item = response["data"][0] + assert item["id"] == message_id + assert item["conversation_id"] == conversation_id + assert item["inputs"] == {"foo": "bar"} + assert item["answer"] == "answer" + assert item["feedback"]["rating"] == "like" + assert item["metadata"] == {"meta": "value"} + assert item["created_at"] == int(created_at.timestamp()) + + assert item["retriever_resources"][0]["id"] == "res-dict" + assert item["retriever_resources"][1]["id"] == "res-obj" + assert item["retriever_resources"][1]["created_at"] == int(resource_created_at.timestamp()) + + assert item["agent_thoughts"][0]["chain_id"] == "chain-1" + assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp()) + + assert item["message_files"][0]["id"] == "file-dict" + assert item["message_files"][1]["id"] == "file-obj" diff --git a/api/tests/unit_tests/controllers/web/test_web_forgot_password.py b/api/tests/unit_tests/controllers/web/test_web_forgot_password.py new file mode 100644 index 0000000000..3d7c319947 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_forgot_password.py @@ -0,0 +1,226 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.forgot_password import ( + ForgotPasswordCheckApi, + ForgotPasswordResetApi, + ForgotPasswordSendEmailApi, +) + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture(autouse=True) +def _patch_wraps(): + wraps_features = SimpleNamespace(enable_email_password_login=True) + dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD") + with ( + patch("controllers.console.wraps.db") as mock_db, + patch("controllers.console.wraps.dify_config", dify_settings), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + yield + + +class TestForgotPasswordSendEmailApi: + @patch("controllers.web.forgot_password.AccountService.send_reset_password_email") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") + @patch("controllers.web.forgot_password.Session") + def test_should_normalize_email_before_sending( + self, + mock_session_cls, + mock_extract_ip, + mock_rate_limit, + mock_get_account, + mock_send_mail, + app, + ): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_send_mail.return_value = "token-123" + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") + mock_extract_ip.assert_called_once() + mock_rate_limit.assert_called_once_with("127.0.0.1") + + +class TestForgotPasswordCheckApi: + @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_should_normalize_email_for_validity_checks( + self, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"} + mock_generate_token.return_value = (None, "new-token") + + with app.test_request_context( + "/web/forgot-password/validity", + method="POST", + json={"email": "User@Example.com", "code": "1234", "token": "token-123"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_is_rate_limit.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + mock_generate_token.assert_called_once_with( + "User@Example.com", + code="1234", + additional_data={"phase": "reset"}, + ) + mock_reset_rate.assert_called_once_with("user@example.com") + + @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_should_preserve_token_email_case( + self, + mock_is_rate_limit, + mock_get_data, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"} + mock_generate_token.return_value = (None, "fresh-token") + + with app.test_request_context( + "/web/forgot-password/validity", + method="POST", + json={"email": "mixedcase@example.com", "code": "5678", "token": "token-upper"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "mixedcase@example.com", "token": "fresh-token"} + mock_generate_token.assert_called_once_with( + "MixedCase@Example.com", + code="5678", + additional_data={"phase": "reset"}, + ) + mock_revoke_token.assert_called_once_with("token-upper") + mock_reset_rate.assert_called_once_with("mixedcase@example.com") + + +class TestForgotPasswordResetApi: + @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + def test_should_fetch_account_with_fallback( + self, + mock_get_reset_data, + mock_revoke_token, + mock_session_cls, + mock_get_account, + mock_update_account, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_update_account.assert_called_once() + mock_revoke_token.assert_called_once_with("token-123") + + @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") + @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") + @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + def test_should_update_password_and_commit( + self, + mock_get_account, + mock_get_reset_data, + mock_revoke_token, + mock_session_cls, + mock_token_bytes, + mock_hash_password, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} + account = MagicMock() + mock_get_account.return_value = account + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_reset_data.assert_called_once_with("reset-token") + mock_revoke_token.assert_called_once_with("reset-token") + mock_token_bytes.assert_called_once_with(16) + mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef") + expected_password = base64.b64encode(b"hashed-value").decode() + assert account.password == expected_password + expected_salt = base64.b64encode(b"0123456789abcdef").decode() + assert account.password_salt == expected_salt + mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py new file mode 100644 index 0000000000..e62993e8d5 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -0,0 +1,91 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi + + +def encode_code(code: str) -> str: + return base64.b64encode(code.encode("utf-8")).decode() + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture(autouse=True) +def _patch_wraps(): + wraps_features = SimpleNamespace(enable_email_password_login=True) + console_dify = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD") + web_dify = SimpleNamespace(ENTERPRISE_ENABLED=True) + with ( + patch("controllers.console.wraps.db") as mock_db, + patch("controllers.console.wraps.dify_config", console_dify), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + patch("controllers.web.login.dify_config", web_dify), + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + yield + + +class TestEmailCodeLoginSendEmailApi: + @patch("controllers.web.login.WebAppAuthService.send_email_code_login_email") + @patch("controllers.web.login.WebAppAuthService.get_user_through_email") + def test_should_fetch_account_with_original_email( + self, + mock_get_user, + mock_send_email, + app, + ): + mock_account = MagicMock() + mock_get_user.return_value = mock_account + mock_send_email.return_value = "token-123" + + with app.test_request_context( + "/web/email-code-login", + method="POST", + json={"email": "User@Example.com", "language": "en-US"}, + ): + response = EmailCodeLoginSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_user.assert_called_once_with("User@Example.com") + mock_send_email.assert_called_once_with(account=mock_account, language="en-US") + + +class TestEmailCodeLoginApi: + @patch("controllers.web.login.AccountService.reset_login_error_rate_limit") + @patch("controllers.web.login.WebAppAuthService.login", return_value="new-access-token") + @patch("controllers.web.login.WebAppAuthService.get_user_through_email") + @patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token") + @patch("controllers.web.login.WebAppAuthService.get_email_code_login_data") + def test_should_normalize_email_before_validating( + self, + mock_get_token_data, + mock_revoke_token, + mock_get_user, + mock_login, + mock_reset_login_rate, + app, + ): + mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} + mock_get_user.return_value = MagicMock() + + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + response = EmailCodeLoginApi().post() + + assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}} + mock_get_user.assert_called_once_with("User@Example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_login.assert_called_once() + mock_reset_login_rate.assert_called_once_with("user@example.com") diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py new file mode 100644 index 0000000000..421a5246eb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -0,0 +1,454 @@ +"""Test multimodal image output handling in BaseAppRunner.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueMessageFileEvent +from core.file.enums import FileTransferMethod, FileType +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from models.enums import CreatorUserRole + + +class TestBaseAppRunnerMultimodal: + """Test that BaseAppRunner correctly handles multimodal image content.""" + + @pytest.fixture + def mock_user_id(self): + """Mock user ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_tenant_id(self): + """Mock tenant ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_message_id(self): + """Mock message ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = MagicMock() + manager.invoke_from = InvokeFrom.SERVICE_API + return manager + + @pytest.fixture + def mock_tool_file(self): + """Create a mock tool file.""" + tool_file = MagicMock() + tool_file.id = str(uuid4()) + return tool_file + + @pytest.fixture + def mock_message_file(self): + """Create a mock message file.""" + message_file = MagicMock() + message_file.id = str(uuid4()) + return message_file + + def test_handle_multimodal_image_content_with_url( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from URL.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + # Verify tool file was created from URL + mock_mgr.create_file_by_url.assert_called_once_with( + user_id=mock_user_id, + tenant_id=mock_tenant_id, + file_url=image_url, + conversation_id=None, + ) + + # Verify message file was created with correct parameters + mock_msg_file_class.assert_called_once() + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["message_id"] == mock_message_id + assert call_kwargs["type"] == FileType.IMAGE + assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE + assert call_kwargs["belongs_to"] == "assistant" + assert call_kwargs["created_by"] == mock_user_id + + # Verify database operations + mock_session.add.assert_called_once_with(mock_message_file) + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once_with(mock_message_file) + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + publish_call = mock_queue_manager.publish.call_args + assert isinstance(publish_call[0][0], QueueMessageFileEvent) + assert publish_call[0][0].message_file_id == mock_message_file.id + # publish_from might be passed as positional or keyword argument + assert ( + publish_call[0][1] == PublishFrom.APPLICATION_MANAGER + or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER + ) + + def test_handle_multimodal_image_content_with_base64( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from base64 data.""" + # Arrange + import base64 + + # Create a small test image (1x1 PNG) + test_image_data = base64.b64encode( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde" + ).decode() + content = ImagePromptMessageContent( + base64_data=test_image_data, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_raw.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + # Verify tool file was created from base64 + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + assert call_kwargs["user_id"] == mock_user_id + assert call_kwargs["tenant_id"] == mock_tenant_id + assert call_kwargs["conversation_id"] is None + assert "file_binary" in call_kwargs + assert call_kwargs["mimetype"] == "image/png" + assert call_kwargs["filename"].startswith("generated_image") + assert call_kwargs["filename"].endswith(".png") + + # Verify message file was created + mock_msg_file_class.assert_called_once() + + # Verify database operations + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + + def test_handle_multimodal_image_content_with_base64_data_uri( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from base64 data with URI prefix.""" + # Arrange + # Data URI format: data:image/png;base64, + test_image_data = ( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + ) + content = ImagePromptMessageContent( + base64_data=f"data:image/png;base64,{test_image_data}", + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_raw.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify that base64 data was extracted correctly (without prefix) + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + # The base64 data should be decoded, so we check the binary was passed + assert "file_binary" in call_kwargs + + def test_handle_multimodal_image_content_without_url_or_base64( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + ): + """Test handling image content without URL or base64 data.""" + # Arrange + content = ImagePromptMessageContent( + url="", + base64_data="", + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - should not create any files or publish events + mock_mgr_class.assert_not_called() + mock_msg_file_class.assert_not_called() + mock_session.add.assert_not_called() + mock_queue_manager.publish.assert_not_called() + + def test_handle_multimodal_image_content_with_error( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + ): + """Test handling image content when an error occurs.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock to raise exception + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.side_effect = Exception("Network error") + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + # Should not raise exception, just log it + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - should not create message file or publish event on error + mock_msg_file_class.assert_not_called() + mock_session.add.assert_not_called() + mock_queue_manager.publish.assert_not_called() + + def test_handle_multimodal_image_content_debugger_mode( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test that debugger mode sets correct created_by_role.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify created_by_role is ACCOUNT for debugger mode + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT + + def test_handle_multimodal_image_content_service_api_mode( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test that service API mode sets correct created_by_role.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify created_by_role is END_USER for service API + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index cdd1950d82..d25bff92dc 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -458,10 +458,10 @@ class TestWorkflowResponseConverterServiceApiTruncation: description="Explore calls should have truncation enabled", ), TestCase( - name="published_truncation_enabled", - invoke_from=InvokeFrom.PUBLISHED, + name="published_pipeline_truncation_enabled", + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, expected_truncation_enabled=True, - description="Published app calls should have truncation enabled", + description="Published pipeline calls should have truncation enabled", ), ], ids=lambda x: x.name, diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py new file mode 100644 index 0000000000..f5903d28bd --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from models.workflow import Workflow + + +def _make_graph_state(): + variable_pool = VariablePool( + system_variables=SystemVariable.default(), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + return MagicMock(), variable_pool, GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + + +@pytest.mark.parametrize( + ("single_iteration_run", "single_loop_run"), + [ + (WorkflowAppGenerateEntity.SingleIterationRunEntity(node_id="iter", inputs={}), None), + (None, WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id="loop", inputs={})), + ], +) +def test_run_uses_single_node_execution_branch( + single_iteration_run: Any, + single_loop_run: Any, +) -> None: + app_config = MagicMock() + app_config.app_id = "app" + app_config.tenant_id = "tenant" + app_config.workflow_id = "workflow" + + app_generate_entity = MagicMock(spec=WorkflowAppGenerateEntity) + app_generate_entity.app_config = app_config + app_generate_entity.inputs = {} + app_generate_entity.files = [] + app_generate_entity.user_id = "user" + app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + app_generate_entity.workflow_execution_id = "execution-id" + app_generate_entity.task_id = "task-id" + app_generate_entity.call_depth = 0 + app_generate_entity.trace_manager = None + app_generate_entity.single_iteration_run = single_iteration_run + app_generate_entity.single_loop_run = single_loop_run + + workflow = MagicMock(spec=Workflow) + workflow.tenant_id = "tenant" + workflow.app_id = "app" + workflow.id = "workflow" + workflow.type = "workflow" + workflow.version = "v1" + workflow.graph_dict = {"nodes": [], "edges": []} + workflow.environment_variables = [] + + runner = WorkflowAppRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(spec=AppQueueManager), + variable_loader=MagicMock(), + workflow=workflow, + system_user_id="system-user", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + graph, variable_pool, graph_runtime_state = _make_graph_state() + mock_workflow_entry = MagicMock() + mock_workflow_entry.graph_engine = MagicMock() + mock_workflow_entry.graph_engine.layer = MagicMock() + mock_workflow_entry.run.return_value = iter([]) + + with ( + patch("core.app.apps.workflow.app_runner.RedisChannel"), + patch("core.app.apps.workflow.app_runner.redis_client"), + patch("core.app.apps.workflow.app_runner.WorkflowEntry", return_value=mock_workflow_entry) as entry_class, + patch.object( + runner, + "_prepare_single_node_execution", + return_value=( + graph, + variable_pool, + graph_runtime_state, + ), + ) as prepare_single, + patch.object(runner, "_init_graph") as init_graph, + ): + runner.run() + + prepare_single.assert_called_once_with( + workflow=workflow, + single_iteration_run=single_iteration_run, + single_loop_run=single_loop_run, + ) + init_graph.assert_not_called() + + entry_kwargs = entry_class.call_args.kwargs + assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER + assert entry_kwargs["variable_pool"] is variable_pool + assert entry_kwargs["graph_runtime_state"] is graph_runtime_state diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py new file mode 100644 index 0000000000..b6e8cc9c8e --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -0,0 +1,144 @@ +from collections.abc import Sequence +from datetime import datetime +from unittest.mock import Mock + +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.variables import StringVariable +from core.variables.segments import Segment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph_engine.protocols.command_channel import CommandChannel +from core.workflow.graph_events.node import NodeRunSucceededEvent +from core.workflow.node_events import NodeRunResult +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers +from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from core.workflow.system_variable import SystemVariable + + +class MockReadOnlyVariablePool: + def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None: + self._variables = variables or {} + + def get(self, selector: Sequence[str]) -> Segment | None: + if len(selector) < 2: + return None + return self._variables.get((selector[0], selector[1])) + + def get_all_by_node(self, node_id: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == node_id} + + def get_by_prefix(self, prefix: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} + + +def _build_graph_runtime_state( + variable_pool: MockReadOnlyVariablePool, + conversation_id: str | None = None, +) -> ReadOnlyGraphRuntimeState: + graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + graph_runtime_state.variable_pool = variable_pool + graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() + return graph_runtime_state + + +def _build_node_run_succeeded_event( + *, + node_type: NodeType, + outputs: dict[str, object] | None = None, + process_data: dict[str, object] | None = None, +) -> NodeRunSucceededEvent: + return NodeRunSucceededEvent( + id="node-exec-id", + node_id="assigner", + node_type=node_type, + start_at=datetime.utcnow(), + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs=outputs or {}, + process_data=process_data or {}, + ), + ) + + +def test_persists_conversation_variables_from_assigner_output(): + conversation_id = "conv-123" + variable = StringVariable( + id="var-1", + name="name", + value="updated", + selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], + ) + process_data = common_helpers.set_updated_variables( + {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] + ) + + variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + layer.on_event(event) + + updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) + updater.flush.assert_called_once() + + +def test_skips_when_outputs_missing(): + conversation_id = "conv-456" + variable = StringVariable( + id="var-2", + name="name", + value="updated", + selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], + ) + + variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() + + +def test_skips_non_assigner_nodes(): + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.LLM) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() + + +def test_skips_non_conversation_variables(): + conversation_id = "conv-789" + non_conversation_variable = StringVariable( + id="var-3", + name="name", + value="updated", + selector=["environment", "name"], + ) + process_data = common_helpers.set_updated_variables( + {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] + ) + + variable_pool = MockReadOnlyVariablePool() + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 534420f21e..1d885f6b2e 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -1,4 +1,5 @@ import json +from collections.abc import Sequence from time import time from unittest.mock import Mock @@ -15,6 +16,7 @@ from core.app.layers.pause_state_persist_layer import ( from core.variables.segments import Segment from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, @@ -66,8 +68,10 @@ class MockReadOnlyVariablePool: def __init__(self, variables: dict[tuple[str, str], object] | None = None): self._variables = variables or {} - def get(self, node_id: str, variable_key: str) -> Segment | None: - value = self._variables.get((node_id, variable_key)) + def get(self, selector: Sequence[str]) -> Segment | None: + if len(selector) < 2: + return None + value = self._variables.get((selector[0], selector[1])) if value is None: return None mock_segment = Mock(spec=Segment) @@ -209,8 +213,9 @@ class TestPauseStatePersistenceLayer: assert layer._session_maker is session_factory assert layer._state_owner_user_id == state_owner_user_id - assert not hasattr(layer, "graph_runtime_state") - assert not hasattr(layer, "command_channel") + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + assert layer.command_channel is None def test_initialize_sets_dependencies(self): session_factory = Mock(name="session_factory") @@ -295,7 +300,7 @@ class TestPauseStatePersistenceLayer: mock_factory.assert_not_called() mock_repo.create_workflow_pause.assert_not_called() - def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self): + def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self): session_factory = Mock(name="session_factory") layer = PauseStatePersistenceLayer( session_factory=session_factory, @@ -305,7 +310,7 @@ class TestPauseStatePersistenceLayer: event = TestDataFactory.create_graph_run_paused_event() - with pytest.raises(AttributeError): + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 5ef7f0d7f4..5a43a247e3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,7 +1,6 @@ """Unit tests for the message cycle manager optimization.""" -from types import SimpleNamespace -from unittest.mock import ANY, Mock, patch +from unittest.mock import Mock, patch import pytest from flask import current_app @@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization: def test_get_message_event_type_with_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE_FILE when message has files.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = mock_message_file + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = mock_message_file # Execute with current_app.app_context(): @@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization: # Assert assert result == StreamEvent.MESSAGE_FILE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_get_message_event_type_without_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE when message has no files.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and no message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = None + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = None # Execute with current_app.app_context(): @@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization: # Assert assert result == StreamEvent.MESSAGE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = mock_message_file + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = mock_message_file # Execute: compute event type once, then pass to message_to_stream_response with current_app.app_context(): @@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization: assert result.answer == "Hello world" assert result.id == "test-message-id" assert result.event == StreamEvent.MESSAGE_FILE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): """Test that message_to_stream_response skips database query when event_type is provided.""" - with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Execute with event_type provided result = message_cycle_manager.message_to_stream_response( answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE @@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization: assert result.answer == "Hello world" assert result.id == "test-message-id" assert result.event == StreamEvent.MESSAGE - # Should not query database when event_type is provided - mock_session_class.assert_not_called() + # Should not open a session when event_type is provided + mock_session_factory.create_session.assert_not_called() def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): """Test message_to_stream_response with from_variable_selector parameter.""" @@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization: def test_optimization_usage_example(self, message_cycle_manager): """Test the optimization pattern that should be used by callers.""" # Step 1: Get event type once (this queries database) - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = None # No files + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = None # No files with current_app.app_context(): event_type = message_cycle_manager.get_message_event_type("test-message-id") - # Should query database once - mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) + # Should open session once + mock_session_factory.create_session.assert_called_once() assert event_type == StreamEvent.MESSAGE # Step 2: Use event_type for multiple calls (no additional queries) - with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: - mock_session_class.return_value.__enter__.return_value = Mock() + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + mock_session_factory.create_session.return_value.__enter__.return_value = Mock() chunk1_response = message_cycle_manager.message_to_stream_response( answer="Chunk 1", message_id="test-message-id", event_type=event_type @@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization: answer="Chunk 2", message_id="test-message-id", event_type=event_type ) - # Should not query database again - mock_session_class.assert_not_called() + # Should not open session again when event_type provided + mock_session_factory.create_session.assert_not_called() assert chunk1_response.event == StreamEvent.MESSAGE assert chunk2_response.event == StreamEvent.MESSAGE diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index beae1d0358..d6d75fb72f 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -14,12 +14,12 @@ def test_successful_request(mock_get_client): mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 - mock_client.send.return_value = mock_response mock_client.request.return_value = mock_response mock_get_client.return_value = mock_client response = make_request("GET", "http://example.com") assert response.status_code == 200 + mock_client.request.assert_called_once() @patch("core.helper.ssrf_proxy._get_ssrf_client") @@ -27,7 +27,6 @@ def test_retry_exceed_max_retries(mock_get_client): mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 500 - mock_client.send.return_value = mock_response mock_client.request.return_value = mock_response mock_get_client.return_value = mock_client @@ -72,34 +71,12 @@ class TestGetUserProvidedHostHeader: assert result in ("first.com", "second.com") -@patch("core.helper.ssrf_proxy._get_ssrf_client") -def test_host_header_preservation_without_user_header(mock_get_client): - """Test that when no Host header is provided, the default behavior is maintained.""" - mock_client = MagicMock() - mock_request = MagicMock() - mock_request.headers = {} - mock_response = MagicMock() - mock_response.status_code = 200 - mock_client.send.return_value = mock_response - mock_client.request.return_value = mock_response - mock_get_client.return_value = mock_client - - response = make_request("GET", "http://example.com") - - assert response.status_code == 200 - # Host should not be set if not provided by user - assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None - - @patch("core.helper.ssrf_proxy._get_ssrf_client") def test_host_header_preservation_with_user_header(mock_get_client): """Test that user-provided Host header is preserved in the request.""" mock_client = MagicMock() - mock_request = MagicMock() - mock_request.headers = {} mock_response = MagicMock() mock_response.status_code = 200 - mock_client.send.return_value = mock_response mock_client.request.return_value = mock_response mock_get_client.return_value = mock_client @@ -107,3 +84,93 @@ def test_host_header_preservation_with_user_header(mock_get_client): response = make_request("GET", "http://example.com", headers={"Host": custom_host}) assert response.status_code == 200 + # Verify client.request was called with the host header preserved (lowercase) + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["host"] == custom_host + + +@patch("core.helper.ssrf_proxy._get_ssrf_client") +@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"]) +def test_host_header_preservation_case_insensitive(mock_get_client, host_key): + """Test that Host header is preserved regardless of case.""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"}) + + assert response.status_code == 200 + # Host header should be normalized to lowercase "host" + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["host"] == "api.example.com" + + +class TestFollowRedirectsParameter: + """Tests for follow_redirects parameter handling. + + These tests verify that follow_redirects is correctly passed to client.request(). + """ + + @patch("core.helper.ssrf_proxy._get_ssrf_client") + def test_follow_redirects_passed_to_request(self, mock_get_client): + """Verify follow_redirects IS passed to client.request().""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + make_request("GET", "http://example.com", follow_redirects=True) + + # Verify follow_redirects was passed to request + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs.get("follow_redirects") is True + + @patch("core.helper.ssrf_proxy._get_ssrf_client") + def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client): + """Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style).""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + # Use allow_redirects (requests-style parameter) + make_request("GET", "http://example.com", allow_redirects=True) + + # Verify it was converted to follow_redirects + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs.get("follow_redirects") is True + assert "allow_redirects" not in call_kwargs + + @patch("core.helper.ssrf_proxy._get_ssrf_client") + def test_follow_redirects_not_set_when_not_specified(self, mock_get_client): + """Verify follow_redirects is not in kwargs when not specified (httpx default behavior).""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + make_request("GET", "http://example.com") + + # follow_redirects should not be in kwargs, letting httpx use its default + call_kwargs = mock_client.request.call_args.kwargs + assert "follow_redirects" not in call_kwargs + + @patch("core.helper.ssrf_proxy._get_ssrf_client") + def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client): + """Verify follow_redirects takes precedence when both are specified.""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + # Both specified - follow_redirects should take precedence + make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs.get("follow_redirects") is True diff --git a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py deleted file mode 100644 index d237c68f35..0000000000 --- a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py +++ /dev/null @@ -1,126 +0,0 @@ -import json -from unittest.mock import patch - -import pytest -from redis.exceptions import RedisError - -from core.helper.tool_provider_cache import ToolProviderListCache -from core.tools.entities.api_entities import ToolProviderTypeApiLiteral - - -@pytest.fixture -def mock_redis_client(): - """Fixture: Mock Redis client""" - with patch("core.helper.tool_provider_cache.redis_client") as mock: - yield mock - - -class TestToolProviderListCache: - """Test class for ToolProviderListCache""" - - def test_generate_cache_key(self): - """Test cache key generation logic""" - # Scenario 1: Specify typ (valid literal value) - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "builtin" - expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}" - assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key - - # Scenario 2: typ is None (defaults to "all") - expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all" - assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all - - def test_get_cached_providers_hit(self, mock_redis_client): - """Test get cached providers - cache hit and successful decoding""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "api" - mock_providers = [{"id": "tool", "name": "test_provider"}] - mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8") - - result = ToolProviderListCache.get_cached_providers(tenant_id, typ) - - mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ)) - assert result == mock_providers - - def test_get_cached_providers_decode_error(self, mock_redis_client): - """Test get cached providers - cache hit but decoding failed""" - tenant_id = "tenant_123" - mock_redis_client.get.return_value = b"invalid_json_data" - - result = ToolProviderListCache.get_cached_providers(tenant_id) - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_get_cached_providers_miss(self, mock_redis_client): - """Test get cached providers - cache miss""" - tenant_id = "tenant_123" - mock_redis_client.get.return_value = None - - result = ToolProviderListCache.get_cached_providers(tenant_id) - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_set_cached_providers(self, mock_redis_client): - """Test set cached providers""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "builtin" - mock_providers = [{"id": "tool", "name": "test_provider"}] - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - - ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers) - - mock_redis_client.setex.assert_called_once_with( - cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers) - ) - - def test_invalidate_cache_specific_type(self, mock_redis_client): - """Test invalidate cache - specific type""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "workflow" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - - ToolProviderListCache.invalidate_cache(tenant_id, typ) - - mock_redis_client.delete.assert_called_once_with(cache_key) - - def test_invalidate_cache_all_types(self, mock_redis_client): - """Test invalidate cache - clear all tenant cache""" - tenant_id = "tenant_123" - mock_keys = [ - b"tool_providers:tenant_id:tenant_123:type:all", - b"tool_providers:tenant_id:tenant_123:type:builtin", - ] - mock_redis_client.scan_iter.return_value = mock_keys - - ToolProviderListCache.invalidate_cache(tenant_id) - - def test_invalidate_cache_no_keys(self, mock_redis_client): - """Test invalidate cache - no cache keys for tenant""" - tenant_id = "tenant_123" - mock_redis_client.scan_iter.return_value = [] - - ToolProviderListCache.invalidate_cache(tenant_id) - - mock_redis_client.delete.assert_not_called() - - def test_redis_fallback_default_return(self, mock_redis_client): - """Test redis_fallback decorator - default return value (Redis error)""" - mock_redis_client.get.side_effect = RedisError("Redis connection error") - - result = ToolProviderListCache.get_cached_providers("tenant_123") - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_redis_fallback_no_default(self, mock_redis_client): - """Test redis_fallback decorator - no default return value (Redis error)""" - mock_redis_client.setex.side_effect = RedisError("Redis connection error") - - try: - ToolProviderListCache.set_cached_providers("tenant_123", "mcp", []) - except RedisError: - pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)") - - mock_redis_client.setex.assert_called_once() diff --git a/api/tests/unit_tests/core/logging/__init__.py b/api/tests/unit_tests/core/logging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/logging/test_context.py b/api/tests/unit_tests/core/logging/test_context.py new file mode 100644 index 0000000000..f388a3a0b9 --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_context.py @@ -0,0 +1,79 @@ +"""Tests for logging context module.""" + +import uuid + +from core.logging.context import ( + clear_request_context, + get_request_id, + get_trace_id, + init_request_context, +) + + +class TestLoggingContext: + """Tests for the logging context functions.""" + + def test_init_creates_request_id(self): + """init_request_context should create a 10-char request ID.""" + init_request_context() + request_id = get_request_id() + assert len(request_id) == 10 + assert all(c in "0123456789abcdef" for c in request_id) + + def test_init_creates_trace_id(self): + """init_request_context should create a 32-char trace ID.""" + init_request_context() + trace_id = get_trace_id() + assert len(trace_id) == 32 + assert all(c in "0123456789abcdef" for c in trace_id) + + def test_trace_id_derived_from_request_id(self): + """trace_id should be deterministically derived from request_id.""" + init_request_context() + request_id = get_request_id() + trace_id = get_trace_id() + + # Verify trace_id is derived using uuid5 + expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex + assert trace_id == expected_trace + + def test_clear_resets_context(self): + """clear_request_context should reset both IDs to empty strings.""" + init_request_context() + assert get_request_id() != "" + assert get_trace_id() != "" + + clear_request_context() + assert get_request_id() == "" + assert get_trace_id() == "" + + def test_default_values_are_empty(self): + """Default values should be empty strings before init.""" + clear_request_context() + assert get_request_id() == "" + assert get_trace_id() == "" + + def test_multiple_inits_create_different_ids(self): + """Each init should create new unique IDs.""" + init_request_context() + first_request_id = get_request_id() + first_trace_id = get_trace_id() + + init_request_context() + second_request_id = get_request_id() + second_trace_id = get_trace_id() + + assert first_request_id != second_request_id + assert first_trace_id != second_trace_id + + def test_context_isolation(self): + """Context should be isolated per-call (no thread leakage in same thread).""" + init_request_context() + id1 = get_request_id() + + # Simulate another request + init_request_context() + id2 = get_request_id() + + # IDs should be different + assert id1 != id2 diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py new file mode 100644 index 0000000000..b66ad111d5 --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -0,0 +1,114 @@ +"""Tests for logging filters.""" + +import logging +from unittest import mock + +import pytest + + +@pytest.fixture +def log_record(): + return logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test", + args=(), + exc_info=None, + ) + + +class TestTraceContextFilter: + def test_sets_empty_trace_id_without_context(self, log_record): + from core.logging.context import clear_request_context + from core.logging.filters import TraceContextFilter + + # Ensure no context is set + clear_request_context() + + filter = TraceContextFilter() + result = filter.filter(log_record) + + assert result is True + assert hasattr(log_record, "trace_id") + assert hasattr(log_record, "span_id") + assert hasattr(log_record, "req_id") + # Without context, IDs should be empty + assert log_record.trace_id == "" + assert log_record.req_id == "" + + def test_sets_trace_id_from_context(self, log_record): + """Test that trace_id and req_id are set from ContextVar when initialized.""" + from core.logging.context import init_request_context + from core.logging.filters import TraceContextFilter + + # Initialize context (no Flask needed!) + init_request_context() + + filter = TraceContextFilter() + filter.filter(log_record) + + # With context initialized, IDs should be set + assert log_record.trace_id != "" + assert len(log_record.trace_id) == 32 + assert log_record.req_id != "" + assert len(log_record.req_id) == 10 + + def test_filter_always_returns_true(self, log_record): + from core.logging.filters import TraceContextFilter + + filter = TraceContextFilter() + result = filter.filter(log_record) + assert result is True + + def test_sets_trace_id_from_otel_when_available(self, log_record): + from core.logging.filters import TraceContextFilter + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2 + mock_context.span_id = 0x051581BF3BB55C45 + mock_span.get_span_context.return_value = mock_context + + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), + ): + filter = TraceContextFilter() + filter.filter(log_record) + + assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" + assert log_record.span_id == "051581bf3bb55c45" + + +class TestIdentityContextFilter: + def test_sets_empty_identity_without_request_context(self, log_record): + from core.logging.filters import IdentityContextFilter + + filter = IdentityContextFilter() + result = filter.filter(log_record) + + assert result is True + assert log_record.tenant_id == "" + assert log_record.user_id == "" + assert log_record.user_type == "" + + def test_filter_always_returns_true(self, log_record): + from core.logging.filters import IdentityContextFilter + + filter = IdentityContextFilter() + result = filter.filter(log_record) + assert result is True + + def test_handles_exception_gracefully(self, log_record): + from core.logging.filters import IdentityContextFilter + + filter = IdentityContextFilter() + + # Should not raise even if something goes wrong + with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")): + result = filter.filter(log_record) + assert result is True + assert log_record.tenant_id == "" diff --git a/api/tests/unit_tests/core/logging/test_structured_formatter.py b/api/tests/unit_tests/core/logging/test_structured_formatter.py new file mode 100644 index 0000000000..94b91d205e --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_structured_formatter.py @@ -0,0 +1,267 @@ +"""Tests for structured JSON formatter.""" + +import logging +import sys + +import orjson + + +class TestStructuredJSONFormatter: + def test_basic_log_format(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter(service_name="test-service") + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=42, + msg="Test message", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["severity"] == "INFO" + assert log_dict["service"] == "test-service" + assert log_dict["caller"] == "test.py:42" + assert log_dict["message"] == "Test message" + assert "ts" in log_dict + assert log_dict["ts"].endswith("Z") + + def test_severity_mapping(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + + test_cases = [ + (logging.DEBUG, "DEBUG"), + (logging.INFO, "INFO"), + (logging.WARNING, "WARN"), + (logging.ERROR, "ERROR"), + (logging.CRITICAL, "ERROR"), + ] + + for level, expected_severity in test_cases: + record = logging.LogRecord( + name="test", + level=level, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + output = formatter.format(record) + log_dict = orjson.loads(output) + assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}" + + def test_error_with_stack_trace(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + + try: + raise ValueError("Test error") + except ValueError: + exc_info = sys.exc_info() + + record = logging.LogRecord( + name="test", + level=logging.ERROR, + pathname="test.py", + lineno=10, + msg="Error occurred", + args=(), + exc_info=exc_info, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["severity"] == "ERROR" + assert "stack_trace" in log_dict + assert "ValueError: Test error" in log_dict["stack_trace"] + + def test_no_stack_trace_for_info(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + + try: + raise ValueError("Test error") + except ValueError: + exc_info = sys.exc_info() + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=10, + msg="Info message", + args=(), + exc_info=exc_info, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "stack_trace" not in log_dict + + def test_trace_context_included(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2" + record.span_id = "051581bf3bb55c45" + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2" + assert log_dict["span_id"] == "051581bf3bb55c45" + + def test_identity_context_included(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.tenant_id = "t-global-corp" + record.user_id = "u-admin-007" + record.user_type = "admin" + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "identity" in log_dict + assert log_dict["identity"]["tenant_id"] == "t-global-corp" + assert log_dict["identity"]["user_id"] == "u-admin-007" + assert log_dict["identity"]["user_type"] == "admin" + + def test_no_identity_when_empty(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "identity" not in log_dict + + def test_attributes_included(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.attributes = {"order_id": "ord-123", "amount": 99.99} + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["attributes"]["order_id"] == "ord-123" + assert log_dict["attributes"]["amount"] == 99.99 + + def test_message_with_args(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="User %s logged in from %s", + args=("john", "192.168.1.1"), + exc_info=None, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["message"] == "User john logged in from 192.168.1.1" + + def test_timestamp_format(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + # Verify ISO 8601 format with Z suffix + ts = log_dict["ts"] + assert ts.endswith("Z") + assert "T" in ts + # Should have milliseconds + assert "." in ts + + def test_fallback_for_non_serializable_attributes(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test with non-serializable", + args=(), + exc_info=None, + ) + # Set is not serializable by orjson + record.attributes = {"items": {1, 2, 3}, "custom": object()} + + # Should not raise, fallback to json.dumps with default=str + output = formatter.format(record) + + # Verify it's valid JSON (parsed by stdlib json since orjson may fail) + import json + + log_dict = json.loads(output) + assert log_dict["message"] == "Test with non-serializable" + assert "attributes" in log_dict diff --git a/api/tests/unit_tests/core/logging/test_trace_helpers.py b/api/tests/unit_tests/core/logging/test_trace_helpers.py new file mode 100644 index 0000000000..aab1753b9b --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_trace_helpers.py @@ -0,0 +1,102 @@ +"""Tests for trace helper functions.""" + +import re +from unittest import mock + + +class TestGetSpanIdFromOtelContext: + def test_returns_none_without_span(self): + from core.helper.trace_id_helper import get_span_id_from_otel_context + + with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + result = get_span_id_from_otel_context() + assert result is None + + def test_returns_span_id_when_available(self): + from core.helper.trace_id_helper import get_span_id_from_otel_context + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.span_id = 0x051581BF3BB55C45 + mock_span.get_span_context.return_value = mock_context + + with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0): + result = get_span_id_from_otel_context() + assert result == "051581bf3bb55c45" + + def test_returns_none_on_exception(self): + from core.helper.trace_id_helper import get_span_id_from_otel_context + + with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")): + result = get_span_id_from_otel_context() + assert result is None + + +class TestGenerateTraceparentHeader: + def test_generates_valid_format(self): + from core.helper.trace_id_helper import generate_traceparent_header + + with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + result = generate_traceparent_header() + + assert result is not None + # Format: 00-{trace_id}-{span_id}-01 + parts = result.split("-") + assert len(parts) == 4 + assert parts[0] == "00" # version + assert len(parts[1]) == 32 # trace_id (32 hex chars) + assert len(parts[2]) == 16 # span_id (16 hex chars) + assert parts[3] == "01" # flags + + def test_uses_otel_context_when_available(self): + from core.helper.trace_id_helper import generate_traceparent_header + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2 + mock_context.span_id = 0x051581BF3BB55C45 + mock_span.get_span_context.return_value = mock_context + + with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with ( + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), + ): + result = generate_traceparent_header() + + assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01" + + def test_generates_hex_only_values(self): + from core.helper.trace_id_helper import generate_traceparent_header + + with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + result = generate_traceparent_header() + + parts = result.split("-") + # All parts should be valid hex + assert re.match(r"^[0-9a-f]+$", parts[1]) + assert re.match(r"^[0-9a-f]+$", parts[2]) + + +class TestParseTraceparentHeader: + def test_parses_valid_traceparent(self): + from core.helper.trace_id_helper import parse_traceparent_header + + traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01" + result = parse_traceparent_header(traceparent) + + assert result == "5b8aa5a2d2c872e8321cf37308d69df2" + + def test_returns_none_for_invalid_format(self): + from core.helper.trace_id_helper import parse_traceparent_header + + # Wrong number of parts + assert parse_traceparent_header("00-abc-def") is None + # Wrong trace_id length + assert parse_traceparent_header("00-abc-def-01") is None + + def test_returns_none_for_empty_string(self): + from core.helper.trace_id_helper import parse_traceparent_header + + assert parse_traceparent_header("") is None diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py index 93d8a20cac..5fbdabceed 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock, patch +import pytest + from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call @@ -97,3 +99,14 @@ def test__increase_tool_call(): mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator): _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) + + +def test__increase_tool_call__no_id_no_name_first_delta_should_raise(): + inputs = [ + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')), + ] + actual: list[ToolCall] = [] + with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): + with pytest.raises(ValueError): + _increase_tool_call(inputs, actual) diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py new file mode 100644 index 0000000000..91352b2a5f --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -0,0 +1,103 @@ +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result + + +def _make_chunk( + *, + model: str = "test-model", + content: str | list[TextPromptMessageContent] | None, + tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, + usage: LLMUsage | None = None, + system_fingerprint: str | None = None, +) -> LLMResultChunk: + message = AssistantPromptMessage(content=content, tool_calls=tool_calls or []) + delta = LLMResultChunkDelta(index=0, message=message, usage=usage) + return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint) + + +def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_tool_calls(): + prompt_messages = [UserPromptMessage(content="hi")] + + tool_calls = [ + AssistantPromptMessage.ToolCall( + id="1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments=""), + ), + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='{"arg1": '), + ), + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'), + ), + ] + + usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1}) + chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1") + + result = _normalize_non_stream_plugin_result( + model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) + ) + + assert result.model == "test-model" + assert result.prompt_messages == prompt_messages + assert result.message.content == "hello" + assert result.usage.prompt_tokens == 1 + assert result.system_fingerprint == "fp-1" + assert result.message.tool_calls == [ + AssistantPromptMessage.ToolCall( + id="1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), + ) + ] + + +def test__normalize_non_stream_plugin_result__from_first_chunk_list_content(): + prompt_messages = [UserPromptMessage(content="hi")] + + content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")] + chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage()) + + result = _normalize_non_stream_plugin_result( + model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) + ) + + assert result.message.content == content_list + + +def test__normalize_non_stream_plugin_result__passthrough_llm_result(): + prompt_messages = [UserPromptMessage(content="hi")] + llm_result = LLMResult( + model="test-model", + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content="ok"), + usage=LLMUsage.empty_usage(), + ) + + assert ( + _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=llm_result) + == llm_result + ) + + +def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): + prompt_messages = [UserPromptMessage(content="hi")] + + result = _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=iter([])) + + assert result.model == "test-model" + assert result.prompt_messages == prompt_messages + assert result.message.content == [] + assert result.message.tool_calls == [] + assert result.usage == LLMUsage.empty_usage() + assert result.system_fingerprint is None diff --git a/api/tests/unit_tests/core/plugin/test_endpoint_client.py b/api/tests/unit_tests/core/plugin/test_endpoint_client.py new file mode 100644 index 0000000000..53056ee42a --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_endpoint_client.py @@ -0,0 +1,279 @@ +"""Unit tests for PluginEndpointClient functionality. + +This test module covers the endpoint client operations including: +- Successful endpoint deletion +- Idempotent delete behavior (record not found) +- Non-idempotent delete behavior (other errors) + +Tests follow the Arrange-Act-Assert pattern for clarity. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.impl.endpoint import PluginEndpointClient +from core.plugin.impl.exc import PluginDaemonInternalServerError + + +class TestPluginEndpointClientDelete: + """Unit tests for PluginEndpointClient delete_endpoint operation. + + Tests cover: + - Successful endpoint deletion + - Idempotent behavior when endpoint is already deleted (record not found) + - Non-idempotent behavior for other errors + """ + + @pytest.fixture + def endpoint_client(self): + """Create a PluginEndpointClient instance for testing.""" + return PluginEndpointClient() + + @pytest.fixture + def mock_config(self): + """Mock plugin daemon configuration.""" + with ( + patch("core.plugin.impl.base.dify_config.PLUGIN_DAEMON_URL", "http://127.0.0.1:5002"), + patch("core.plugin.impl.base.dify_config.PLUGIN_DAEMON_KEY", "test-api-key"), + ): + yield + + def test_delete_endpoint_success(self, endpoint_client, mock_config): + """Test successful endpoint deletion. + + Given: + - A valid tenant_id, user_id, and endpoint_id + - The plugin daemon returns success response + When: + - delete_endpoint is called + Then: + - The method should return True + - The request should be made with correct parameters + """ + # Arrange + tenant_id = "tenant-123" + user_id = "user-456" + endpoint_id = "endpoint-789" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "code": 0, + "message": "success", + "data": True, + } + + with patch("httpx.request", return_value=mock_response): + # Act + result = endpoint_client.delete_endpoint( + tenant_id=tenant_id, + user_id=user_id, + endpoint_id=endpoint_id, + ) + + # Assert + assert result is True + + def test_delete_endpoint_idempotent_record_not_found(self, endpoint_client, mock_config): + """Test idempotent delete behavior when endpoint is already deleted. + + Given: + - A valid tenant_id, user_id, and endpoint_id + - The plugin daemon returns "record not found" error + When: + - delete_endpoint is called + Then: + - The method should return True (idempotent behavior) + - No exception should be raised + """ + # Arrange + tenant_id = "tenant-123" + user_id = "user-456" + endpoint_id = "endpoint-789" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "code": -1, + "message": ( + '{"error_type": "PluginDaemonInternalServerError", ' + '"message": "failed to remove endpoint: record not found"}' + ), + } + + with patch("httpx.request", return_value=mock_response): + # Act + result = endpoint_client.delete_endpoint( + tenant_id=tenant_id, + user_id=user_id, + endpoint_id=endpoint_id, + ) + + # Assert - should return True instead of raising an error + assert result is True + + def test_delete_endpoint_non_idempotent_other_errors(self, endpoint_client, mock_config): + """Test non-idempotent delete behavior for other errors. + + Given: + - A valid tenant_id, user_id, and endpoint_id + - The plugin daemon returns a different error (not "record not found") + When: + - delete_endpoint is called + Then: + - The method should raise PluginDaemonInternalServerError + """ + # Arrange + tenant_id = "tenant-123" + user_id = "user-456" + endpoint_id = "endpoint-789" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "code": -1, + "message": ( + '{"error_type": "PluginDaemonInternalServerError", ' + '"message": "failed to remove endpoint: internal server error"}' + ), + } + + with patch("httpx.request", return_value=mock_response): + # Act & Assert + with pytest.raises(PluginDaemonInternalServerError) as exc_info: + endpoint_client.delete_endpoint( + tenant_id=tenant_id, + user_id=user_id, + endpoint_id=endpoint_id, + ) + + # Assert - the error message should not be "record not found" + assert "record not found" not in str(exc_info.value.description) + + def test_delete_endpoint_idempotent_case_insensitive(self, endpoint_client, mock_config): + """Test idempotent delete behavior with case-insensitive error message. + + Given: + - A valid tenant_id, user_id, and endpoint_id + - The plugin daemon returns "Record Not Found" error (different case) + When: + - delete_endpoint is called + Then: + - The method should return True (idempotent behavior) + """ + # Arrange + tenant_id = "tenant-123" + user_id = "user-456" + endpoint_id = "endpoint-789" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "code": -1, + "message": '{"error_type": "PluginDaemonInternalServerError", "message": "Record Not Found"}', + } + + with patch("httpx.request", return_value=mock_response): + # Act + result = endpoint_client.delete_endpoint( + tenant_id=tenant_id, + user_id=user_id, + endpoint_id=endpoint_id, + ) + + # Assert - should still return True + assert result is True + + def test_delete_endpoint_multiple_calls_idempotent(self, endpoint_client, mock_config): + """Test that multiple delete calls are idempotent. + + Given: + - A valid tenant_id, user_id, and endpoint_id + - The first call succeeds + - Subsequent calls return "record not found" + When: + - delete_endpoint is called multiple times + Then: + - All calls should return True + """ + # Arrange + tenant_id = "tenant-123" + user_id = "user-456" + endpoint_id = "endpoint-789" + + # First call - success + mock_response_success = MagicMock() + mock_response_success.status_code = 200 + mock_response_success.json.return_value = { + "code": 0, + "message": "success", + "data": True, + } + + # Second call - record not found + mock_response_not_found = MagicMock() + mock_response_not_found.status_code = 200 + mock_response_not_found.json.return_value = { + "code": -1, + "message": ( + '{"error_type": "PluginDaemonInternalServerError", ' + '"message": "failed to remove endpoint: record not found"}' + ), + } + + with patch("httpx.request") as mock_request: + # Act - first call + mock_request.return_value = mock_response_success + result1 = endpoint_client.delete_endpoint( + tenant_id=tenant_id, + user_id=user_id, + endpoint_id=endpoint_id, + ) + + # Act - second call (already deleted) + mock_request.return_value = mock_response_not_found + result2 = endpoint_client.delete_endpoint( + tenant_id=tenant_id, + user_id=user_id, + endpoint_id=endpoint_id, + ) + + # Assert - both should return True + assert result1 is True + assert result2 is True + + def test_delete_endpoint_non_idempotent_unauthorized_error(self, endpoint_client, mock_config): + """Test that authorization errors are not treated as idempotent. + + Given: + - A valid tenant_id, user_id, and endpoint_id + - The plugin daemon returns an unauthorized error + When: + - delete_endpoint is called + Then: + - The method should raise the appropriate error (not return True) + """ + # Arrange + tenant_id = "tenant-123" + user_id = "user-456" + endpoint_id = "endpoint-789" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "code": -1, + "message": '{"error_type": "PluginDaemonUnauthorizedError", "message": "unauthorized access"}', + } + + with patch("httpx.request", return_value=mock_response): + # Act & Assert + with pytest.raises(Exception) as exc_info: + endpoint_client.delete_endpoint( + tenant_id=tenant_id, + user_id=user_id, + endpoint_id=endpoint_id, + ) + + # Assert - should not return True for unauthorized errors + assert exc_info.value.__class__.__name__ == "PluginDaemonUnauthorizedError" diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 2a0b293a39..9e911e1fce 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -346,6 +346,7 @@ class TestPluginRuntimeErrorHandling: mock_response.status_code = 200 invoke_error = { "error_type": "InvokeRateLimitError", + "message": "Rate limit exceeded", "args": {"description": "Rate limit exceeded"}, } error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) @@ -364,6 +365,7 @@ class TestPluginRuntimeErrorHandling: mock_response.status_code = 200 invoke_error = { "error_type": "InvokeAuthorizationError", + "message": "Invalid credentials", "args": {"description": "Invalid credentials"}, } error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) @@ -382,6 +384,7 @@ class TestPluginRuntimeErrorHandling: mock_response.status_code = 200 invoke_error = { "error_type": "InvokeBadRequestError", + "message": "Invalid parameters", "args": {"description": "Invalid parameters"}, } error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) @@ -400,6 +403,7 @@ class TestPluginRuntimeErrorHandling: mock_response.status_code = 200 invoke_error = { "error_type": "InvokeConnectionError", + "message": "Connection to external service failed", "args": {"description": "Connection to external service failed"}, } error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) @@ -418,6 +422,7 @@ class TestPluginRuntimeErrorHandling: mock_response.status_code = 200 invoke_error = { "error_type": "InvokeServerUnavailableError", + "message": "Service temporarily unavailable", "args": {"description": "Service temporarily unavailable"}, } error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) diff --git a/api/tests/unit_tests/core/rag/cleaner/__init__.py b/api/tests/unit_tests/core/rag/cleaner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py new file mode 100644 index 0000000000..65ee62b8dd --- /dev/null +++ b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py @@ -0,0 +1,213 @@ +from core.rag.cleaner.clean_processor import CleanProcessor + + +class TestCleanProcessor: + """Test cases for CleanProcessor.clean method.""" + + def test_clean_default_removal_of_invalid_symbols(self): + """Test default cleaning removes invalid symbols.""" + # Test <| replacement + assert CleanProcessor.clean("text<|with<|invalid", None) == "text replacement + assert CleanProcessor.clean("text|>with|>invalid", None) == "text>with>invalid" + + # Test removal of control characters + text_with_control = "normal\x00text\x1fwith\x07control\x7fchars" + expected = "normaltextwithcontrolchars" + assert CleanProcessor.clean(text_with_control, None) == expected + + # Test U+FFFE removal + text_with_ufffe = "normal\ufffepadding" + expected = "normalpadding" + assert CleanProcessor.clean(text_with_ufffe, None) == expected + + def test_clean_with_none_process_rule(self): + """Test cleaning with None process_rule - only default cleaning applied.""" + text = "Hello<|World\x00" + expected = "Hello becomes >, control chars and U+FFFE are removed + assert CleanProcessor.clean(text, None) == "<<>>" + + def test_clean_multiple_markdown_links_preserved(self): + """Test multiple markdown links are all preserved.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + text = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)" + expected = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)" + assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_markdown_link_text_as_url(self): + """Test markdown link where the link text itself is a URL.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + # Link text that looks like URL should be preserved + text = "[https://text-url.com](https://actual-url.com)" + expected = "[https://text-url.com](https://actual-url.com)" + assert CleanProcessor.clean(text, process_rule) == expected + + # Text URL without markdown should be removed + text = "https://text-url.com https://actual-url.com" + expected = " " + assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_complex_markdown_link_content(self): + """Test markdown links with complex content - known limitation with brackets in link text.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + # Note: The regex pattern [^\]]* cannot handle ] within link text + # This is a known limitation - the pattern stops at the first ] + text = "[Text with [brackets] and (parens)](https://example.com)" + # Actual behavior: only matches up to first ], URL gets removed + expected = "[Text with [brackets] and (parens)](" + assert CleanProcessor.clean(text, process_rule) == expected + + # Test that properly formatted markdown links work + text = "[Text with (parens) and symbols](https://example.com)" + expected = "[Text with (parens) and symbols](https://example.com)" + assert CleanProcessor.clean(text, process_rule) == expected diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py new file mode 100644 index 0000000000..4998a9858f --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py @@ -0,0 +1,327 @@ +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.pgvector.pgvector import ( + PGVector, + PGVectorConfig, +) + + +class TestPGVector(unittest.TestCase): + def setUp(self): + self.config = PGVectorConfig( + host="localhost", + port=5432, + user="test_user", + password="test_password", + database="test_db", + min_connection=1, + max_connection=5, + pg_bigm=False, + ) + self.collection_name = "test_collection" + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + def test_init(self, mock_pool_class): + """Test PGVector initialization.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + pgvector = PGVector(self.collection_name, self.config) + + assert pgvector._collection_name == self.collection_name + assert pgvector.table_name == f"embedding_{self.collection_name}" + assert pgvector.get_type() == "pgvector" + assert pgvector.pool is not None + assert pgvector.pg_bigm is False + assert pgvector.index_hash is not None + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + def test_init_with_pg_bigm(self, mock_pool_class): + """Test PGVector initialization with pg_bigm enabled.""" + config = PGVectorConfig( + host="localhost", + port=5432, + user="test_user", + password="test_password", + database="test_db", + min_connection=1, + max_connection=5, + pg_bigm=True, + ) + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + pgvector = PGVector(self.collection_name, config) + + assert pgvector.pg_bigm is True + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_basic(self, mock_redis, mock_pool_class): + """Test basic collection creation.""" + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = [1] # vector extension exists + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(1536) + + # Verify SQL execution calls + assert mock_cursor.execute.called + + # Check that CREATE TABLE was called with correct dimension + create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)] + assert len(create_table_calls) == 1 + assert "vector(1536)" in create_table_calls[0][0][0] + + # Check that CREATE INDEX was called (dimension <= 2000) + create_index_calls = [ + call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call) + ] + assert len(create_index_calls) == 1 + + # Verify Redis cache was set + mock_redis.set.assert_called_once() + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class): + """Test collection creation with dimension > 2000 (no HNSW index).""" + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = [1] # vector extension exists + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(3072) # Dimension > 2000 + + # Check that CREATE TABLE was called + create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)] + assert len(create_table_calls) == 1 + assert "vector(3072)" in create_table_calls[0][0][0] + + # Check that HNSW index was NOT created (dimension > 2000) + hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)] + assert len(hnsw_index_calls) == 0 + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class): + """Test collection creation with pg_bigm enabled.""" + config = PGVectorConfig( + host="localhost", + port=5432, + user="test_user", + password="test_password", + database="test_db", + min_connection=1, + max_connection=5, + pg_bigm=True, + ) + + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = [1] # vector extension exists + + pgvector = PGVector(self.collection_name, config) + pgvector._create_collection(1536) + + # Check that pg_bigm index was created + bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)] + assert len(bigm_index_calls) == 1 + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class): + """Test that vector extension is created if it doesn't exist.""" + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + # First call: vector extension doesn't exist + mock_cursor.fetchone.return_value = None + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(1536) + + # Check that CREATE EXTENSION was called + create_extension_calls = [ + call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call) + ] + assert len(create_extension_calls) == 1 + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class): + """Test that collection creation is skipped when cache exists.""" + # Mock Redis operations - cache exists + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = 1 # Cache exists + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(1536) + + # Check that no SQL was executed (early return due to cache) + assert mock_cursor.execute.call_count == 0 + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class): + """Test that Redis lock is used during collection creation.""" + # Mock Redis operations + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = [1] # vector extension exists + + pgvector = PGVector(self.collection_name, self.config) + pgvector._create_collection(1536) + + # Verify Redis lock was acquired with correct lock name + mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20) + + # Verify lock context manager was entered and exited + mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() + + @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + def test_get_cursor_context_manager(self, mock_pool_class): + """Test that _get_cursor properly manages connection lifecycle.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.getconn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + pgvector = PGVector(self.collection_name, self.config) + + with pgvector._get_cursor() as cur: + assert cur == mock_cursor + + # Verify connection lifecycle methods were called + mock_pool.getconn.assert_called_once() + mock_cursor.close.assert_called_once() + mock_conn.commit.assert_called_once() + mock_pool.putconn.assert_called_once_with(mock_conn) + + +@pytest.mark.parametrize( + "invalid_config_override", + [ + {"host": ""}, # Test empty host + {"port": 0}, # Test invalid port + {"user": ""}, # Test empty user + {"password": ""}, # Test empty password + {"database": ""}, # Test empty database + {"min_connection": 0}, # Test invalid min_connection + {"max_connection": 0}, # Test invalid max_connection + {"min_connection": 10, "max_connection": 5}, # Test min > max + ], +) +def test_config_validation_parametrized(invalid_config_override): + """Test configuration validation for various invalid inputs using parametrize.""" + config = { + "host": "localhost", + "port": 5432, + "user": "test_user", + "password": "test_password", + "database": "test_db", + "min_connection": 1, + "max_connection": 5, + } + config.update(invalid_config_override) + + with pytest.raises(ValueError): + PGVectorConfig(**config) + + +if __name__ == "__main__": + unittest.main() diff --git a/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py new file mode 100644 index 0000000000..3167a9a301 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py @@ -0,0 +1,186 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.extractor.pdf_extractor as pe + + +@pytest.fixture +def mock_dependencies(monkeypatch): + # Mock storage + saves = [] + + def save(key, data): + saves.append((key, data)) + + monkeypatch.setattr(pe, "storage", SimpleNamespace(save=save)) + + # Mock db + class DummySession: + def __init__(self): + self.added = [] + self.committed = False + + def add(self, obj): + self.added.append(obj) + + def add_all(self, objs): + self.added.extend(objs) + + def commit(self): + self.committed = True + + db_stub = SimpleNamespace(session=DummySession()) + monkeypatch.setattr(pe, "db", db_stub) + + # Mock UploadFile + class FakeUploadFile: + DEFAULT_ID = "test_file_id" + + def __init__(self, **kwargs): + # Assign id from DEFAULT_ID, allow override via kwargs if needed + self.id = self.DEFAULT_ID + for k, v in kwargs.items(): + setattr(self, k, v) + + monkeypatch.setattr(pe, "UploadFile", FakeUploadFile) + + # Mock config + monkeypatch.setattr(pe.dify_config, "FILES_URL", "http://files.local") + monkeypatch.setattr(pe.dify_config, "INTERNAL_FILES_URL", None) + monkeypatch.setattr(pe.dify_config, "STORAGE_TYPE", "local") + + return SimpleNamespace(saves=saves, db=db_stub, UploadFile=FakeUploadFile) + + +@pytest.mark.parametrize( + ("image_bytes", "expected_mime", "expected_ext", "file_id"), + [ + (b"\xff\xd8\xff some jpeg", "image/jpeg", "jpg", "test_file_id_jpeg"), + (b"\x89PNG\r\n\x1a\n some png", "image/png", "png", "test_file_id_png"), + ], +) +def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, expected_mime, expected_ext, file_id): + saves = mock_dependencies.saves + db_stub = mock_dependencies.db + + # Customize FakeUploadFile id for this test case. + # Using monkeypatch ensures the class attribute is reset between parameter sets. + monkeypatch.setattr(mock_dependencies.UploadFile, "DEFAULT_ID", file_id) + + # Mock page and image objects + mock_page = MagicMock() + mock_image_obj = MagicMock() + + def mock_extract(buf, fb_format=None): + buf.write(image_bytes) + + mock_image_obj.extract.side_effect = mock_extract + + mock_page.get_objects.return_value = [mock_image_obj] + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + # We need to handle the import inside _extract_images + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + assert f"![image](http://files.local/files/{file_id}/file-preview)" in result + assert len(saves) == 1 + assert saves[0][1] == image_bytes + assert len(db_stub.session.added) == 1 + assert db_stub.session.added[0].tenant_id == "t1" + assert db_stub.session.added[0].size == len(image_bytes) + assert db_stub.session.added[0].mime_type == expected_mime + assert db_stub.session.added[0].extension == expected_ext + assert db_stub.session.committed is True + + +@pytest.mark.parametrize( + ("get_objects_side_effect", "get_objects_return_value"), + [ + (None, []), # Empty list + (None, None), # None returned + (Exception("Failed to get objects"), None), # Exception raised + ], +) +def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_side_effect, get_objects_return_value): + mock_page = MagicMock() + if get_objects_side_effect: + mock_page.get_objects.side_effect = get_objects_side_effect + else: + mock_page.get_objects.return_value = get_objects_return_value + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + assert result == "" + + +def test_extract_calls_extract_images(mock_dependencies, monkeypatch): + # Mock pypdfium2 + mock_pdf_doc = MagicMock() + mock_page = MagicMock() + mock_pdf_doc.__iter__.return_value = [mock_page] + + # Mock text extraction + mock_text_page = MagicMock() + mock_text_page.get_text_range.return_value = "Page text content" + mock_page.get_textpage.return_value = mock_text_page + + with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc): + # Mock Blob + mock_blob = MagicMock() + mock_blob.source = "test.pdf" + with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob): + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + # Mock _extract_images to return a known string + monkeypatch.setattr(extractor, "_extract_images", lambda p: "![image](img_url)") + + documents = list(extractor.extract()) + + assert len(documents) == 1 + assert "Page text content" in documents[0].page_content + assert "![image](img_url)" in documents[0].page_content + assert documents[0].metadata["page"] == 0 + + +def test_extract_images_failures(mock_dependencies): + saves = mock_dependencies.saves + db_stub = mock_dependencies.db + + # Mock page and image objects + mock_page = MagicMock() + mock_image_obj_fail = MagicMock() + mock_image_obj_ok = MagicMock() + + # First image raises exception + mock_image_obj_fail.extract.side_effect = Exception("Extraction failure") + + # Second image is OK (JPEG) + jpeg_bytes = b"\xff\xd8\xff some image data" + + def mock_extract(buf, fb_format=None): + buf.write(jpeg_bytes) + + mock_image_obj_ok.extract.side_effect = mock_extract + + mock_page.get_objects.return_value = [mock_image_obj_fail, mock_image_obj_ok] + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + # Should have one success + assert "![image](http://files.local/files/test_file_id/file-preview)" in result + assert len(saves) == 1 + assert saves[0][1] == jpeg_bytes + assert db_stub.session.committed is True diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 3203aab8c3..f9e59a5f05 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,8 +1,12 @@ """Primarily used for testing merged cell scenarios""" +import os +import tempfile from types import SimpleNamespace from docx import Document +from docx.oxml import OxmlElement +from docx.oxml.ns import qn import core.rag.extractor.word_extractor as we from core.rag.extractor.word_extractor import WordExtractor @@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url(): dify_config.FILES_URL = original_files_url if original_internal_files_url is not None: dify_config.INTERNAL_FILES_URL = original_internal_files_url + + +def test_extract_hyperlinks(monkeypatch): + # Mock db and storage to avoid issues during image extraction (even if no images are present) + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None)) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + doc = Document() + p = doc.add_paragraph("Visit ") + + # Adding a hyperlink manually + r_id = "rId99" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + new_run = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "Dify" + new_run.append(t) + hyperlink.append(new_run) + p._p.append(hyperlink) + + # Add relationship to the part + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "https://dify.ai", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + extractor = WordExtractor(tmp_path, "tenant_id", "user_id") + docs = extractor.extract() + # Verify modern hyperlink extraction + assert "Visit[Dify](https://dify.ai)" in docs[0].page_content + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_extract_legacy_hyperlinks(monkeypatch): + # Mock db and storage + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None)) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + doc = Document() + p = doc.add_paragraph() + + # Construct a legacy HYPERLINK field: + # 1. w:fldChar (begin) + # 2. w:instrText (HYPERLINK "http://example.com") + # 3. w:fldChar (separate) + # 4. w:r (visible text "Example") + # 5. w:fldChar (end) + + run1 = OxmlElement("w:r") + fldCharBegin = OxmlElement("w:fldChar") + fldCharBegin.set(qn("w:fldCharType"), "begin") + run1.append(fldCharBegin) + p._p.append(run1) + + run2 = OxmlElement("w:r") + instrText = OxmlElement("w:instrText") + instrText.text = ' HYPERLINK "http://example.com" ' + run2.append(instrText) + p._p.append(run2) + + run3 = OxmlElement("w:r") + fldCharSep = OxmlElement("w:fldChar") + fldCharSep.set(qn("w:fldCharType"), "separate") + run3.append(fldCharSep) + p._p.append(run3) + + run4 = OxmlElement("w:r") + t4 = OxmlElement("w:t") + t4.text = "Example" + run4.append(t4) + p._p.append(run4) + + run5 = OxmlElement("w:r") + fldCharEnd = OxmlElement("w:fldChar") + fldCharEnd.set(qn("w:fldCharType"), "end") + run5.append(fldCharEnd) + p._p.append(run5) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + extractor = WordExtractor(tmp_path, "tenant_id", "user_id") + docs = extractor.extract() + # Verify legacy hyperlink extraction + assert "[Example](http://example.com)" in docs[0].page_content + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index affd6c648f..ca08cb0591 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -73,6 +73,7 @@ import pytest from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from models.dataset import Dataset @@ -421,7 +422,18 @@ class TestRetrievalService: # In real code, this waits for all futures to complete # In tests, futures complete immediately, so wait is a no-op with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"): - yield mock_executor + # Mock concurrent.futures.as_completed for early error propagation + # In real code, this yields futures as they complete + # In tests, we yield all futures immediately since they're already done + def mock_as_completed(futures_list, timeout=None): + """Mock as_completed that yields futures immediately.""" + yield from futures_list + + with patch( + "core.rag.datasource.retrieval_service.concurrent.futures.as_completed", + side_effect=mock_as_completed, + ): + yield mock_executor # ==================== Vector Search Tests ==================== @@ -1507,6 +1519,282 @@ class TestRetrievalService: call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["reranking_model"] == reranking_model + # ==================== Multiple Retrieve Thread Tests ==================== + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + def test_multiple_retrieve_thread_skips_second_reranking_with_single_dataset( + self, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread skips second reranking when dataset_count is 1. + + When there is only one dataset, the second reranking is unnecessary + because the documents are already ranked from the first retrieval. + This optimization avoids the overhead of reranking when it won't + provide any benefit. + + Verifies: + - DataPostProcessor is NOT called when dataset_count == 1 + - Documents are still added to all_documents + - Standard scoring logic is applied instead + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + all_documents = [] + + # Act - Call with dataset_count = 1 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=1, # Single dataset - should skip second reranking + ) + + # Assert + # DataPostProcessor should NOT be called (second reranking skipped) + mock_data_processor_class.assert_not_called() + + # Documents should still be added to all_documents + assert len(all_documents) == 2 + assert all_documents[0].page_content == "Test content 1" + assert all_documents[1].page_content == "Test content 2" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score") + def test_multiple_retrieve_thread_performs_second_reranking_with_multiple_datasets( + self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread performs second reranking when dataset_count > 1. + + When there are multiple datasets, the second reranking is necessary + to merge and re-rank results from different datasets. This ensures + the most relevant documents across all datasets are returned. + + Verifies: + - DataPostProcessor IS called when dataset_count > 1 + - Reranking is applied with correct parameters + - Documents are processed correctly + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.7, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.6, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + # Mock DataPostProcessor instance and its invoke method + mock_processor_instance = Mock() + # Simulate reranking - return documents in reversed order with updated scores + reranked_docs = [ + Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.85, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + ] + mock_processor_instance.invoke.return_value = reranked_docs + mock_data_processor_class.return_value = mock_processor_instance + + all_documents = [] + + # Create second dataset + mock_dataset2 = Mock(spec=Dataset) + mock_dataset2.id = str(uuid4()) + mock_dataset2.indexing_technique = "high_quality" + mock_dataset2.provider = "dify" + + # Act - Call with dataset_count = 2 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset, mock_dataset2], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=2, # Multiple datasets - should perform second reranking + ) + + # Assert + # DataPostProcessor SHOULD be called (second reranking performed) + mock_data_processor_class.assert_called_once_with( + tenant_id, + "reranking_model", + {"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + None, + False, + ) + + # Verify invoke was called with correct parameters + mock_processor_instance.invoke.assert_called_once() + + # Documents should be added to all_documents after reranking + assert len(all_documents) == 2 + # The reranked order should be reflected + assert all_documents[0].page_content == "Test content 2" + assert all_documents[1].page_content == "Test content 1" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score") + def test_multiple_retrieve_thread_single_dataset_uses_standard_scoring( + self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread uses standard scoring when dataset_count is 1 + and reranking is enabled. + + When there's only one dataset, instead of using DataPostProcessor, + the method should fall through to the standard scoring logic + (calculate_vector_score for high_quality datasets). + + Verifies: + - DataPostProcessor is NOT called + - calculate_vector_score IS called for high_quality indexing + - Documents are scored correctly + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + # Mock calculate_vector_score to return scored documents + scored_docs = [ + Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + ] + mock_calculate_vector_score.return_value = scored_docs + + all_documents = [] + + # Act - Call with dataset_count = 1 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, # Reranking enabled but should be skipped for single dataset + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=1, + ) + + # Assert + # DataPostProcessor should NOT be called + mock_data_processor_class.assert_not_called() + + # calculate_vector_score SHOULD be called for high_quality datasets + mock_calculate_vector_score.assert_called_once() + call_args = mock_calculate_vector_score.call_args + assert call_args[0][1] == 5 # top_k + + # Documents should be added after standard scoring + assert len(all_documents) == 1 + assert all_documents[0].page_content == "Test content 1" + class TestRetrievalMethods: """ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py new file mode 100644 index 0000000000..07d6e51e4b --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py @@ -0,0 +1,873 @@ +""" +Unit tests for DatasetRetrieval.process_metadata_filter_func. + +This module provides comprehensive test coverage for the process_metadata_filter_func +method in the DatasetRetrieval class, which is responsible for building SQLAlchemy +filter expressions based on metadata filtering conditions. + +Conditions Tested: +================== +1. **String Conditions**: contains, not contains, start with, end with +2. **Equality Conditions**: is / =, is not / ≠ +3. **Null Conditions**: empty, not empty +4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >= +5. **List Conditions**: in +6. **Edge Cases**: None values, different data types (str, int, float) + +Test Architecture: +================== +- Direct instantiation of DatasetRetrieval +- Mocking of DatasetDocument model attributes +- Verification of SQLAlchemy filter expressions +- Follows Arrange-Act-Assert (AAA) pattern + +Running Tests: +============== + # Run all tests in this module + uv run --project api pytest \ + api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v + + # Run a specific test + uv run --project api pytest \ + api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\ +TestProcessMetadataFilterFunc::test_contains_condition -v +""" + +from unittest.mock import MagicMock + +import pytest + +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval + + +class TestProcessMetadataFilterFunc: + """ + Comprehensive test suite for process_metadata_filter_func method. + + This test class validates all metadata filtering conditions supported by + the DatasetRetrieval class, including string operations, numeric comparisons, + null checks, and list operations. + + Method Signature: + ================== + def process_metadata_filter_func( + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list + ) -> list: + + The method builds SQLAlchemy filter expressions by: + 1. Validating value is not None (except for empty/not empty conditions) + 2. Using DatasetDocument.doc_metadata JSON field operations + 3. Adding appropriate SQLAlchemy expressions to the filters list + 4. Returning the updated filters list + + Mocking Strategy: + ================== + - Mock DatasetDocument.doc_metadata to avoid database dependencies + - Verify filter expressions are created correctly + - Test with various data types (str, int, float, list) + """ + + @pytest.fixture + def retrieval(self): + """ + Create a DatasetRetrieval instance for testing. + + Returns: + DatasetRetrieval: Instance to test process_metadata_filter_func + """ + return DatasetRetrieval() + + @pytest.fixture + def mock_doc_metadata(self): + """ + Mock the DatasetDocument.doc_metadata JSON field. + + The method uses DatasetDocument.doc_metadata[metadata_name] to access + JSON fields. We mock this to avoid database dependencies. + + Returns: + Mock: Mocked doc_metadata attribute + """ + mock_metadata_field = MagicMock() + + # Create mock for string access + mock_string_access = MagicMock() + mock_string_access.like = MagicMock() + mock_string_access.notlike = MagicMock() + mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_string_access.in_ = MagicMock(return_value=MagicMock()) + + # Create mock for float access (for numeric comparisons) + mock_float_access = MagicMock() + mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__le__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) + + # Create mock for null checks + mock_null_access = MagicMock() + mock_null_access.is_ = MagicMock(return_value=MagicMock()) + mock_null_access.isnot = MagicMock(return_value=MagicMock()) + + # Setup __getitem__ to return appropriate mock based on usage + def getitem_side_effect(name): + if name in ["author", "title", "category"]: + return mock_string_access + elif name in ["year", "price", "rating"]: + return mock_float_access + else: + return mock_string_access + + mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) + mock_metadata_field.as_string.return_value = mock_string_access + mock_metadata_field.as_float.return_value = mock_float_access + mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ + mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot + + return mock_metadata_field + + # ==================== String Condition Tests ==================== + + def test_contains_condition_string_value(self, retrieval): + """ + Test 'contains' condition with string value. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value% syntax + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "John" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_contains_condition(self, retrieval): + """ + Test 'not contains' condition. + + Verifies: + - Filters list is populated with NOT LIKE expression + - Pattern matching uses %value% syntax with negation + """ + filters = [] + sequence = 0 + condition = "not contains" + metadata_name = "title" + value = "banned" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_start_with_condition(self, retrieval): + """ + Test 'start with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses value% syntax + """ + filters = [] + sequence = 0 + condition = "start with" + metadata_name = "category" + value = "tech" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_end_with_condition(self, retrieval): + """ + Test 'end with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value syntax + """ + filters = [] + sequence = 0 + condition = "end with" + metadata_name = "filename" + value = ".pdf" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Equality Condition Tests ==================== + + def test_is_condition_with_string_value(self, retrieval): + """ + Test 'is' (=) condition with string value. + + Verifies: + - Filters list is populated with equality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = "Jane Doe" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_equals_condition_with_string_value(self, retrieval): + """ + Test '=' condition with string value. + + Verifies: + - Same behavior as 'is' condition + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "=" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_int_value(self, retrieval): + """ + Test 'is' condition with integer value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_float_value(self, retrieval): + """ + Test 'is' condition with float value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "price" + value = 19.99 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_string_value(self, retrieval): + """ + Test 'is not' (≠) condition with string value. + + Verifies: + - Filters list is populated with inequality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "author" + value = "Unknown" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_equals_condition(self, retrieval): + """ + Test '≠' condition with string value. + + Verifies: + - Same behavior as 'is not' condition + - Inequality expression is used + """ + filters = [] + sequence = 0 + condition = "≠" + metadata_name = "category" + value = "archived" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_numeric_value(self, retrieval): + """ + Test 'is not' condition with numeric value. + + Verifies: + - Numeric inequality comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Null Condition Tests ==================== + + def test_empty_condition(self, retrieval): + """ + Test 'empty' condition (null check). + + Verifies: + - Filters list is populated with IS NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "empty" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_empty_condition(self, retrieval): + """ + Test 'not empty' condition (not null check). + + Verifies: + - Filters list is populated with IS NOT NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "not empty" + metadata_name = "description" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Numeric Comparison Tests ==================== + + def test_before_condition(self, retrieval): + """ + Test 'before' (<) condition. + + Verifies: + - Filters list is populated with less than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "before" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_condition(self, retrieval): + """ + Test '<' condition. + + Verifies: + - Same behavior as 'before' condition + - Less than expression is used + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "price" + value = 100.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_after_condition(self, retrieval): + """ + Test 'after' (>) condition. + + Verifies: + - Filters list is populated with greater than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "after" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_condition(self, retrieval): + """ + Test '>' condition. + + Verifies: + - Same behavior as 'after' condition + - Greater than expression is used + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≤' condition. + + Verifies: + - Filters list is populated with less than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≤" + metadata_name = "price" + value = 50.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_ascii(self, retrieval): + """ + Test '<=' condition. + + Verifies: + - Same behavior as '≤' condition + - Less than or equal expression is used + """ + filters = [] + sequence = 0 + condition = "<=" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≥' condition. + + Verifies: + - Filters list is populated with greater than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≥" + metadata_name = "rating" + value = 3.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_ascii(self, retrieval): + """ + Test '>=' condition. + + Verifies: + - Same behavior as '≥' condition + - Greater than or equal expression is used + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== List/In Condition Tests ==================== + + def test_in_condition_with_comma_separated_string(self, retrieval): + """ + Test 'in' condition with comma-separated string value. + + Verifies: + - String is split into list + - Whitespace is trimmed from each value + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "tech, science, AI " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_list_value(self, retrieval): + """ + Test 'in' condition with list value. + + Verifies: + - List is processed correctly + - None values are filtered out + - IN expression is created with valid values + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "tags" + value = ["python", "javascript", None, "golang"] + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_tuple_value(self, retrieval): + """ + Test 'in' condition with tuple value. + + Verifies: + - Tuple is processed like a list + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = ("tech", "science", "ai") + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_empty_string(self, retrieval): + """ + Test 'in' condition with empty string value. + + Verifies: + - Empty string results in literal(False) filter + - No valid values to match + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + # Verify it's a literal(False) expression + # This is a bit tricky to test without access to the actual expression + + def test_in_condition_with_only_whitespace(self, retrieval): + """ + Test 'in' condition with whitespace-only string value. + + Verifies: + - Whitespace-only string results in literal(False) filter + - All values are stripped and filtered out + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = " , , " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_single_string(self, retrieval): + """ + Test 'in' condition with single non-comma string. + + Verifies: + - Single string is treated as single-item list + - IN expression is created with one value + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Edge Case Tests ==================== + + def test_none_value_with_non_empty_condition(self, retrieval): + """ + Test None value with conditions that require value. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values (except empty/not empty) + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 # No filter added + + def test_none_value_with_equals_condition(self, retrieval): + """ + Test None value with 'is' (=) condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_none_value_with_numeric_condition(self, retrieval): + """ + Test None value with numeric comparison condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "year" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_existing_filters_preserved(self, retrieval): + """ + Test that existing filters are preserved. + + Verifies: + - Existing filters in the list are not removed + - New filters are appended to the list + """ + existing_filter = MagicMock() + filters = [existing_filter] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 2 + assert filters[0] == existing_filter + + def test_multiple_filters_accumulated(self, retrieval): + """ + Test multiple calls to accumulate filters. + + Verifies: + - Each call adds a new filter to the list + - All filters are preserved across calls + """ + filters = [] + + # First filter + retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) + assert len(filters) == 1 + + # Second filter + retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) + assert len(filters) == 2 + + # Third filter + retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) + assert len(filters) == 3 + + def test_unknown_condition(self, retrieval): + """ + Test unknown/unsupported condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for unknown conditions + """ + filters = [] + sequence = 0 + condition = "unknown_condition" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_empty_string_value_with_contains(self, retrieval): + """ + Test empty string value with 'contains' condition. + + Verifies: + - Filter is added even with empty string + - LIKE expression is created + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_special_characters_in_value(self, retrieval): + """ + Test special characters in value string. + + Verifies: + - Special characters are handled in value + - LIKE expression is created correctly + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "title" + value = "C++ & Python's features" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_zero_value_with_numeric_condition(self, retrieval): + """ + Test zero value with numeric comparison condition. + + Verifies: + - Zero is treated as valid value + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "price" + value = 0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_negative_value_with_numeric_condition(self, retrieval): + """ + Test negative value with numeric comparison condition. + + Verifies: + - Negative numbers are handled correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "temperature" + value = -10.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_float_value_with_integer_comparison(self, retrieval): + """ + Test float value with numeric comparison condition. + + Verifies: + - Float values work correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 diff --git a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py new file mode 100644 index 0000000000..5f461d53ae --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py @@ -0,0 +1,113 @@ +import threading +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from flask import Flask, current_app + +from core.rag.models.document import Document +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from models.dataset import Dataset + + +class TestRetrievalService: + @pytest.fixture + def mock_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = str(uuid4()) + dataset.tenant_id = str(uuid4()) + dataset.name = "test_dataset" + dataset.indexing_technique = "high_quality" + dataset.provider = "dify" + return dataset + + def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): + """ + Repro test for current bug: + reranking runs after `with flask_app.app_context():` exits. + `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, + so we must assert from that list (not from an outer try/except). + """ + dataset_retrieval = DatasetRetrieval() + flask_app = Flask(__name__) + tenant_id = str(uuid4()) + + # second dataset to ensure dataset_count > 1 reranking branch + secondary_dataset = Mock(spec=Dataset) + secondary_dataset.id = str(uuid4()) + secondary_dataset.provider = "dify" + secondary_dataset.indexing_technique = "high_quality" + + # retriever returns 1 doc into internal list (all_documents_item) + document = Document( + page_content="Context aware doc", + metadata={ + "doc_id": "doc1", + "score": 0.95, + "document_id": str(uuid4()), + "dataset_id": mock_dataset.id, + }, + provider="dify", + ) + + def fake_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.append(document) + + called = {"init": 0, "invoke": 0} + + class ContextRequiredPostProcessor: + def __init__(self, *args, **kwargs): + called["init"] += 1 + # will raise RuntimeError if no Flask app context exists + _ = current_app.name + + def invoke(self, *args, **kwargs): + called["invoke"] += 1 + _ = current_app.name + return kwargs.get("documents") or args[1] + + # output list from _multiple_retrieve_thread + all_documents: list[Document] = [] + + # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here + thread_exceptions: list[Exception] = [] + + def target(): + with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): + with patch( + "core.rag.retrieval.dataset_retrieval.DataPostProcessor", + ContextRequiredPostProcessor, + ): + dataset_retrieval._multiple_retrieve_thread( + flask_app=flask_app, + available_datasets=[mock_dataset, secondary_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={ + "reranking_provider_name": "cohere", + "reranking_model_name": "rerank-v2", + }, + weights=None, + top_k=3, + score_threshold=0.0, + query="test query", + attachment_id=None, + dataset_count=2, # force reranking branch + thread_exceptions=thread_exceptions, # ✅ key + ) + + t = threading.Thread(target=target) + t.start() + t.join() + + # Ensure reranking branch was actually executed + assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." + + # Current buggy code should record an exception (not raise it) + assert not thread_exceptions, thread_exceptions diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 9060cf7b6c..636fac7a40 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -32,7 +32,6 @@ def mock_provider_entity(): label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"), icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"), - icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"), background="background.png", help=None, supported_model_types=[ModelType.LLM], diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 5d180c7cbc..cd45292488 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M def scalar(self, _stmt): return self.results.pop(0) + # SQLAlchemy Session APIs used by code under test + def expunge(self, *_args, **_kwargs): + pass + + def close(self): + pass + + # support `with session_factory.create_session() as session:` + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + tenant = SimpleNamespace(id="tenant_id") end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") - db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user])) - monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + # Monkeypatch session factory to return our stub session + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession([tenant, None, end_user]), + ) entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), @@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt def scalar(self, _stmt): return self.results.pop(0) - db_stub = SimpleNamespace(session=StubSession([None])) - monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + def expunge(self, *_args, **_kwargs): + pass + + def close(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + # Monkeypatch session factory to return our stub session with no tenant + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession([None]), + ) entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index af4f96ba23..aa16c8af1c 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -35,7 +35,6 @@ from core.variables.variables import ( SecretVariable, StringVariable, Variable, - VariableUnion, ) from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -96,7 +95,7 @@ class _Segments(BaseModel): class _Variables(BaseModel): - variables: list[VariableUnion] + variables: list[Variable] def create_test_file( @@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad: # Create one instance of each variable type test_file = create_test_file() - all_variables: list[VariableUnion] = [ + all_variables: list[Variable] = [ NoneVariable(name="none_var"), StringVariable(value="test string", name="string_var"), IntegerVariable(value=42, name="int_var"), diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 925142892c..fb4b18b57a 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -11,7 +11,7 @@ from core.variables import ( SegmentType, StringVariable, ) -from core.variables.variables import Variable +from core.variables.variables import VariableBase def test_frozen_variables(): @@ -76,7 +76,7 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var: Variable = StringVariable(name="text", value="text") + var: VariableBase = StringVariable(name="text", value="text") assert var.to_object() == "text" var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 diff --git a/api/tests/unit_tests/core/workflow/context/__init__.py b/api/tests/unit_tests/core/workflow/context/__init__.py new file mode 100644 index 0000000000..ac81c5c9e8 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/context/__init__.py @@ -0,0 +1 @@ +"""Tests for workflow context management.""" diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py new file mode 100644 index 0000000000..8dd669e17f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -0,0 +1,339 @@ +"""Tests for execution context module.""" + +import contextvars +import threading +from contextlib import contextmanager +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +from core.workflow.context.execution_context import ( + AppContext, + ExecutionContext, + ExecutionContextBuilder, + IExecutionContext, + NullAppContext, + read_context, + register_context, +) + + +class TestAppContext: + """Test AppContext abstract base class.""" + + def test_app_context_is_abstract(self): + """Test that AppContext cannot be instantiated directly.""" + with pytest.raises(TypeError): + AppContext() # type: ignore + + +class TestNullAppContext: + """Test NullAppContext implementation.""" + + def test_null_app_context_get_config(self): + """Test get_config returns value from config dict.""" + config = {"key1": "value1", "key2": "value2"} + ctx = NullAppContext(config=config) + + assert ctx.get_config("key1") == "value1" + assert ctx.get_config("key2") == "value2" + + def test_null_app_context_get_config_default(self): + """Test get_config returns default when key not found.""" + ctx = NullAppContext() + + assert ctx.get_config("nonexistent", "default") == "default" + assert ctx.get_config("nonexistent") is None + + def test_null_app_context_get_extension(self): + """Test get_extension returns stored extension.""" + ctx = NullAppContext() + extension = MagicMock() + ctx.set_extension("db", extension) + + assert ctx.get_extension("db") == extension + + def test_null_app_context_get_extension_not_found(self): + """Test get_extension returns None when extension not found.""" + ctx = NullAppContext() + + assert ctx.get_extension("nonexistent") is None + + def test_null_app_context_enter_yield(self): + """Test enter method yields without any side effects.""" + ctx = NullAppContext() + + with ctx.enter(): + # Should not raise any exception + pass + + +class TestExecutionContext: + """Test ExecutionContext class.""" + + def test_initialization_with_all_params(self): + """Test ExecutionContext initialization with all parameters.""" + app_ctx = NullAppContext() + context_vars = contextvars.copy_context() + user = MagicMock() + + ctx = ExecutionContext( + app_context=app_ctx, + context_vars=context_vars, + user=user, + ) + + assert ctx.app_context == app_ctx + assert ctx.context_vars == context_vars + assert ctx.user == user + + def test_initialization_with_minimal_params(self): + """Test ExecutionContext initialization with minimal parameters.""" + ctx = ExecutionContext() + + assert ctx.app_context is None + assert ctx.context_vars is None + assert ctx.user is None + + def test_enter_with_context_vars(self): + """Test enter restores context variables.""" + test_var = contextvars.ContextVar("test_var") + test_var.set("original_value") + + # Copy context with the variable + context_vars = contextvars.copy_context() + + # Change the variable + test_var.set("new_value") + + # Create execution context and enter it + ctx = ExecutionContext(context_vars=context_vars) + + with ctx.enter(): + # Variable should be restored to original value + assert test_var.get() == "original_value" + + # After exiting, variable stays at the value from within the context + # (this is expected Python contextvars behavior) + assert test_var.get() == "original_value" + + def test_enter_with_app_context(self): + """Test enter enters app context if available.""" + app_ctx = NullAppContext() + ctx = ExecutionContext(app_context=app_ctx) + + # Should not raise any exception + with ctx.enter(): + pass + + def test_enter_without_app_context(self): + """Test enter works without app context.""" + ctx = ExecutionContext(app_context=None) + + # Should not raise any exception + with ctx.enter(): + pass + + def test_context_manager_protocol(self): + """Test ExecutionContext supports context manager protocol.""" + ctx = ExecutionContext() + + with ctx: + # Should not raise any exception + pass + + def test_user_property(self): + """Test user property returns set user.""" + user = MagicMock() + ctx = ExecutionContext(user=user) + + assert ctx.user == user + + def test_thread_safe_context_manager(self): + """Test shared ExecutionContext works across threads without token mismatch.""" + test_var = contextvars.ContextVar("thread_safe_test_var") + + class TrackingAppContext(AppContext): + def get_config(self, key: str, default: Any = None) -> Any: + return default + + def get_extension(self, name: str) -> Any: + return None + + @contextmanager + def enter(self): + token = test_var.set(threading.get_ident()) + try: + yield + finally: + test_var.reset(token) + + ctx = ExecutionContext(app_context=TrackingAppContext()) + errors: list[Exception] = [] + barrier = threading.Barrier(2) + + def worker(): + try: + for _ in range(20): + with ctx: + try: + barrier.wait() + barrier.wait() + except threading.BrokenBarrierError: + return + except Exception as exc: + errors.append(exc) + try: + barrier.abort() + except Exception: + pass + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t1.start() + t2.start() + t1.join(timeout=5) + t2.join(timeout=5) + + assert not errors + + +class TestIExecutionContextProtocol: + """Test IExecutionContext protocol.""" + + def test_execution_context_implements_protocol(self): + """Test that ExecutionContext implements IExecutionContext protocol.""" + ctx = ExecutionContext() + + # Should have __enter__ and __exit__ methods + assert hasattr(ctx, "__enter__") + assert hasattr(ctx, "__exit__") + assert hasattr(ctx, "user") + + def test_protocol_compatibility(self): + """Test that ExecutionContext can be used where IExecutionContext is expected.""" + + def accept_context(context: IExecutionContext) -> Any: + """Function that accepts IExecutionContext protocol.""" + # Just verify it has the required protocol attributes + assert hasattr(context, "__enter__") + assert hasattr(context, "__exit__") + assert hasattr(context, "user") + return context.user + + ctx = ExecutionContext(user="test_user") + result = accept_context(ctx) + + assert result == "test_user" + + def test_protocol_with_flask_execution_context(self): + """Test that IExecutionContext protocol is compatible with different implementations.""" + # Verify the protocol works with ExecutionContext + ctx = ExecutionContext(user="test_user") + + # Should have the required protocol attributes + assert hasattr(ctx, "__enter__") + assert hasattr(ctx, "__exit__") + assert hasattr(ctx, "user") + assert ctx.user == "test_user" + + # Should work as context manager + with ctx: + assert ctx.user == "test_user" + + +class TestExecutionContextBuilder: + """Test ExecutionContextBuilder class.""" + + def test_builder_with_all_params(self): + """Test builder with all parameters set.""" + app_ctx = NullAppContext() + context_vars = contextvars.copy_context() + user = MagicMock() + + ctx = ( + ExecutionContextBuilder().with_app_context(app_ctx).with_context_vars(context_vars).with_user(user).build() + ) + + assert ctx.app_context == app_ctx + assert ctx.context_vars == context_vars + assert ctx.user == user + + def test_builder_with_partial_params(self): + """Test builder with only some parameters set.""" + app_ctx = NullAppContext() + + ctx = ExecutionContextBuilder().with_app_context(app_ctx).build() + + assert ctx.app_context == app_ctx + assert ctx.context_vars is None + assert ctx.user is None + + def test_builder_fluent_interface(self): + """Test builder provides fluent interface.""" + builder = ExecutionContextBuilder() + + # Each method should return the builder + assert isinstance(builder.with_app_context(NullAppContext()), ExecutionContextBuilder) + assert isinstance(builder.with_context_vars(contextvars.copy_context()), ExecutionContextBuilder) + assert isinstance(builder.with_user(None), ExecutionContextBuilder) + + +class TestCaptureCurrentContext: + """Test capture_current_context function.""" + + def test_capture_current_context_returns_context(self): + """Test that capture_current_context returns a valid context.""" + from core.workflow.context.execution_context import capture_current_context + + result = capture_current_context() + + # Should return an object that implements IExecutionContext + assert hasattr(result, "__enter__") + assert hasattr(result, "__exit__") + assert hasattr(result, "user") + + def test_capture_current_context_captures_contextvars(self): + """Test that capture_current_context captures context variables.""" + # Set a context variable before capturing + import contextvars + + test_var = contextvars.ContextVar("capture_test_var") + test_var.set("test_value_123") + + from core.workflow.context.execution_context import capture_current_context + + result = capture_current_context() + + # Context variables should be captured + assert result.context_vars is not None + + +class TestTenantScopedContextRegistry: + def setup_method(self): + from core.workflow.context import reset_context_provider + + reset_context_provider() + + def teardown_method(self): + from core.workflow.context import reset_context_provider + + reset_context_provider() + + def test_tenant_provider_read_ok(self): + class SandboxContext(BaseModel): + base_url: str | None = None + + register_context("workflow.sandbox", "t1", lambda: SandboxContext(base_url="http://t1")) + register_context("workflow.sandbox", "t2", lambda: SandboxContext(base_url="http://t2")) + + assert read_context("workflow.sandbox", tenant_id="t1").base_url == "http://t1" + assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" + + def test_missing_provider_raises_keyerror(self): + from core.workflow.context import ContextProviderNotFoundError + + with pytest.raises(ContextProviderNotFoundError): + read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py b/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py new file mode 100644 index 0000000000..a809b29552 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py @@ -0,0 +1,316 @@ +"""Tests for Flask app context module.""" + +import contextvars +from unittest.mock import MagicMock, patch + +import pytest + + +class TestFlaskAppContext: + """Test FlaskAppContext implementation.""" + + @pytest.fixture + def mock_flask_app(self): + """Create a mock Flask app.""" + app = MagicMock() + app.config = {"TEST_KEY": "test_value"} + app.extensions = {"db": MagicMock(), "cache": MagicMock()} + app.app_context = MagicMock() + app.app_context.return_value.__enter__ = MagicMock(return_value=None) + app.app_context.return_value.__exit__ = MagicMock(return_value=None) + return app + + def test_flask_app_context_initialization(self, mock_flask_app): + """Test FlaskAppContext initialization.""" + # Import here to avoid Flask dependency in test environment + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + assert ctx.flask_app == mock_flask_app + + def test_flask_app_context_get_config(self, mock_flask_app): + """Test get_config returns Flask app config value.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + assert ctx.get_config("TEST_KEY") == "test_value" + + def test_flask_app_context_get_config_default(self, mock_flask_app): + """Test get_config returns default when key not found.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + assert ctx.get_config("NONEXISTENT", "default") == "default" + + def test_flask_app_context_get_extension(self, mock_flask_app): + """Test get_extension returns Flask extension.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + db_ext = mock_flask_app.extensions["db"] + + assert ctx.get_extension("db") == db_ext + + def test_flask_app_context_get_extension_not_found(self, mock_flask_app): + """Test get_extension returns None when extension not found.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + assert ctx.get_extension("nonexistent") is None + + def test_flask_app_context_enter(self, mock_flask_app): + """Test enter method enters Flask app context.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + with ctx.enter(): + # Should not raise any exception + pass + + # Verify app_context was called + mock_flask_app.app_context.assert_called_once() + + +class TestFlaskExecutionContext: + """Test FlaskExecutionContext class.""" + + @pytest.fixture + def mock_flask_app(self): + """Create a mock Flask app.""" + app = MagicMock() + app.config = {} + app.app_context = MagicMock() + app.app_context.return_value.__enter__ = MagicMock(return_value=None) + app.app_context.return_value.__exit__ = MagicMock(return_value=None) + return app + + def test_initialization(self, mock_flask_app): + """Test FlaskExecutionContext initialization.""" + from context.flask_app_context import FlaskExecutionContext + + context_vars = contextvars.copy_context() + user = MagicMock() + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=context_vars, + user=user, + ) + + assert ctx.context_vars == context_vars + assert ctx.user == user + + def test_app_context_property(self, mock_flask_app): + """Test app_context property returns FlaskAppContext.""" + from context.flask_app_context import FlaskAppContext, FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + ) + + assert isinstance(ctx.app_context, FlaskAppContext) + assert ctx.app_context.flask_app == mock_flask_app + + def test_context_manager_protocol(self, mock_flask_app): + """Test FlaskExecutionContext supports context manager protocol.""" + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + ) + + # Should have __enter__ and __exit__ methods + assert hasattr(ctx, "__enter__") + assert hasattr(ctx, "__exit__") + + # Should work as context manager + with ctx: + pass + + +class TestCaptureFlaskContext: + """Test capture_flask_context function.""" + + @patch("context.flask_app_context.current_app") + @patch("context.flask_app_context.g") + def test_capture_flask_context_captures_app(self, mock_g, mock_current_app): + """Test capture_flask_context captures Flask app.""" + mock_app = MagicMock() + mock_app._get_current_object = MagicMock(return_value=mock_app) + mock_current_app._get_current_object = MagicMock(return_value=mock_app) + + from context.flask_app_context import capture_flask_context + + ctx = capture_flask_context() + + assert ctx._flask_app == mock_app + + @patch("context.flask_app_context.current_app") + @patch("context.flask_app_context.g") + def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app): + """Test capture_flask_context captures user from Flask g object.""" + mock_app = MagicMock() + mock_app._get_current_object = MagicMock(return_value=mock_app) + mock_current_app._get_current_object = MagicMock(return_value=mock_app) + + mock_user = MagicMock() + mock_user.id = "user_123" + mock_g._login_user = mock_user + + from context.flask_app_context import capture_flask_context + + ctx = capture_flask_context() + + assert ctx.user == mock_user + + @patch("context.flask_app_context.current_app") + def test_capture_flask_context_with_explicit_user(self, mock_current_app): + """Test capture_flask_context uses explicit user parameter.""" + mock_app = MagicMock() + mock_app._get_current_object = MagicMock(return_value=mock_app) + mock_current_app._get_current_object = MagicMock(return_value=mock_app) + + explicit_user = MagicMock() + explicit_user.id = "user_456" + + from context.flask_app_context import capture_flask_context + + ctx = capture_flask_context(user=explicit_user) + + assert ctx.user == explicit_user + + @patch("context.flask_app_context.current_app") + def test_capture_flask_context_captures_contextvars(self, mock_current_app): + """Test capture_flask_context captures context variables.""" + mock_app = MagicMock() + mock_app._get_current_object = MagicMock(return_value=mock_app) + mock_current_app._get_current_object = MagicMock(return_value=mock_app) + + # Set a context variable + test_var = contextvars.ContextVar("test_var") + test_var.set("test_value") + + from context.flask_app_context import capture_flask_context + + ctx = capture_flask_context() + + # Context variables should be captured + assert ctx.context_vars is not None + # Verify the variable is in the captured context + captured_value = ctx.context_vars[test_var] + assert captured_value == "test_value" + + +class TestFlaskExecutionContextIntegration: + """Integration tests for FlaskExecutionContext.""" + + @pytest.fixture + def mock_flask_app(self): + """Create a mock Flask app with proper app context.""" + app = MagicMock() + app.config = {"TEST": "value"} + app.extensions = {"db": MagicMock()} + + # Mock app context + mock_app_context = MagicMock() + mock_app_context.__enter__ = MagicMock(return_value=None) + mock_app_context.__exit__ = MagicMock(return_value=None) + app.app_context.return_value = mock_app_context + + return app + + def test_enter_restores_context_vars(self, mock_flask_app): + """Test that enter restores captured context variables.""" + # Create a context variable and set a value + test_var = contextvars.ContextVar("integration_test_var") + test_var.set("original_value") + + # Capture the context + context_vars = contextvars.copy_context() + + # Change the value + test_var.set("new_value") + + # Create FlaskExecutionContext and enter it + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=context_vars, + ) + + with ctx: + # Value should be restored to original + assert test_var.get() == "original_value" + + # After exiting, variable stays at the value from within the context + # (this is expected Python contextvars behavior) + assert test_var.get() == "original_value" + + def test_enter_enters_flask_app_context(self, mock_flask_app): + """Test that enter enters Flask app context.""" + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + ) + + with ctx: + # Verify app context was entered + assert mock_flask_app.app_context.called + + @patch("context.flask_app_context.g") + def test_enter_restores_user_in_g(self, mock_g, mock_flask_app): + """Test that enter restores user in Flask g object.""" + mock_user = MagicMock() + mock_user.id = "test_user" + + # Note: FlaskExecutionContext saves user from g before entering context, + # then restores it after entering the app context. + # The user passed to constructor is NOT restored to g. + # So we need to test the actual behavior. + + # Create FlaskExecutionContext with user in constructor + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + user=mock_user, + ) + + # Set user in g before entering (simulating existing user in g) + mock_g._login_user = mock_user + + with ctx: + # After entering, the user from g before entry should be restored + assert mock_g._login_user == mock_user + + # The user in constructor is stored but not automatically restored to g + # (it's available via ctx.user property) + assert ctx.user == mock_user + + def test_enter_method_as_context_manager(self, mock_flask_app): + """Test enter method returns a proper context manager.""" + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + ) + + # enter() should return a generator/context manager + with ctx.enter(): + # Should work without issues + pass + + # Verify app context was called + assert mock_flask_app.app_context.called diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py new file mode 100644 index 0000000000..6858120335 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory +from core.workflow.entities import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph.validation import GraphValidationError +from core.workflow.nodes import NodeType +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom + + +def _build_iteration_graph(node_id: str) -> dict[str, Any]: + return { + "nodes": [ + { + "id": node_id, + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": [node_id, "output"], + }, + } + ], + "edges": [], + } + + +def _build_loop_graph(node_id: str) -> dict[str, Any]: + return { + "nodes": [ + { + "id": node_id, + "data": { + "type": "loop", + "title": "Loop", + "loop_count": 1, + "break_conditions": [], + "logical_operator": "and", + "loop_variables": [], + "outputs": {}, + }, + } + ], + "edges": [], + } + + +def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool( + system_variables=SystemVariable.default(), + user_inputs={}, + environment_variables=[], + ), + start_at=0.0, + ) + return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + + +def test_iteration_root_requires_skip_validation(): + node_id = "iteration-node" + graph_config = _build_iteration_graph(node_id) + node_factory = _make_factory(graph_config) + + with pytest.raises(GraphValidationError): + Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=node_id, + ) + + graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=node_id, + skip_validation=True, + ) + + assert graph.root_node.id == node_id + assert graph.root_node.node_type == NodeType.ITERATION + + +def test_loop_root_requires_skip_validation(): + node_id = "loop-node" + graph_config = _build_loop_graph(node_id) + node_factory = _make_factory(graph_config) + + with pytest.raises(GraphValidationError): + Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=node_id, + ) + + graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=node_id, + skip_validation=True, + ) + + assert graph.root_node.id == node_id + assert graph.root_node.node_type == NodeType.LOOP diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 8677325d4e..f33fd0deeb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,8 +3,15 @@ import json from unittest.mock import MagicMock +from core.variables import IntegerVariable, StringVariable from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + GraphEngineCommand, + UpdateVariablesCommand, + VariableUpdate, +) class TestRedisChannel: @@ -148,6 +155,43 @@ class TestRedisChannel: assert commands[0].command_type == CommandType.ABORT assert isinstance(commands[1], AbortCommand) + def test_fetch_commands_with_update_variables_command(self): + """Test fetching update variables command from Redis.""" + mock_redis = MagicMock() + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), + ), + ] + ) + command_json = json.dumps(update_command.model_dump()) + + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[command_json.encode()], 1] + + channel = RedisChannel(mock_redis, "test:key") + commands = channel.fetch_commands() + + assert len(commands) == 1 + assert isinstance(commands[0], UpdateVariablesCommand) + assert isinstance(commands[0].updates[0].value, StringVariable) + assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] + assert commands[0].updates[0].value.value == "bar" + def test_fetch_commands_skips_invalid_json(self): """Test that invalid JSON commands are skipped.""" mock_redis = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py new file mode 100644 index 0000000000..cf8811dc2b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py @@ -0,0 +1 @@ +"""Tests for graph traversal components.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py new file mode 100644 index 0000000000..0019020ede --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py @@ -0,0 +1,307 @@ +"""Unit tests for skip propagator.""" + +from unittest.mock import MagicMock, create_autospec + +from core.workflow.graph import Edge, Graph +from core.workflow.graph_engine.graph_state_manager import GraphStateManager +from core.workflow.graph_engine.graph_traversal.skip_propagator import SkipPropagator + + +class TestSkipPropagator: + """Test suite for SkipPropagator.""" + + def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None: + """When there are unknown incoming edges, propagation should stop.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + # Setup graph edges dict + mock_graph.edges = {"edge_1": mock_edge} + + # Setup incoming edges + incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return has_unknown=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": True, + "has_taken": False, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_graph.get_incoming_edges.assert_called_once_with("node_2") + mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges) + # Should not call any other state manager methods + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.start_execution.assert_not_called() + mock_state_manager.mark_node_skipped.assert_not_called() + + def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None: + """When there is at least one taken edge, node should be enqueued.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return has_taken=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": True, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_state_manager.enqueue_node.assert_called_once_with("node_2") + mock_state_manager.start_execution.assert_called_once_with("node_2") + mock_state_manager.mark_node_skipped.assert_not_called() + + def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None: + """When all incoming edges are skipped, should propagate skip to node.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return all_skipped=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.start_execution.assert_not_called() + + def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None: + """_propagate_skip_to_node should mark node and all outgoing edges as skipped.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create outgoing edges + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_2" + edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge + + edge2 = MagicMock(spec=Edge) + edge2.id = "edge_3" + edge2.head = "node_downstream_2" + + # Setup graph edges dict for propagate_skip_from_edge + mock_graph.edges = {"edge_2": edge1, "edge_3": edge2} + mock_graph.get_outgoing_edges.return_value = [edge1, edge2] + + # Setup get_incoming_edges to return empty list to stop recursion + mock_graph.get_incoming_edges.return_value = [] + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Use mock to call private method + # Act + propagator._propagate_skip_to_node("node_1") + + # Assert + mock_state_manager.mark_node_skipped.assert_called_once_with("node_1") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") + assert mock_state_manager.mark_edge_skipped.call_count == 2 + # Should recursively propagate from each edge + # Since propagate_skip_from_edge is called, we need to verify it was called + # But we can't directly verify due to recursion. We'll trust the logic. + + def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None: + """skip_branch_paths should mark all unselected edges as skipped and propagate.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create unselected edges + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_1" + edge1.head = "node_downstream_1" + + edge2 = MagicMock(spec=Edge) + edge2.id = "edge_2" + edge2.head = "node_downstream_2" + + unselected_edges = [edge1, edge2] + + # Setup graph edges dict + mock_graph.edges = {"edge_1": edge1, "edge_2": edge2} + # Setup get_incoming_edges to return empty list to stop recursion + mock_graph.get_incoming_edges.return_value = [] + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.skip_branch_paths(unselected_edges) + + # Assert + mock_state_manager.mark_edge_skipped.assert_any_call("edge_1") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") + assert mock_state_manager.mark_edge_skipped.call_count == 2 + # propagate_skip_from_edge should be called for each edge + # We can't directly verify due to the mock, but the logic is covered + + def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None: + """Skip propagation should recursively propagate through the graph.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4 + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_1" + edge1.head = "node_2" + + edge3 = MagicMock(spec=Edge) + edge3.id = "edge_3" + edge3.head = "node_4" + + mock_graph.edges = {"edge_1": edge1, "edge_3": edge3} + + # Setup get_incoming_edges to return different values based on node + def get_incoming_edges_side_effect(node_id): + if node_id == "node_2": + return [edge1] + elif node_id == "node_4": + return [edge3] + return [] + + mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect + + # Setup get_outgoing_edges to return different values based on node + def get_outgoing_edges_side_effect(node_id): + if node_id == "node_2": + return [edge3] + elif node_id == "node_4": + return [] # No outgoing edges, stops recursion + return [] + + mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect + + # Setup state manager to return all_skipped for both nodes + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + # Should mark node_2 as skipped + mock_state_manager.mark_node_skipped.assert_any_call("node_2") + # Should mark edge_3 as skipped + mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") + # Should propagate to node_4 + mock_state_manager.mark_node_skipped.assert_any_call("node_4") + assert mock_state_manager.mark_node_skipped.call_count == 2 + + def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None: + """Test with mixed edge states (some unknown, some taken, some skipped).""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Test 1: has_unknown=True, has_taken=False, all_skipped=False + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": True, + "has_taken": False, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should stop processing + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.mark_node_skipped.assert_not_called() + + # Reset mocks for next test + mock_state_manager.reset_mock() + mock_graph.reset_mock() + + # Test 2: has_unknown=False, has_taken=True, all_skipped=False + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": True, + "all_skipped": False, + } + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should enqueue node + mock_state_manager.enqueue_node.assert_called_once_with("node_2") + mock_state_manager.start_execution.assert_called_once_with("node_2") + + # Reset mocks for next test + mock_state_manager.reset_mock() + mock_graph.reset_mock() + + # Test 3: has_unknown=False, has_taken=False, all_skipped=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should propagate skip + mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index b18a3369e9..35a234be0b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -90,12 +90,47 @@ def mock_tool_node(): @pytest.fixture def mock_is_instrument_flag_enabled_false(): """Mock is_instrument_flag_enabled to return False.""" - with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False): + with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False): yield @pytest.fixture def mock_is_instrument_flag_enabled_true(): """Mock is_instrument_flag_enabled to return True.""" - with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True): + with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True): yield + + +@pytest.fixture +def mock_retrieval_node(): + """Create a mock Knowledge Retrieval Node.""" + node = MagicMock() + node.id = "test-retrieval-node-id" + node.title = "Retrieval Node" + node.execution_id = "test-retrieval-execution-id" + node.node_type = NodeType.KNOWLEDGE_RETRIEVAL + return node + + +@pytest.fixture +def mock_result_event(): + """Create a mock result event with NodeRunResult.""" + from datetime import datetime + + from core.workflow.graph_events.node import NodeRunSucceededEvent + from core.workflow.node_events.base import NodeRunResult + + node_run_result = NodeRunResult( + inputs={"query": "test query"}, + outputs={"result": [{"content": "test content", "metadata": {}}]}, + process_data={}, + metadata={}, + ) + + return NodeRunSucceededEvent( + id="test-execution-id", + node_id="test-node-id", + node_type=NodeType.LLM, + start_at=datetime.now(), + node_run_result=node_run_result, + ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py new file mode 100644 index 0000000000..d6ba61c50c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import pytest + +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_engine.layers.base import ( + GraphEngineLayer, + GraphEngineLayerNotInitializedError, +) +from core.workflow.graph_events import GraphEngineEvent + +from ..test_table_runner import WorkflowRunner + + +class LayerForTest(GraphEngineLayer): + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + pass + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_layer_runtime_state_raises_when_uninitialized() -> None: + layer = LayerForTest() + + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + + +def test_layer_runtime_state_available_after_engine_layer() -> None: + runner = WorkflowRunner() + fixture_data = runner.load_fixture("simple_passthrough_workflow") + graph, graph_runtime_state = runner.create_graph_from_fixture( + fixture_data, + inputs={"query": "test layer state"}, + ) + engine = GraphEngine( + workflow_id="test_workflow", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + ) + + layer = LayerForTest() + engine.layer(layer) + + outputs = layer.graph_runtime_state.outputs + ready_queue_size = layer.graph_runtime_state.ready_queue_size + + assert outputs == {} + assert isinstance(ready_queue_size, int) + assert ready_queue_size >= 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 458cf2cc67..ade846df28 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -4,7 +4,8 @@ Tests for ObservabilityLayer. Test coverage: - Initialization and enable/disable logic - Node span lifecycle (start, end, error handling) -- Parser integration (default and tool-specific) +- Parser integration (default, tool, LLM, and retrieval parsers) +- Result event parameter extraction (inputs/outputs) - Graph lifecycle management - Disabled mode behavior """ @@ -14,14 +15,14 @@ from unittest.mock import patch import pytest from opentelemetry.trace import StatusCode +from core.app.workflow.layers.observability import ObservabilityLayer from core.workflow.enums import NodeType -from core.workflow.graph_engine.layers.observability import ObservabilityLayer class TestObservabilityLayerInitialization: """Test ObservabilityLayer initialization logic.""" - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter): """Test that layer initializes correctly when OTel is enabled.""" @@ -31,7 +32,7 @@ class TestObservabilityLayerInitialization: assert NodeType.TOOL in layer._parsers assert layer._default_parser is not None - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true") def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter): """Test that layer enables when instrument flag is enabled.""" @@ -45,7 +46,7 @@ class TestObservabilityLayerInitialization: class TestObservabilityLayerNodeSpanLifecycle: """Test node span creation and lifecycle management.""" - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_node_span_created_and_ended( self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node @@ -62,7 +63,7 @@ class TestObservabilityLayerNodeSpanLifecycle: assert spans[0].name == mock_llm_node.title assert spans[0].status.status_code == StatusCode.OK - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_node_error_recorded_in_span( self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node @@ -81,7 +82,7 @@ class TestObservabilityLayerNodeSpanLifecycle: assert len(spans[0].events) > 0 assert any("exception" in event.name.lower() for event in spans[0].events) - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_node_end_without_start_handled_gracefully( self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node @@ -99,7 +100,7 @@ class TestObservabilityLayerNodeSpanLifecycle: class TestObservabilityLayerParserIntegration: """Test parser integration for different node types.""" - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_default_parser_used_for_regular_node( self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node @@ -118,7 +119,7 @@ class TestObservabilityLayerParserIntegration: assert attrs["node.execution_id"] == mock_start_node.execution_id assert attrs["node.type"] == mock_start_node.node_type.value - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_tool_parser_used_for_tool_node( self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node @@ -134,15 +135,107 @@ class TestObservabilityLayerParserIntegration: assert len(spans) == 1 attrs = spans[0].attributes assert attrs["node.id"] == mock_tool_node.id - assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id - assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value - assert attrs["tool.name"] == mock_tool_node._node_data.tool_name + assert attrs["gen_ai.tool.name"] == mock_tool_node.title + assert attrs["gen_ai.tool.type"] == mock_tool_node._node_data.provider_type.value + + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_llm_parser_used_for_llm_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event + ): + """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" + from core.workflow.node_events.base import NodeRunResult + + mock_result_event.node_run_result = NodeRunResult( + inputs={}, + outputs={"text": "test completion", "finish_reason": "stop"}, + process_data={ + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + "prompts": [{"role": "user", "text": "test prompt"}], + }, + metadata={}, + ) + + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None, mock_result_event) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_llm_node.id + assert attrs["gen_ai.request.model"] == "gpt-4" + assert attrs["gen_ai.provider.name"] == "openai" + assert attrs["gen_ai.usage.input_tokens"] == 10 + assert attrs["gen_ai.usage.output_tokens"] == 20 + assert attrs["gen_ai.usage.total_tokens"] == 30 + assert attrs["gen_ai.completion"] == "test completion" + assert attrs["gen_ai.response.finish_reason"] == "stop" + + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_retrieval_parser_used_for_retrieval_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event + ): + """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" + from core.workflow.node_events.base import NodeRunResult + + mock_result_event.node_run_result = NodeRunResult( + inputs={"query": "test query"}, + outputs={"result": [{"content": "test content", "metadata": {"score": 0.9, "document_id": "doc1"}}]}, + process_data={}, + metadata={}, + ) + + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_retrieval_node) + layer.on_node_run_end(mock_retrieval_node, None, mock_result_event) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_retrieval_node.id + assert attrs["retrieval.query"] == "test query" + assert "retrieval.document" in attrs + + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_result_event_extracts_inputs_and_outputs( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event + ): + """Test that result_event parameter allows parsers to extract inputs and outputs.""" + from core.workflow.node_events.base import NodeRunResult + + mock_result_event.node_run_result = NodeRunResult( + inputs={"input_key": "input_value"}, + outputs={"output_key": "output_value"}, + process_data={}, + metadata={}, + ) + + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None, mock_result_event) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert "input.value" in attrs + assert "output.value" in attrs class TestObservabilityLayerGraphLifecycle: """Test graph lifecycle management.""" - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node): """Test that on_graph_start clears node contexts.""" @@ -155,7 +248,7 @@ class TestObservabilityLayerGraphLifecycle: layer.on_graph_start() assert len(layer._node_contexts) == 0 - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_on_graph_end_with_no_unfinished_spans( self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node @@ -171,7 +264,7 @@ class TestObservabilityLayerGraphLifecycle: spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_on_graph_end_with_unfinished_spans_logs_warning( self, tracer_provider_with_memory_exporter, mock_llm_node, caplog @@ -192,7 +285,7 @@ class TestObservabilityLayerGraphLifecycle: class TestObservabilityLayerDisabledMode: """Test behavior when layer is disabled.""" - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node): """Test that disabled layer doesn't create spans on node start.""" @@ -206,7 +299,7 @@ class TestObservabilityLayerDisabledMode: spans = memory_span_exporter.get_finished_spans() assert len(spans) == 0 - @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node): """Test that disabled layer doesn't process node end.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index c1fc4acd73..fe3ea576c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -3,6 +3,7 @@ from __future__ import annotations import queue +import threading from unittest import mock from core.workflow.entities.pause_reason import SchedulingPause @@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=execution_coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() assert event_queue.empty() @@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int: event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() @@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index b074a11be9..d826f7a900 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -4,12 +4,19 @@ import time from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables import IntegerVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent from core.workflow.nodes.start.start_node import StartNode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -180,3 +187,67 @@ def test_pause_command(): graph_execution = engine.graph_runtime_state.graph_execution assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] + + +def test_update_variables_command_updates_pool(): + """Test that GraphEngine updates variable pool via update variables command.""" + + shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=shared_runtime_state, + ) + mock_graph.nodes["start"] = start_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + command_channel = InMemoryChannel() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=shared_runtime_state, + command_channel=command_channel, + ) + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), + ), + ] + ) + command_channel.send_command(update_command) + + list(engine.run()) + + updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) + added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) + + assert updated_existing is not None + assert updated_existing.value == "new value" + assert added_new is not None + assert added_new.value == 123 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index b02f90588b..5ceb8dd7f7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing the behavior of mock nodes during testing. """ +from __future__ import annotations + from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -95,67 +97,67 @@ class MockConfigBuilder: def __init__(self) -> None: self._config = MockConfig() - def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder": + def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder: """Enable or disable auto-mocking.""" self._config.enable_auto_mock = enabled return self - def with_delays(self, enabled: bool = True) -> "MockConfigBuilder": + def with_delays(self, enabled: bool = True) -> MockConfigBuilder: """Enable or disable simulated execution delays.""" self._config.simulate_delays = enabled return self - def with_llm_response(self, response: str) -> "MockConfigBuilder": + def with_llm_response(self, response: str) -> MockConfigBuilder: """Set default LLM response.""" self._config.default_llm_response = response return self - def with_agent_response(self, response: str) -> "MockConfigBuilder": + def with_agent_response(self, response: str) -> MockConfigBuilder: """Set default agent response.""" self._config.default_agent_response = response return self - def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default tool response.""" self._config.default_tool_response = response return self - def with_retrieval_response(self, response: str) -> "MockConfigBuilder": + def with_retrieval_response(self, response: str) -> MockConfigBuilder: """Set default retrieval response.""" self._config.default_retrieval_response = response return self - def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default HTTP response.""" self._config.default_http_response = response return self - def with_template_transform_response(self, response: str) -> "MockConfigBuilder": + def with_template_transform_response(self, response: str) -> MockConfigBuilder: """Set default template transform response.""" self._config.default_template_transform_response = response return self - def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default code execution response.""" self._config.default_code_response = response return self - def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder": + def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder: """Set outputs for a specific node.""" self._config.set_node_outputs(node_id, outputs) return self - def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder": + def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder: """Set error for a specific node.""" self._config.set_node_error(node_id, error) return self - def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder": + def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder: """Add a node-specific configuration.""" self._config.set_node_config(config.node_id, config) return self - def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder": + def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder: """Set default configuration for a node type.""" self._config.set_default_config(node_type, config) return self diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index eeffdd27fe..170445225b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,9 +7,9 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.enums import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.node_factory import DifyNodeFactory from .test_mock_nodes import ( MockAgentNode, @@ -103,13 +103,25 @@ class MockNodeFactory(DifyNodeFactory): # Create mock node instance mock_class = self._mock_node_types[node_type] - mock_instance = mock_class( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - mock_config=self.mock_config, - ) + if node_type == NodeType.CODE: + mock_instance = mock_class( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + mock_config=self.mock_config, + code_executor=self._code_executor, + code_providers=self._code_providers, + code_limits=self._code_limits, + ) + else: + mock_instance = mock_class( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + mock_config=self.mock_config, + ) return mock_instance diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index fd94a5e833..5937bbfb39 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -40,12 +40,14 @@ class MockNodeMixin: graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, + **kwargs: Any, ): super().__init__( id=id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + **kwargs, ) self.mock_config = mock_config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 4fb693a5c2..de08cc3497 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -5,11 +5,24 @@ This module tests the functionality of MockTemplateTransformNode and MockCodeNod to ensure they work correctly with the TableTestRunner. """ +from configs import dify_config from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode +DEFAULT_CODE_LIMITS = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, +) + class TestMockTemplateTransformNode: """Test cases for MockTemplateTransformNode.""" @@ -306,6 +319,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node @@ -370,6 +384,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node @@ -438,6 +453,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index b76fe42fce..e8cd665107 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -13,6 +13,7 @@ from unittest.mock import patch from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph @@ -26,7 +27,6 @@ from core.workflow.graph_events import ( ) from core.workflow.node_events import NodeRunResult, StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py new file mode 100644 index 0000000000..ea8d3a977f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py @@ -0,0 +1,539 @@ +""" +Unit tests for stop_event functionality in GraphEngine. + +Tests the unified stop_event management by GraphEngine and its propagation +to WorkerPool, Worker, Dispatcher, and Nodes. +""" + +import threading +import time +from unittest.mock import MagicMock, Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_events import ( + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, +) +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from models.enums import UserFrom + + +class TestStopEventPropagation: + """Test suite for stop_event propagation through GraphEngine components.""" + + def test_graph_engine_creates_stop_event(self): + """Test that GraphEngine creates a stop_event on initialization.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify stop_event was created + assert engine._stop_event is not None + assert isinstance(engine._stop_event, threading.Event) + + # Verify it was set in graph_runtime_state + assert runtime_state.stop_event is not None + assert runtime_state.stop_event is engine._stop_event + + def test_stop_event_cleared_on_start(self): + """Test that stop_event is cleared when execution starts.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set the stop_event before running + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear the stop_event) + events = list(engine.run()) + + # After running, stop_event should be set again (by _stop_execution) + # But during start it was cleared + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + def test_stop_event_set_on_stop(self): + """Test that stop_event is set when execution stops.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Initially not set + assert not engine._stop_event.is_set() + + # Run the engine + list(engine.run()) + + # After execution completes, stop_event should be set + assert engine._stop_event.is_set() + + def test_stop_event_passed_to_worker_pool(self): + """Test that stop_event is passed to WorkerPool.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify WorkerPool has the stop_event + assert engine._worker_pool._stop_event is not None + assert engine._worker_pool._stop_event is engine._stop_event + + def test_stop_event_passed_to_dispatcher(self): + """Test that stop_event is passed to Dispatcher.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify Dispatcher has the stop_event + assert engine._dispatcher._stop_event is not None + assert engine._dispatcher._stop_event is engine._stop_event + + +class TestNodeStopCheck: + """Test suite for Node._should_stop() functionality.""" + + def test_node_should_stop_checks_runtime_state(self): + """Test that Node._should_stop() checks GraphRuntimeState.stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Initially stop_event is not set + assert not answer_node._should_stop() + + # Set the stop_event + runtime_state.stop_event.set() + + # Now _should_stop should return True + assert answer_node._should_stop() + + def test_node_run_checks_stop_event_between_yields(self): + """Test that Node.run() checks stop_event between yielding events.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a simple node + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Set stop_event BEFORE running the node + runtime_state.stop_event.set() + + # Run the node - should yield start event then detect stop + # The node should check stop_event before processing + assert answer_node._should_stop(), "stop_event should be set" + + # Run and collect events + events = list(answer_node.run()) + + # Since stop_event is set at the start, we should get: + # 1. NodeRunStartedEvent (always yielded first) + # 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast) + assert len(events) >= 2 + assert isinstance(events[0], NodeRunStartedEvent) + + # Note: AnswerNode is very simple and might complete before stop check + # The important thing is that _should_stop() returns True when stop_event is set + assert answer_node._should_stop() + + +class TestStopEventIntegration: + """Integration tests for stop_event in workflow execution.""" + + def test_simple_workflow_respects_stop_event(self): + """Test that a simple workflow respects stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + # Create start and answer nodes + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + mock_graph.nodes["start"] = start_node + mock_graph.nodes["answer"] = answer_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set stop_event before running + runtime_state.stop_event.set() + + # Run the engine + events = list(engine.run()) + + # Should get started event but not succeeded (due to stop) + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + # The workflow should still complete (start node runs quickly) + # but answer node might be cancelled depending on timing + + def test_stop_event_with_concurrent_nodes(self): + """Test stop_event behavior with multiple concurrent nodes.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + # Create multiple nodes + for i in range(3): + answer_node = AnswerNode( + id=f"answer_{i}", + config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes[f"answer_{i}"] = answer_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # All nodes should share the same stop_event + for node in mock_graph.nodes.values(): + assert node.graph_runtime_state.stop_event is runtime_state.stop_event + assert node.graph_runtime_state.stop_event is engine._stop_event + + +class TestStopEventTimeoutBehavior: + """Test stop_event behavior with join timeouts.""" + + @patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread") + def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock): + """Test that Dispatcher uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + dispatcher = engine._dispatcher + dispatcher.start() # This will create and start the mocked thread + + mock_thread_instance = mock_thread_cls.return_value + mock_thread_instance.is_alive.return_value = True + + dispatcher.stop() + + mock_thread_instance.join.assert_called_once_with(timeout=2.0) + + @patch("core.workflow.graph_engine.worker_management.worker_pool.Worker") + def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock): + """Test that WorkerPool uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + worker_pool = engine._worker_pool + worker_pool.start(initial_count=1) # Start with one worker + + mock_worker_instance = mock_worker_cls.return_value + mock_worker_instance.is_alive.return_value = True + + worker_pool.stop() + + mock_worker_instance.join.assert_called_once_with(timeout=2.0) + + +class TestStopEventResumeBehavior: + """Test stop_event behavior during workflow resume.""" + + def test_stop_event_cleared_on_resume(self): + """Test that stop_event is cleared when resuming a paused workflow.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Simulate a previous execution that set stop_event + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear stop_event in _start_execution) + events = list(engine.run()) + + # Execution should complete successfully + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + +class TestWorkerStopBehavior: + """Test Worker behavior with shared stop_event.""" + + def test_worker_uses_shared_stop_event(self): + """Test that Worker uses shared stop_event from GraphEngine.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Get the worker pool and check workers + worker_pool = engine._worker_pool + + # Start the worker pool to create workers + worker_pool.start() + + # Check that at least one worker was created + assert len(worker_pool._workers) > 0 + + # Verify workers use the shared stop_event + for worker in worker_pool._workers: + assert worker._stop_event is engine._stop_event + + # Clean up + worker_pool.stop() + + def test_worker_stop_is_noop(self): + """Test that Worker.stop() is now a no-op.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a mock worker + from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue + from core.workflow.graph_engine.worker import Worker + + ready_queue = InMemoryReadyQueue() + event_queue = MagicMock() + + # Create a proper mock graph with real dict + mock_graph = Mock(spec=Graph) + mock_graph.nodes = {} # Use real dict + + stop_event = threading.Event() + + worker = Worker( + ready_queue=ready_queue, + event_queue=event_queue, + graph=mock_graph, + layers=[], + stop_event=stop_event, + ) + + # Calling stop() should do nothing (no-op) + # and should NOT set the stop_event + worker.stop() + assert not stop_event.is_set() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py index 1f4c063bf0..99157a7c3e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py @@ -45,3 +45,33 @@ def test_streaming_conversation_variables(): runner = TableTestRunner() result = runner.run_test_case(case) assert result.success, f"Test failed: {result.error}" + + +def test_streaming_conversation_variables_v1_overwrite_waits_for_assignment(): + fixture_name = "test_streaming_conversation_variables_v1_overwrite" + input_query = "overwrite-value" + + case = WorkflowTestCase( + fixture_path=fixture_name, + use_auto_mock=False, + mock_config=MockConfigBuilder().build(), + query=input_query, + inputs={}, + expected_outputs={"answer": f"Current Value Of `conv_var` is:{input_query}"}, + ) + + runner = TableTestRunner() + result = runner.run_test_case(case) + assert result.success, f"Test failed: {result.error}" + + events = result.events + conv_var_chunk_events = [ + event + for event in events + if isinstance(event, NodeRunStreamChunkEvent) and tuple(event.selector) == ("conversation", "conv_var") + ] + + assert conv_var_chunk_events, "Expected conversation variable chunk events to be emitted" + assert all(event.chunk == input_query for event in conv_var_chunk_events), ( + "Expected streamed conversation variable value to match the input query" + ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 08f7b00a33..10ac1206fb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,6 +19,7 @@ from functools import lru_cache from pathlib import Path from typing import Any +from core.app.workflow.node_factory import DifyNodeFactory from core.tools.utils.yaml_utils import _load_yaml_file from core.variables import ( ArrayNumberVariable, @@ -38,7 +39,6 @@ from core.workflow.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 98d9560e64..1e95ec1970 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -3,11 +3,11 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 596e72ddd0..2262d25a14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,3 +1,4 @@ +from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage from core.variables.types import SegmentType from core.workflow.nodes.code.code_node import CodeNode @@ -7,6 +8,18 @@ from core.workflow.nodes.code.exc import ( DepthLimitError, OutputValidationError, ) +from core.workflow.nodes.code.limits import CodeNodeLimits + +CodeNode._limits = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, +) class TestCodeNodeExceptions: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index 27df938102..cefc4967ac 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -16,7 +16,7 @@ from core.workflow.system_variable import SystemVariable def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -69,7 +69,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -124,7 +124,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -178,7 +178,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=SystemVariable.empty()) + variable_pool = VariablePool(system_variables=SystemVariable.default()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -205,7 +205,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -290,7 +290,7 @@ def test_init_headers(): return Executor( node_data=node_data, timeout=timeout, - variable_pool=VariablePool(system_variables=SystemVariable.empty()), + variable_pool=VariablePool(system_variables=SystemVariable.default()), ) executor = create_executor("aa\n cc:") @@ -324,7 +324,7 @@ def test_init_params(): return Executor( node_data=node_data, timeout=timeout, - variable_pool=VariablePool(system_variables=SystemVariable.empty()), + variable_pool=VariablePool(system_variables=SystemVariable.default()), ) # Test basic key-value pairs @@ -355,7 +355,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.empty()) + variable_pool = VariablePool(system_variables=SystemVariable.default()) node_data = HttpRequestNodeData( title="test", method="get", @@ -379,7 +379,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.empty()) + variable_pool = VariablePool(system_variables=SystemVariable.default()) node_data = HttpRequestNodeData( title="test", method="get", @@ -403,7 +403,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.empty()) + variable_pool = VariablePool(system_variables=SystemVariable.default()) node_data = HttpRequestNodeData( title="test", method="get", @@ -427,7 +427,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=SystemVariable.empty()) + variable_pool = VariablePool(system_variables=SystemVariable.default()) node_data = HttpRequestNodeData( title="test", method="get", @@ -451,7 +451,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.empty()) + variable_pool = VariablePool(system_variables=SystemVariable.default()) node_data = HttpRequestNodeData( title="test", method="get", @@ -475,3 +475,130 @@ def test_valid_api_key_works(): headers = executor._assembling_headers() assert "Authorization" in headers assert headers["Authorization"] == "Bearer valid-api-key-123" + + +def test_executor_with_json_body_and_unquoted_uuid_variable(): + """Test that unquoted UUID variables are correctly handled in JSON body. + + This test verifies the fix for issue #31436 where json_repair would truncate + certain UUID patterns (like 57eeeeb1-...) when they appeared as unquoted values. + """ + # UUID that triggers the json_repair truncation bug + test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" + + variable_pool = VariablePool( + system_variables=SystemVariable.default(), + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "uuid"], test_uuid) + + node_data = HttpRequestNodeData( + title="Test JSON Body with Unquoted UUID Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + # UUID variable without quotes - this is the problematic case + value='{"rowId": {{#pre_node_id.uuid#}}}', + ) + ], + ), + ) + + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # The UUID should be preserved in full, not truncated + assert executor.json == {"rowId": test_uuid} + assert len(executor.json["rowId"]) == len(test_uuid) + + +def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): + """Test that unquoted UUID variables with newlines in JSON are handled correctly. + + This is a specific case from issue #31436 where the JSON body contains newlines. + """ + test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" + + variable_pool = VariablePool( + system_variables=SystemVariable.default(), + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "uuid"], test_uuid) + + node_data = HttpRequestNodeData( + title="Test JSON Body with Unquoted UUID and Newlines", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + # JSON with newlines and unquoted UUID variable + value='{\n"rowId": {{#pre_node_id.uuid#}}\n}', + ) + ], + ), + ) + + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # The UUID should be preserved in full + assert executor.json == {"rowId": test_uuid} + + +def test_executor_with_json_body_preserves_numbers_and_strings(): + """Test that numbers are preserved and string values are properly quoted.""" + variable_pool = VariablePool( + system_variables=SystemVariable.default(), + user_inputs={}, + ) + variable_pool.add(["node", "count"], 42) + variable_pool.add(["node", "id"], "abc-123") + + node_data = HttpRequestNodeData( + title="Test JSON Body with mixed types", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"count": {{#node.count#}}, "id": {{#node.id#}}}', + ) + ], + ), + ) + + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + assert executor.json["count"] == 42 + assert executor.json["id"] == "abc-123" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index e8f257bf2f..1e224d56a5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -78,7 +78,7 @@ class TestFileSaverImpl: file_binary=_PNG_DATA, mimetype=mime_type, ) - mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png") + mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 77264022bc..3d1b8b2f27 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -86,7 +86,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) return GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 1a67d5c3e3..66d6c3c56b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.helper.code_executor.code_executor import CodeExecutionError from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.workflow import WorkflowType @@ -127,7 +127,9 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_simple_template( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): @@ -145,7 +147,7 @@ class TestTemplateTransformNode: mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) # Setup mock executor - mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"} + mock_execute.return_value = "Hello Alice, you are 30 years old!" node = TemplateTransformNode( id="test_node", @@ -162,7 +164,9 @@ class TestTemplateTransformNode: assert result.inputs["name"] == "Alice" assert result.inputs["age"] == 30 - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with None variable values.""" node_data = { @@ -172,7 +176,7 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.return_value = None - mock_execute.return_value = {"result": "Value: "} + mock_execute.return_value = "Value: " node = TemplateTransformNode( id="test_node", @@ -187,13 +191,15 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs["value"] is None - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_code_execution_error( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): """Test _run when code execution fails.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.side_effect = CodeExecutionError("Template syntax error") + mock_execute.side_effect = TemplateRenderError("Template syntax error") node = TemplateTransformNode( id="test_node", @@ -208,14 +214,16 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Template syntax error" in result.error - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) @patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10) def test_run_output_length_exceeds_limit( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): """Test _run when output exceeds maximum length.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"} + mock_execute.return_value = "This is a very long output that exceeds the limit" node = TemplateTransformNode( id="test_node", @@ -230,7 +238,9 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_complex_jinja2_template( self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params ): @@ -257,7 +267,7 @@ class TestTemplateTransformNode: ("sys", "show_total"): mock_show_total, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"} + mock_execute.return_value = "apple, banana, orange (Total: 3)" node = TemplateTransformNode( id="test_node", @@ -292,7 +302,9 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { @@ -301,7 +313,7 @@ class TestTemplateTransformNode: "template": "This is a static message.", } - mock_execute.return_value = {"result": "This is a static message."} + mock_execute.return_value = "This is a static message." node = TemplateTransformNode( id="test_node", @@ -317,7 +329,9 @@ class TestTemplateTransformNode: assert result.outputs["output"] == "This is a static message." assert result.inputs == {} - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with numeric variable values.""" node_data = { @@ -339,7 +353,7 @@ class TestTemplateTransformNode: ("sys", "quantity"): mock_quantity, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = {"result": "Total: $31.5"} + mock_execute.return_value = "Total: $31.5" node = TemplateTransformNode( id="test_node", @@ -354,7 +368,9 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["output"] == "Total: $31.5" - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with dictionary variable values.""" node_data = { @@ -367,7 +383,7 @@ class TestTemplateTransformNode: mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"} mock_graph_runtime_state.variable_pool.get.return_value = mock_user - mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"} + mock_execute.return_value = "Name: John Doe, Email: john@example.com" node = TemplateTransformNode( id="test_node", @@ -383,7 +399,9 @@ class TestTemplateTransformNode: assert "John Doe" in result.outputs["output"] assert "john@example.com" in result.outputs["output"] - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with list variable values.""" node_data = { @@ -396,7 +414,7 @@ class TestTemplateTransformNode: mock_tags.to_object.return_value = ["python", "ai", "workflow"] mock_graph_runtime_state.variable_pool.get.return_value = mock_tags - mock_execute.return_value = {"result": "Tags: #python #ai #workflow "} + mock_execute.return_value = "Tags: #python #ai #workflow " node = TemplateTransformNode( id="test_node", diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index dc7175f964..d700888c2f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment from core.workflow.entities import GraphInitParams @@ -12,7 +13,6 @@ from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 539e72edb5..16b432bae6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -58,6 +58,8 @@ def test_json_object_valid_schema(): } ) + schema = json.loads(schema) + variables = [ VariableEntity( variable="profile", @@ -68,7 +70,7 @@ def test_json_object_valid_schema(): ) ] - user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})} + user_inputs = {"profile": {"age": 20, "name": "Tom"}} node = make_start_node(user_inputs, variables) result = node._run() @@ -87,6 +89,8 @@ def test_json_object_invalid_json_string(): "required": ["age", "name"], } ) + + schema = json.loads(schema) variables = [ VariableEntity( variable="profile", @@ -97,12 +101,12 @@ def test_json_object_invalid_json_string(): ) ] - # Missing closing brace makes this invalid JSON + # Providing a string instead of an object should raise a type error user_inputs = {"profile": '{"age": 20, "name": "Tom"'} node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'): + with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"): node._run() @@ -118,6 +122,8 @@ def test_json_object_does_not_match_schema(): } ) + schema = json.loads(schema) + variables = [ VariableEntity( variable="profile", @@ -129,7 +135,7 @@ def test_json_object_does_not_match_schema(): ] # age is a string, which violates the schema (expects number) - user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})} + user_inputs = {"profile": {"age": "twenty", "name": "Tom"}} node = make_start_node(user_inputs, variables) @@ -149,6 +155,8 @@ def test_json_object_missing_required_schema_field(): } ) + schema = json.loads(schema) + variables = [ VariableEntity( variable="profile", @@ -160,7 +168,7 @@ def test_json_object_missing_required_schema_field(): ] # Missing required field "name" - user_inputs = {"profile": json.dumps({"age": 20})} + user_inputs = {"profile": {"age": 20}} node = make_start_node(user_inputs, variables) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 09b8191870..06927cddcf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import types from collections.abc import Generator @@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only @pytest.fixture -def tool_node(monkeypatch) -> "ToolNode": +def tool_node(monkeypatch) -> ToolNode: module_name = "core.ops.ops_trace_manager" if module_name not in sys.modules: ops_stub = types.ModuleType(module_name) @@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: +def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: def _identity_transform(messages, *_args, **_kwargs): return messages @@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l return _collect_events(generator) -def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"): +def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( tenant_id="tenant-id", type=FileType.DOCUMENT, @@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"): assert files_segment.value == [file_obj] -def test_plain_link_messages_remain_links(tool_node: "ToolNode"): +def test_plain_link_messages_remain_links(tool_node: ToolNode): message = ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index c62fc4d8fe..d4b7a017f9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -1,14 +1,14 @@ import time import uuid -from unittest import mock from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.variables import ArrayStringVariable, StringVariable -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.graph_events.node import NodeRunSucceededEvent +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -86,9 +86,6 @@ def test_overwrite_string_variable(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -104,20 +101,14 @@ def test_overwrite_string_variable(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_var = StringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=input_variable.value, - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == input_variable.value got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -191,9 +182,6 @@ def test_append_variable_to_array(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -209,22 +197,14 @@ def test_append_variable_to_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_value = list(conversation_variable.value) - expected_value.append(input_variable.value) - expected_var = ArrayStringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=expected_value, - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == ["the first value", "the second value"] got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -287,9 +267,6 @@ def test_clear_array(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -305,20 +282,14 @@ def test_clear_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_var = ArrayStringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=[], - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == [] got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index caa36734ad..b08f9c37b4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -3,10 +3,10 @@ import uuid from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.variables import ArrayStringVariable from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -390,3 +390,42 @@ def test_remove_last_from_empty_array(): got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None assert got.to_object() == [] + + +def test_node_factory_creates_variable_assigner_node(): + graph_config = { + "edges": [], + "nodes": [ + { + "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, + "id": "assigner", + }, + ], + } + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + variable_pool = VariablePool( + system_variables=SystemVariable(conversation_id="conversation_id"), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + node_factory = DifyNodeFactory( + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + ) + + node = node_factory.create_node(graph_config["nodes"][0]) + + assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index ead2334473..d8f6b41f89 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -111,7 +111,7 @@ def test_webhook_node_file_conversion_to_file_variable(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -184,7 +184,7 @@ def test_webhook_node_file_conversion_with_missing_files(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -219,7 +219,7 @@ def test_webhook_node_file_conversion_with_none_file(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -256,7 +256,7 @@ def test_webhook_node_file_conversion_with_non_dict_file(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -300,7 +300,7 @@ def test_webhook_node_file_conversion_mixed_parameters(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -370,7 +370,7 @@ def test_webhook_node_different_file_types(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -430,7 +430,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index bbb5511923..3b5aedebca 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -75,7 +75,7 @@ def test_webhook_node_basic_initialization(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) @@ -118,7 +118,7 @@ def test_webhook_node_run_with_headers(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": { @@ -154,7 +154,7 @@ def test_webhook_node_run_with_query_params(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -190,7 +190,7 @@ def test_webhook_node_run_with_body_params(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -249,7 +249,7 @@ def test_webhook_node_run_with_file_params(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, @@ -302,7 +302,7 @@ def test_webhook_node_run_mixed_parameters(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {"Authorization": "Bearer token"}, @@ -342,7 +342,7 @@ def test_webhook_node_run_empty_webhook_data(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, # No webhook_data ) @@ -368,7 +368,7 @@ def test_webhook_node_run_case_insensitive_headers(): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": { @@ -398,7 +398,7 @@ def test_webhook_node_variable_pool_user_inputs(): # Add some additional variables to the pool variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}}, "other_var": "should_be_included", @@ -429,7 +429,7 @@ def test_webhook_node_different_methods(method): ) variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={ "webhook_data": { "headers": {}, diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py index 7cdb2328f2..078ec5f6ab 100644 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -30,3 +30,12 @@ class TestWorkflowExecutionStatus: for status in non_ended_statuses: assert not status.is_ended(), f"{status} should not be considered ended" + + def test_ended_values(self): + """Test ended_values returns the expected status values.""" + assert set(WorkflowExecutionStatus.ended_values()) == { + WorkflowExecutionStatus.SUCCEEDED.value, + WorkflowExecutionStatus.FAILED.value, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, + WorkflowExecutionStatus.STOPPED.value, + } diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9733bf60eb..b8869dbf1d 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -24,7 +24,7 @@ from core.variables.variables import ( IntegerVariable, ObjectVariable, StringVariable, - VariableUnion, + Variable, ) from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.runtime import VariablePool @@ -160,7 +160,7 @@ class TestVariablePoolSerialization: ) # Create environment variables with all types including ArrayFileVariable - env_vars: list[VariableUnion] = [ + env_vars: list[Variable] = [ StringVariable( id="env_string_id", name="env_string", @@ -182,7 +182,7 @@ class TestVariablePoolSerialization: ] # Create conversation variables with complex data - conv_vars: list[VariableUnion] = [ + conv_vars: list[Variable] = [ StringVariable( id="conv_string_id", name="conv_string", diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 68d6c109e8..27ffa455d6 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -2,13 +2,17 @@ from types import SimpleNamespace import pytest +from configs import dify_config from core.file.enums import FileType from core.file.models import File, FileTransferMethod +from core.helper.code_executor.code_executor import CodeLanguage from core.variables.variables import StringVariable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry @@ -96,12 +100,64 @@ class TestWorkflowEntry: assert output_var is not None assert output_var.value == "system_user" + def test_single_step_run_injects_code_limits(self): + """Ensure single-step CodeNode execution configures limits.""" + # Arrange + node_id = "code_node" + node_data = { + "type": "code", + "title": "Code", + "desc": None, + "variables": [], + "code_language": CodeLanguage.PYTHON3, + "code": "def main():\n return {}", + "outputs": {}, + } + node_config = {"id": node_id, "data": node_data} + + class StubWorkflow: + def __init__(self): + self.tenant_id = "tenant" + self.app_id = "app" + self.id = "workflow" + self.graph_dict = {"nodes": [node_config], "edges": []} + + def get_node_config_by_id(self, target_id: str): + assert target_id == node_id + return node_config + + workflow = StubWorkflow() + variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={}) + expected_limits = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ) + + # Act + node, _ = WorkflowEntry.single_step_run( + workflow=workflow, + node_id=node_id, + user_id="user", + user_inputs={}, + variable_pool=variable_pool, + ) + + # Assert + assert isinstance(node, CodeNode) + assert node._limits == expected_limits + def test_mapping_user_inputs_to_variable_pool_with_env_variables(self): """Test mapping environment variables from user inputs to variable pool.""" # Initialize variable pool with environment variables env_var = StringVariable(name="API_KEY", value="existing_key") variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), environment_variables=[env_var], user_inputs={}, ) @@ -142,7 +198,7 @@ class TestWorkflowEntry: # Initialize variable pool with conversation variables conv_var = StringVariable(name="last_message", value="Hello") variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), conversation_variables=[conv_var], user_inputs={}, ) @@ -183,7 +239,7 @@ class TestWorkflowEntry: """Test mapping regular node variables from user inputs to variable pool.""" # Initialize empty variable pool variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) @@ -225,7 +281,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_file_handling(self): """Test mapping file inputs from user inputs to variable pool.""" variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) @@ -284,7 +340,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_missing_variable_error(self): """Test that mapping raises error when required variable is missing.""" variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) @@ -310,7 +366,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_alternative_key_format(self): """Test mapping with alternative key format (without node prefix).""" variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) @@ -340,7 +396,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_complex_selectors(self): """Test mapping with complex node variable keys.""" variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) @@ -376,7 +432,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_invalid_node_variable(self): """Test that mapping handles invalid node variable format.""" variable_pool = VariablePool( - system_variables=SystemVariable.empty(), + system_variables=SystemVariable.default(), user_inputs={}, ) diff --git a/api/tests/unit_tests/extensions/logstore/__init__.py b/api/tests/unit_tests/extensions/logstore/__init__.py new file mode 100644 index 0000000000..fe9ada9128 --- /dev/null +++ b/api/tests/unit_tests/extensions/logstore/__init__.py @@ -0,0 +1 @@ +"""LogStore extension unit tests.""" diff --git a/api/tests/unit_tests/extensions/logstore/test_sql_escape.py b/api/tests/unit_tests/extensions/logstore/test_sql_escape.py new file mode 100644 index 0000000000..63172b3f9b --- /dev/null +++ b/api/tests/unit_tests/extensions/logstore/test_sql_escape.py @@ -0,0 +1,469 @@ +""" +Unit tests for SQL escape utility functions. + +These tests ensure that SQL injection attacks are properly prevented +in LogStore queries, particularly for cross-tenant access scenarios. +""" + +import pytest + +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string + + +class TestEscapeSQLString: + """Test escape_sql_string function.""" + + def test_escape_empty_string(self): + """Test escaping empty string.""" + assert escape_sql_string("") == "" + + def test_escape_normal_string(self): + """Test escaping string without special characters.""" + assert escape_sql_string("tenant_abc123") == "tenant_abc123" + assert escape_sql_string("app-uuid-1234") == "app-uuid-1234" + + def test_escape_single_quote(self): + """Test escaping single quote.""" + # Single quote should be doubled + assert escape_sql_string("tenant'id") == "tenant''id" + assert escape_sql_string("O'Reilly") == "O''Reilly" + + def test_escape_multiple_quotes(self): + """Test escaping multiple single quotes.""" + assert escape_sql_string("a'b'c") == "a''b''c" + assert escape_sql_string("'''") == "''''''" + + # === SQL Injection Attack Scenarios === + + def test_prevent_boolean_injection(self): + """Test prevention of boolean injection attacks.""" + # Classic OR 1=1 attack + malicious_input = "tenant' OR '1'='1" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' OR ''1''=''1" + + # When used in SQL, this becomes a safe string literal + sql = f"WHERE tenant_id='{escaped}'" + assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'" + # The entire input is now a string literal that won't match any tenant + + def test_prevent_or_injection(self): + """Test prevention of OR-based injection.""" + malicious_input = "tenant_a' OR tenant_id='tenant_b" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant_a'' OR tenant_id=''tenant_b" + + sql = f"WHERE tenant_id='{escaped}'" + # The OR is now part of the string literal, not SQL logic + assert "OR tenant_id=" in sql + # The SQL has: opening ', doubled internal quotes '', and closing ' + assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'" + + def test_prevent_union_injection(self): + """Test prevention of UNION-based injection.""" + malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1" + escaped = escape_sql_string(malicious_input) + assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1" + + # UNION becomes part of the string literal + assert "UNION" in escaped + assert escaped.count("''") == 4 # All internal quotes are doubled + + def test_prevent_comment_injection(self): + """Test prevention of comment-based injection.""" + # SQL comment to bypass remaining conditions + malicious_input = "tenant' --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' --" + + sql = f"WHERE tenant_id='{escaped}' AND deleted=false" + # The -- is now inside the string, not a SQL comment + assert "--" in sql + assert "AND deleted=false" in sql # This part is NOT commented out + + def test_prevent_semicolon_injection(self): + """Test prevention of semicolon-based multi-statement injection.""" + malicious_input = "tenant'; DROP TABLE users; --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant''; DROP TABLE users; --" + + # Semicolons and DROP are now part of the string + assert "DROP TABLE" in escaped + + def test_prevent_time_based_blind_injection(self): + """Test prevention of time-based blind SQL injection.""" + malicious_input = "tenant' AND SLEEP(5) --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' AND SLEEP(5) --" + + # SLEEP becomes part of the string + assert "SLEEP" in escaped + + def test_prevent_wildcard_injection(self): + """Test prevention of wildcard-based injection.""" + malicious_input = "tenant' OR tenant_id LIKE '%" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' OR tenant_id LIKE ''%" + + # The LIKE and wildcard are now part of the string + assert "LIKE" in escaped + + def test_prevent_null_byte_injection(self): + """Test handling of null bytes.""" + # Null bytes can sometimes bypass filters + malicious_input = "tenant\x00' OR '1'='1" + escaped = escape_sql_string(malicious_input) + # Null byte is preserved, but quote is escaped + assert "''1''=''1" in escaped + + # === Real-world SAAS Scenarios === + + def test_cross_tenant_access_attempt(self): + """Test prevention of cross-tenant data access.""" + # Attacker tries to access another tenant's data + attacker_input = "tenant_b' OR tenant_id='tenant_a" + escaped = escape_sql_string(attacker_input) + + sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'" + # The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a" + # which doesn't exist - preventing access to either tenant's data + assert "tenant_b'' OR tenant_id=''tenant_a" in sql + + def test_cross_app_access_attempt(self): + """Test prevention of cross-application data access.""" + attacker_input = "app1' OR app_id='app2" + escaped = escape_sql_string(attacker_input) + + sql = f"WHERE app_id='{escaped}'" + # Cannot access app2's data + assert "app1'' OR app_id=''app2" in sql + + def test_bypass_status_filter(self): + """Test prevention of bypassing status filters.""" + # Try to see all statuses instead of just 'running' + attacker_input = "running' OR status LIKE '%" + escaped = escape_sql_string(attacker_input) + + sql = f"WHERE status='{escaped}'" + # Status condition is not bypassed + assert "running'' OR status LIKE ''%" in sql + + # === Edge Cases === + + def test_escape_only_quotes(self): + """Test string with only quotes.""" + assert escape_sql_string("'") == "''" + assert escape_sql_string("''") == "''''" + + def test_escape_mixed_content(self): + """Test string with mixed quotes and other chars.""" + input_str = "It's a 'test' of O'Reilly's code" + escaped = escape_sql_string(input_str) + assert escaped == "It''s a ''test'' of O''Reilly''s code" + + def test_escape_unicode_with_quotes(self): + """Test Unicode strings with quotes.""" + input_str = "租户' OR '1'='1" + escaped = escape_sql_string(input_str) + assert escaped == "租户'' OR ''1''=''1" + + +class TestEscapeIdentifier: + """Test escape_identifier function.""" + + def test_escape_uuid(self): + """Test escaping UUID identifiers.""" + uuid = "550e8400-e29b-41d4-a716-446655440000" + assert escape_identifier(uuid) == uuid + + def test_escape_alphanumeric_id(self): + """Test escaping alphanumeric identifiers.""" + assert escape_identifier("tenant_123") == "tenant_123" + assert escape_identifier("app-abc-123") == "app-abc-123" + + def test_escape_identifier_with_quote(self): + """Test escaping identifier with single quote.""" + malicious = "tenant' OR '1'='1" + escaped = escape_identifier(malicious) + assert escaped == "tenant'' OR ''1''=''1" + + def test_identifier_injection_attempt(self): + """Test prevention of injection through identifiers.""" + # Common identifier injection patterns + test_cases = [ + ("id' OR '1'='1", "id'' OR ''1''=''1"), + ("id'; DROP TABLE", "id''; DROP TABLE"), + ("id' UNION SELECT", "id'' UNION SELECT"), + ] + + for malicious, expected in test_cases: + assert escape_identifier(malicious) == expected + + +class TestSQLInjectionIntegration: + """Integration tests simulating real SQL construction scenarios.""" + + def test_complete_where_clause_safety(self): + """Test that a complete WHERE clause is safe from injection.""" + # Simulating typical query construction + tenant_id = "tenant' OR '1'='1" + app_id = "app' UNION SELECT" + run_id = "run' --" + + escaped_tenant = escape_identifier(tenant_id) + escaped_app = escape_identifier(app_id) + escaped_run = escape_identifier(run_id) + + sql = f""" + SELECT * FROM workflow_runs + WHERE tenant_id='{escaped_tenant}' + AND app_id='{escaped_app}' + AND id='{escaped_run}' + """ + + # Verify all special characters are escaped + assert "tenant'' OR ''1''=''1" in sql + assert "app'' UNION SELECT" in sql + assert "run'' --" in sql + + # Verify SQL structure is preserved (3 conditions with AND) + assert sql.count("AND") == 2 + + def test_multiple_conditions_with_injection_attempts(self): + """Test multiple conditions all attempting injection.""" + conditions = { + "tenant_id": "t1' OR tenant_id='t2", + "app_id": "a1' OR app_id='a2", + "status": "running' OR '1'='1", + } + + where_parts = [] + for field, value in conditions.items(): + escaped = escape_sql_string(value) + where_parts.append(f"{field}='{escaped}'") + + where_clause = " AND ".join(where_parts) + + # All injection attempts are neutralized + assert "t1'' OR tenant_id=''t2" in where_clause + assert "a1'' OR app_id=''a2" in where_clause + assert "running'' OR ''1''=''1" in where_clause + + # AND structure is preserved + assert where_clause.count(" AND ") == 2 + + @pytest.mark.parametrize( + ("attack_vector", "description"), + [ + ("' OR '1'='1", "Boolean injection"), + ("' OR '1'='1' --", "Boolean with comment"), + ("' UNION SELECT * FROM users --", "Union injection"), + ("'; DROP TABLE workflow_runs; --", "Destructive command"), + ("' AND SLEEP(10) --", "Time-based blind"), + ("' OR tenant_id LIKE '%", "Wildcard injection"), + ("admin' --", "Comment bypass"), + ("' OR 1=1 LIMIT 1 --", "Limit bypass"), + ], + ) + def test_common_injection_vectors(self, attack_vector, description): + """Test protection against common injection attack vectors.""" + escaped = escape_sql_string(attack_vector) + + # Build SQL + sql = f"WHERE tenant_id='{escaped}'" + + # Verify the attack string is now a safe literal + # The key indicator: all internal single quotes are doubled + internal_quotes = escaped.count("''") + original_quotes = attack_vector.count("'") + + # Each original quote should be doubled + assert internal_quotes == original_quotes + + # Verify SQL has exactly 2 quotes (opening and closing) + assert sql.count("'") >= 2 # At least opening and closing + + def test_logstore_specific_scenario(self): + """Test SQL injection prevention in LogStore-specific scenarios.""" + # Simulate LogStore query with window function + tenant_id = "tenant' OR '1'='1" + app_id = "app' UNION SELECT" + + escaped_tenant = escape_identifier(tenant_id) + escaped_app = escape_identifier(app_id) + + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM workflow_execution_logstore + WHERE tenant_id='{escaped_tenant}' + AND app_id='{escaped_app}' + AND __time__ > 0 + ) AS subquery WHERE rn = 1 + """ + + # Complex query structure is maintained + assert "ROW_NUMBER()" in sql + assert "PARTITION BY id" in sql + + # Injection attempts are escaped + assert "tenant'' OR ''1''=''1" in sql + assert "app'' UNION SELECT" in sql + + +# ==================================================================================== +# Tests for LogStore Query Syntax (SDK Mode) +# ==================================================================================== + + +class TestLogStoreQueryEscape: + """Test escape_logstore_query_value for SDK mode query syntax.""" + + def test_normal_value(self): + """Test escaping normal alphanumeric value.""" + value = "550e8400-e29b-41d4-a716-446655440000" + escaped = escape_logstore_query_value(value) + + # Should be wrapped in double quotes + assert escaped == '"550e8400-e29b-41d4-a716-446655440000"' + + def test_empty_value(self): + """Test escaping empty string.""" + assert escape_logstore_query_value("") == '""' + + def test_value_with_and_keyword(self): + """Test that 'and' keyword is neutralized when quoted.""" + malicious = "value and field:evil" + escaped = escape_logstore_query_value(malicious) + + # Should be wrapped in quotes, making 'and' a literal + assert escaped == '"value and field:evil"' + + # Simulate using in query + query = f"tenant_id:{escaped}" + assert query == 'tenant_id:"value and field:evil"' + + def test_value_with_or_keyword(self): + """Test that 'or' keyword is neutralized when quoted.""" + malicious = "tenant_a or tenant_id:tenant_b" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"tenant_a or tenant_id:tenant_b"' + + query = f"tenant_id:{escaped}" + assert "or" in query # Present but as literal string + + def test_value_with_not_keyword(self): + """Test that 'not' keyword is neutralized when quoted.""" + malicious = "not field:value" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"not field:value"' + + def test_value_with_parentheses(self): + """Test that parentheses are neutralized when quoted.""" + malicious = "(tenant_a or tenant_b)" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"(tenant_a or tenant_b)"' + assert "(" in escaped # Present as literal + assert ")" in escaped # Present as literal + + def test_value_with_colon(self): + """Test that colons are neutralized when quoted.""" + malicious = "field:value" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"field:value"' + assert ":" in escaped # Present as literal + + def test_value_with_double_quotes(self): + """Test that internal double quotes are escaped.""" + value_with_quotes = 'tenant"test"value' + escaped = escape_logstore_query_value(value_with_quotes) + + # Double quotes should be escaped with backslash + assert escaped == '"tenant\\"test\\"value"' + # Should have outer quotes plus escaped inner quotes + assert '\\"' in escaped + + def test_value_with_backslash(self): + """Test that backslashes are escaped.""" + value_with_backslash = "tenant\\test" + escaped = escape_logstore_query_value(value_with_backslash) + + # Backslash should be escaped + assert escaped == '"tenant\\\\test"' + assert "\\\\" in escaped + + def test_value_with_backslash_and_quote(self): + """Test escaping both backslash and double quote.""" + value = 'path\\to\\"file"' + escaped = escape_logstore_query_value(value) + + # Both should be escaped + assert escaped == '"path\\\\to\\\\\\"file\\""' + # Verify escape order is correct + assert "\\\\" in escaped # Escaped backslash + assert '\\"' in escaped # Escaped double quote + + def test_complex_injection_attempt(self): + """Test complex injection combining multiple operators.""" + malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")' + escaped = escape_logstore_query_value(malicious) + + # All special chars should be literals or escaped + assert escaped.startswith('"') + assert escaped.endswith('"') + # Inner double quotes escaped, operators become literals + assert "or" in escaped + assert "and" in escaped + assert '\\"' in escaped # Escaped quotes + + def test_only_backslash(self): + """Test escaping a single backslash.""" + assert escape_logstore_query_value("\\") == '"\\\\"' + + def test_only_double_quote(self): + """Test escaping a single double quote.""" + assert escape_logstore_query_value('"') == '"\\""' + + def test_multiple_backslashes(self): + """Test escaping multiple consecutive backslashes.""" + assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6 + + def test_escape_sequence_like_input(self): + """Test that existing escape sequences are properly escaped.""" + # Input looks like already escaped, but we still escape it + value = 'value\\"test' + escaped = escape_logstore_query_value(value) + # \\ -> \\\\, " -> \" + assert escaped == '"value\\\\\\"test"' + + +@pytest.mark.parametrize( + ("attack_scenario", "field", "malicious_value"), + [ + ("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"), + ("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"), + ("Boolean logic", "status", "succeeded or status:failed"), + ("Negation", "tenant_id", "not tenant_a"), + ("Field injection", "run_id", "run123 and tenant_id:evil_tenant"), + ("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"), + ("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'), + ("Backslash escape bypass", "app_id", "app\\ and app_id:evil"), + ], +) +def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str): + """Test that various LogStore query injection attempts are neutralized.""" + escaped = escape_logstore_query_value(malicious_value) + + # Build query + query = f"{field}:{escaped}" + + # All operators should be within quoted string (literals) + assert escaped.startswith('"') + assert escaped.endswith('"') + + # Verify the full query structure is safe + assert query.count(":") >= 1 # At least the main field:value separator diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index fc7a090ef9..d3a4d69f07 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -8,11 +8,12 @@ class TestCelerySSLConfiguration: """Test suite for Celery SSL configuration.""" def test_get_celery_ssl_options_when_ssl_disabled(self): - """Test SSL options when REDIS_USE_SSL is False.""" - mock_config = MagicMock() - mock_config.REDIS_USE_SSL = False + """Test SSL options when BROKER_USE_SSL is False.""" + from configs import DifyConfig - with patch("extensions.ext_celery.dify_config", mock_config): + dify_config = DifyConfig(CELERY_BROKER_URL="redis://localhost:6379/0") + + with patch("extensions.ext_celery.dify_config", dify_config): from extensions.ext_celery import _get_celery_ssl_options result = _get_celery_ssl_options() @@ -21,7 +22,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_when_broker_not_redis(self): """Test SSL options when broker is not Redis.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "amqp://localhost:5672" with patch("extensions.ext_celery.dify_config", mock_config): @@ -33,7 +33,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_none(self): """Test SSL options with CERT_NONE requirement.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" mock_config.REDIS_SSL_CA_CERTS = None @@ -53,7 +52,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_required(self): """Test SSL options with CERT_REQUIRED and certificates.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED" mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" @@ -73,7 +71,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_optional(self): """Test SSL options with CERT_OPTIONAL requirement.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL" mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" @@ -91,7 +88,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_invalid_cert_reqs(self): """Test SSL options with invalid cert requirement defaults to CERT_NONE.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE" mock_config.REDIS_SSL_CA_CERTS = None @@ -108,7 +104,6 @@ class TestCelerySSLConfiguration: def test_celery_init_applies_ssl_to_broker_and_backend(self): """Test that SSL options are applied to both broker and backend when using Redis.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.CELERY_BACKEND = "redis" mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" diff --git a/api/tests/unit_tests/fields/test_file_fields.py b/api/tests/unit_tests/fields/test_file_fields.py new file mode 100644 index 0000000000..8be8df16f4 --- /dev/null +++ b/api/tests/unit_tests/fields/test_file_fields.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig + + +def test_file_response_serializes_datetime() -> None: + created_at = datetime(2024, 1, 1, 12, 0, 0) + file_obj = SimpleNamespace( + id="file-1", + name="example.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by="user-1", + created_at=created_at, + preview_url="https://preview", + source_url="https://source", + original_url="https://origin", + user_id="user-1", + tenant_id="tenant-1", + conversation_id="conv-1", + file_key="key-1", + ) + + serialized = FileResponse.model_validate(file_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["id"] == "file-1" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["preview_url"] == "https://preview" + assert serialized["source_url"] == "https://source" + assert serialized["original_url"] == "https://origin" + assert serialized["user_id"] == "user-1" + assert serialized["tenant_id"] == "tenant-1" + assert serialized["conversation_id"] == "conv-1" + assert serialized["file_key"] == "key-1" + + +def test_file_with_signed_url_builds_payload() -> None: + payload = FileWithSignedUrl( + id="file-2", + name="remote.pdf", + size=2048, + extension="pdf", + url="https://signed", + mime_type="application/pdf", + created_by="user-2", + created_at=datetime(2024, 1, 2, 0, 0, 0), + ) + + dumped = payload.model_dump(mode="json") + + assert dumped["url"] == "https://signed" + assert dumped["created_at"] == int(datetime(2024, 1, 2, 0, 0, 0).timestamp()) + + +def test_remote_file_info_and_upload_config() -> None: + info = RemoteFileInfo(file_type="text/plain", file_length=123) + assert info.model_dump(mode="json") == {"file_type": "text/plain", "file_length": 123} + + config = UploadConfig( + file_size_limit=1, + batch_count_limit=2, + file_upload_limit=3, + image_file_size_limit=4, + video_file_size_limit=5, + audio_file_size_limit=6, + workflow_file_upload_limit=7, + image_file_batch_limit=8, + single_chunk_attachment_limit=9, + attachment_image_file_size_limit=10, + ) + + dumped = config.model_dump(mode="json") + assert dumped["file_upload_limit"] == 3 + assert dumped["attachment_image_file_size_limit"] == 10 diff --git a/api/tests/unit_tests/libs/test_archive_storage.py b/api/tests/unit_tests/libs/test_archive_storage.py new file mode 100644 index 0000000000..de3c9c4737 --- /dev/null +++ b/api/tests/unit_tests/libs/test_archive_storage.py @@ -0,0 +1,286 @@ +import base64 +import hashlib +from datetime import datetime +from unittest.mock import ANY, MagicMock + +import pytest +from botocore.exceptions import ClientError + +from libs import archive_storage as storage_module +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageError, + ArchiveStorageNotConfiguredError, +) + +BUCKET_NAME = "archive-bucket" + + +def _configure_storage(monkeypatch, **overrides): + defaults = { + "ARCHIVE_STORAGE_ENABLED": True, + "ARCHIVE_STORAGE_ENDPOINT": "https://storage.example.com", + "ARCHIVE_STORAGE_ARCHIVE_BUCKET": BUCKET_NAME, + "ARCHIVE_STORAGE_ACCESS_KEY": "access", + "ARCHIVE_STORAGE_SECRET_KEY": "secret", + "ARCHIVE_STORAGE_REGION": "auto", + } + defaults.update(overrides) + for key, value in defaults.items(): + monkeypatch.setattr(storage_module.dify_config, key, value, raising=False) + + +def _client_error(code: str) -> ClientError: + return ClientError({"Error": {"Code": code}}, "Operation") + + +def _mock_client(monkeypatch): + client = MagicMock() + client.head_bucket.return_value = None + # Configure put_object to return a proper ETag that matches the MD5 hash + # The ETag format is typically the MD5 hash wrapped in quotes + + def mock_put_object(**kwargs): + md5_hash = kwargs.get("Body", b"") + if isinstance(md5_hash, bytes): + md5_hash = hashlib.md5(md5_hash).hexdigest() + else: + md5_hash = hashlib.md5(md5_hash.encode()).hexdigest() + response = MagicMock() + response.get.return_value = f'"{md5_hash}"' + return response + + client.put_object.side_effect = mock_put_object + boto_client = MagicMock(return_value=client) + monkeypatch.setattr(storage_module.boto3, "client", boto_client) + return client, boto_client + + +def test_init_disabled(monkeypatch): + _configure_storage(monkeypatch, ARCHIVE_STORAGE_ENABLED=False) + with pytest.raises(ArchiveStorageNotConfiguredError, match="not enabled"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_missing_config(monkeypatch): + _configure_storage(monkeypatch, ARCHIVE_STORAGE_ENDPOINT=None) + with pytest.raises(ArchiveStorageNotConfiguredError, match="incomplete"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_not_found(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("404") + + with pytest.raises(ArchiveStorageNotConfiguredError, match="does not exist"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_access_denied(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("403") + + with pytest.raises(ArchiveStorageNotConfiguredError, match="Access denied"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_other_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("500") + + with pytest.raises(ArchiveStorageError, match="Failed to access archive bucket"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_sets_client(monkeypatch): + _configure_storage(monkeypatch) + client, boto_client = _mock_client(monkeypatch) + + storage = ArchiveStorage(bucket=BUCKET_NAME) + + boto_client.assert_called_once_with( + "s3", + endpoint_url="https://storage.example.com", + aws_access_key_id="access", + aws_secret_access_key="secret", + region_name="auto", + config=ANY, + ) + assert storage.client is client + assert storage.bucket == BUCKET_NAME + + +def test_put_object_returns_checksum(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + + data = b"hello" + checksum = storage.put_object("key", data) + + expected_md5 = hashlib.md5(data).hexdigest() + expected_content_md5 = base64.b64encode(hashlib.md5(data).digest()).decode() + client.put_object.assert_called_once_with( + Bucket="archive-bucket", + Key="key", + Body=data, + ContentMD5=expected_content_md5, + ) + assert checksum == expected_md5 + + +def test_put_object_raises_on_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + client.put_object.side_effect = _client_error("500") + + with pytest.raises(ArchiveStorageError, match="Failed to upload object"): + storage.put_object("key", b"data") + + +def test_get_object_returns_bytes(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + body = MagicMock() + body.read.return_value = b"payload" + client.get_object.return_value = {"Body": body} + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.get_object("key") == b"payload" + + +def test_get_object_missing(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.get_object.side_effect = _client_error("NoSuchKey") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(FileNotFoundError, match="Archive object not found"): + storage.get_object("missing") + + +def test_get_object_stream(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + body = MagicMock() + body.iter_chunks.return_value = [b"a", b"b"] + client.get_object.return_value = {"Body": body} + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert list(storage.get_object_stream("key")) == [b"a", b"b"] + + +def test_get_object_stream_missing(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.get_object.side_effect = _client_error("NoSuchKey") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(FileNotFoundError, match="Archive object not found"): + list(storage.get_object_stream("missing")) + + +def test_object_exists(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.object_exists("key") is True + client.head_object.side_effect = _client_error("404") + assert storage.object_exists("missing") is False + + +def test_delete_object_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.delete_object.side_effect = _client_error("500") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to delete object"): + storage.delete_object("key") + + +def test_list_objects(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + paginator = MagicMock() + paginator.paginate.return_value = [ + {"Contents": [{"Key": "a"}, {"Key": "b"}]}, + {"Contents": [{"Key": "c"}]}, + ] + client.get_paginator.return_value = paginator + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.list_objects("prefix") == ["a", "b", "c"] + paginator.paginate.assert_called_once_with(Bucket="archive-bucket", Prefix="prefix") + + +def test_list_objects_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + paginator = MagicMock() + paginator.paginate.side_effect = _client_error("500") + client.get_paginator.return_value = paginator + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to list objects"): + storage.list_objects("prefix") + + +def test_generate_presigned_url(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.generate_presigned_url.return_value = "http://signed-url" + storage = ArchiveStorage(bucket=BUCKET_NAME) + + url = storage.generate_presigned_url("key", expires_in=123) + + client.generate_presigned_url.assert_called_once_with( + ClientMethod="get_object", + Params={"Bucket": "archive-bucket", "Key": "key"}, + ExpiresIn=123, + ) + assert url == "http://signed-url" + + +def test_generate_presigned_url_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.generate_presigned_url.side_effect = _client_error("500") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to generate pre-signed URL"): + storage.generate_presigned_url("key") + + +def test_serialization_roundtrip(): + records = [ + { + "id": "1", + "created_at": datetime(2024, 1, 1, 12, 0, 0), + "payload": {"nested": "value"}, + "items": [{"name": "a"}], + }, + {"id": "2", "value": 123}, + ] + + data = ArchiveStorage.serialize_to_jsonl(records) + decoded = ArchiveStorage.deserialize_from_jsonl(data) + + assert decoded[0]["id"] == "1" + assert decoded[0]["payload"]["nested"] == "value" + assert decoded[0]["items"][0]["name"] == "a" + assert "2024-01-01T12:00:00" in decoded[0]["created_at"] + assert decoded[1]["value"] == 123 + + +def test_content_md5_matches_checksum(): + data = b"checksum" + expected = base64.b64encode(hashlib.md5(data).digest()).decode() + + assert ArchiveStorage._content_md5(data) == expected + assert ArchiveStorage.compute_checksum(data) == hashlib.md5(data).hexdigest() diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py index 9aa157a651..5135970bcc 100644 --- a/api/tests/unit_tests/libs/test_external_api.py +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite(): assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty." -def test_external_api_param_mapping_and_quota_and_exc_info_none(): - # Force exc_info() to return (None,None,None) only during request - import libs.external_api as ext +def test_external_api_param_mapping_and_quota(): + app = _create_api_app() + client = app.test_client() - orig_exc_info = ext.sys.exc_info - try: - ext.sys.exc_info = lambda: (None, None, None) + # Param errors mapping payload path + res = client.get("/api/param-errors") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "invalid_param" + assert data["params"] == "field" - app = _create_api_app() - client = app.test_client() - - # Param errors mapping payload path - res = client.get("/api/param-errors") - assert res.status_code == 400 - data = res.get_json() - assert data["code"] == "invalid_param" - assert data["params"] == "field" - - # Quota path — depending on Flask-RESTX internals it may be handled - res = client.get("/api/quota") - assert res.status_code in (400, 429) - finally: - ext.sys.exc_info = orig_exc_info # type: ignore[assignment] + # Quota path — depending on Flask-RESTX internals it may be handled + res = client.get("/api/quota") + assert res.status_code in (400, 429) def test_unauthorized_and_force_logout_clears_cookies(): diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index a96817cfbc..1a93dbbca1 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -2,7 +2,7 @@ from datetime import datetime import pytest -from libs.helper import OptionalTimestampField, extract_tenant_id +from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -78,3 +78,51 @@ class TestOptionalTimestampField: value = datetime(2024, 1, 2, 3, 4, 5) assert field.format(value) == int(value.timestamp()) + + +class TestEscapeLikePattern: + """Test cases for the escape_like_pattern utility function.""" + + def test_escape_percent_character(self): + """Test escaping percent character.""" + result = escape_like_pattern("50% discount") + assert result == "50\\% discount" + + def test_escape_underscore_character(self): + """Test escaping underscore character.""" + result = escape_like_pattern("test_data") + assert result == "test\\_data" + + def test_escape_backslash_character(self): + """Test escaping backslash character.""" + result = escape_like_pattern("path\\to\\file") + assert result == "path\\\\to\\\\file" + + def test_escape_combined_special_characters(self): + """Test escaping multiple special characters together.""" + result = escape_like_pattern("file_50%\\path") + assert result == "file\\_50\\%\\\\path" + + def test_escape_empty_string(self): + """Test escaping empty string returns empty string.""" + result = escape_like_pattern("") + assert result == "" + + def test_escape_none_handling(self): + """Test escaping None returns None (falsy check handles it).""" + # The function checks `if not pattern`, so None is falsy and returns as-is + result = escape_like_pattern(None) + assert result is None + + def test_escape_normal_string_no_change(self): + """Test that normal strings without special characters are unchanged.""" + result = escape_like_pattern("normal text") + assert result == "normal text" + + def test_escape_order_matters(self): + """Test that backslash is escaped first to prevent double escaping.""" + # If we escape % first, then escape \, we might get wrong results + # This test ensures the order is correct: \ first, then % and _ + result = escape_like_pattern("test\\%_value") + # Should be: test\\\%\_value + assert result == "test\\\\\\%\\_value" diff --git a/api/tests/unit_tests/libs/test_smtp_client.py b/api/tests/unit_tests/libs/test_smtp_client.py index fcee01ca00..042bc15643 100644 --- a/api/tests/unit_tests/libs/test_smtp_client.py +++ b/api/tests/unit_tests/libs/test_smtp_client.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest @@ -17,7 +17,7 @@ def test_smtp_plain_success(mock_smtp_cls: MagicMock): client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") client.send(_mail()) - mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10) + mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10, local_hostname=ANY) mock_smtp.sendmail.assert_called_once() mock_smtp.quit.assert_called_once() @@ -38,7 +38,7 @@ def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock): ) client.send(_mail()) - mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10) + mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10, local_hostname=ANY) assert mock_smtp.ehlo.call_count == 2 mock_smtp.starttls.assert_called_once() mock_smtp.login.assert_called_once_with("user", "pass") diff --git a/api/tests/unit_tests/libs/test_workspace_permission.py b/api/tests/unit_tests/libs/test_workspace_permission.py new file mode 100644 index 0000000000..89586ccf26 --- /dev/null +++ b/api/tests/unit_tests/libs/test_workspace_permission.py @@ -0,0 +1,142 @@ +from unittest.mock import Mock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from libs.workspace_permission import ( + check_workspace_member_invite_permission, + check_workspace_owner_transfer_permission, +) + + +class TestWorkspacePermissionHelper: + """Test workspace permission helper functions.""" + + @patch("libs.workspace_permission.dify_config") + @patch("libs.workspace_permission.EnterpriseService") + def test_community_edition_allows_invite(self, mock_enterprise_service, mock_config): + """Community edition should always allow invitations without calling any service.""" + mock_config.ENTERPRISE_ENABLED = False + + # Should not raise + check_workspace_member_invite_permission("test-workspace-id") + + # EnterpriseService should NOT be called in community edition + mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called() + + @patch("libs.workspace_permission.dify_config") + @patch("libs.workspace_permission.FeatureService") + def test_community_edition_allows_transfer(self, mock_feature_service, mock_config): + """Community edition should check billing plan but not call enterprise service.""" + mock_config.ENTERPRISE_ENABLED = False + mock_features = Mock() + mock_features.is_allow_transfer_workspace = True + mock_feature_service.get_features.return_value = mock_features + + # Should not raise + check_workspace_owner_transfer_permission("test-workspace-id") + + mock_feature_service.get_features.assert_called_once_with("test-workspace-id") + + @patch("libs.workspace_permission.EnterpriseService") + @patch("libs.workspace_permission.dify_config") + def test_enterprise_blocks_invite_when_disabled(self, mock_config, mock_enterprise_service): + """Enterprise edition should block invitations when workspace policy is False.""" + mock_config.ENTERPRISE_ENABLED = True + + mock_permission = Mock() + mock_permission.allow_member_invite = False + mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission + + with pytest.raises(Forbidden, match="Workspace policy prohibits member invitations"): + check_workspace_member_invite_permission("test-workspace-id") + + mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id") + + @patch("libs.workspace_permission.EnterpriseService") + @patch("libs.workspace_permission.dify_config") + def test_enterprise_allows_invite_when_enabled(self, mock_config, mock_enterprise_service): + """Enterprise edition should allow invitations when workspace policy is True.""" + mock_config.ENTERPRISE_ENABLED = True + + mock_permission = Mock() + mock_permission.allow_member_invite = True + mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission + + # Should not raise + check_workspace_member_invite_permission("test-workspace-id") + + mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id") + + @patch("libs.workspace_permission.EnterpriseService") + @patch("libs.workspace_permission.dify_config") + @patch("libs.workspace_permission.FeatureService") + def test_billing_plan_blocks_transfer(self, mock_feature_service, mock_config, mock_enterprise_service): + """SANDBOX billing plan should block owner transfer before checking enterprise policy.""" + mock_config.ENTERPRISE_ENABLED = True + mock_features = Mock() + mock_features.is_allow_transfer_workspace = False # SANDBOX plan + mock_feature_service.get_features.return_value = mock_features + + with pytest.raises(Forbidden, match="Your current plan does not allow workspace ownership transfer"): + check_workspace_owner_transfer_permission("test-workspace-id") + + # Enterprise service should NOT be called since billing plan already blocks + mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called() + + @patch("libs.workspace_permission.EnterpriseService") + @patch("libs.workspace_permission.dify_config") + @patch("libs.workspace_permission.FeatureService") + def test_enterprise_blocks_transfer_when_disabled(self, mock_feature_service, mock_config, mock_enterprise_service): + """Enterprise edition should block transfer when workspace policy is False.""" + mock_config.ENTERPRISE_ENABLED = True + mock_features = Mock() + mock_features.is_allow_transfer_workspace = True # Billing plan allows + mock_feature_service.get_features.return_value = mock_features + + mock_permission = Mock() + mock_permission.allow_owner_transfer = False # Workspace policy blocks + mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission + + with pytest.raises(Forbidden, match="Workspace policy prohibits ownership transfer"): + check_workspace_owner_transfer_permission("test-workspace-id") + + mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id") + + @patch("libs.workspace_permission.EnterpriseService") + @patch("libs.workspace_permission.dify_config") + @patch("libs.workspace_permission.FeatureService") + def test_enterprise_allows_transfer_when_both_enabled( + self, mock_feature_service, mock_config, mock_enterprise_service + ): + """Enterprise edition should allow transfer when both billing and workspace policy allow.""" + mock_config.ENTERPRISE_ENABLED = True + mock_features = Mock() + mock_features.is_allow_transfer_workspace = True # Billing plan allows + mock_feature_service.get_features.return_value = mock_features + + mock_permission = Mock() + mock_permission.allow_owner_transfer = True # Workspace policy allows + mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission + + # Should not raise + check_workspace_owner_transfer_permission("test-workspace-id") + + mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id") + + @patch("libs.workspace_permission.logger") + @patch("libs.workspace_permission.EnterpriseService") + @patch("libs.workspace_permission.dify_config") + def test_enterprise_service_error_fails_open(self, mock_config, mock_enterprise_service, mock_logger): + """On enterprise service error, should fail-open (allow) and log error.""" + mock_config.ENTERPRISE_ENABLED = True + + # Simulate enterprise service error + mock_enterprise_service.WorkspacePermissionService.get_permission.side_effect = Exception("Service unavailable") + + # Should not raise (fail-open) + check_workspace_member_invite_permission("test-workspace-id") + + # Should log the error + mock_logger.exception.assert_called_once() + assert "Failed to check workspace invite permission" in str(mock_logger.exception.call_args) diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 8f93de46b9..3317e9a311 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -114,7 +114,7 @@ class TestAppModelValidation: def test_icon_type_validation(self): """Test icon type enum values.""" # Assert - assert {t.value for t in IconType} == {"image", "emoji"} + assert {t.value for t in IconType} == {"image", "emoji", "link"} def test_app_desc_or_prompt_with_description(self): """Test desc_or_prompt property when description exists.""" diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index 303f0493bd..a0fed1aa14 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from qcloud_cos import CosConfig @@ -18,3 +18,72 @@ class TestTencentCos(BaseStorageTest): with patch.object(CosConfig, "__init__", return_value=None): self.storage = TencentCosStorage() self.storage.bucket_name = get_example_bucket() + + +class TestTencentCosConfiguration: + """Tests for TencentCosStorage initialization with different configurations.""" + + def test_init_with_custom_domain(self): + """Test initialization with custom domain configured.""" + # Mock dify_config to return custom domain configuration + mock_dify_config = MagicMock() + mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = "cos.example.com" + mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id" + mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key" + mock_dify_config.TENCENT_COS_SCHEME = "https" + + # Mock CosConfig and CosS3Client + mock_config_instance = MagicMock() + mock_client = MagicMock() + + with ( + patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config), + patch( + "extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance + ) as mock_cos_config, + patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client), + ): + TencentCosStorage() + + # Verify CosConfig was called with Domain parameter (not Region) + mock_cos_config.assert_called_once() + call_kwargs = mock_cos_config.call_args[1] + assert "Domain" in call_kwargs + assert call_kwargs["Domain"] == "cos.example.com" + assert "Region" not in call_kwargs + assert call_kwargs["SecretId"] == "test-secret-id" + assert call_kwargs["SecretKey"] == "test-secret-key" + assert call_kwargs["Scheme"] == "https" + + def test_init_with_region(self): + """Test initialization with region configured (no custom domain).""" + # Mock dify_config to return region configuration + mock_dify_config = MagicMock() + mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = None + mock_dify_config.TENCENT_COS_REGION = "ap-guangzhou" + mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id" + mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key" + mock_dify_config.TENCENT_COS_SCHEME = "https" + + # Mock CosConfig and CosS3Client + mock_config_instance = MagicMock() + mock_client = MagicMock() + + with ( + patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config), + patch( + "extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance + ) as mock_cos_config, + patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client), + ): + TencentCosStorage() + + # Verify CosConfig was called with Region parameter (not Domain) + mock_cos_config.assert_called_once() + call_kwargs = mock_cos_config.call_args[1] + assert "Region" in call_kwargs + assert call_kwargs["Region"] == "ap-guangzhou" + assert "Domain" not in call_kwargs + assert call_kwargs["SecretId"] == "test-secret-id" + assert call_kwargs["SecretKey"] == "test-secret-key" + assert call_kwargs["Scheme"] == "https" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index e9b99e7f53..710b817548 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -4,6 +4,7 @@ from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session, sessionmaker from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType @@ -109,6 +110,42 @@ class TestDifyAPISQLAlchemyWorkflowRunRepository: return pause +class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_get_runs_batch_by_time_range_filters_terminal_statuses( + self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock + ): + scalar_result = Mock() + scalar_result.all.return_value = [] + mock_session.scalars.return_value = scalar_result + + repository.get_runs_batch_by_time_range( + start_from=None, + end_before=datetime(2024, 1, 1), + last_seen=None, + batch_size=50, + ) + + stmt = mock_session.scalars.call_args[0][0] + compiled_sql = str( + stmt.compile( + dialect=postgresql.dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + + assert "workflow_runs.status" in compiled_sql + for status in ( + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.STOPPED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + ): + assert f"'{status.value}'" in compiled_sql + + assert "'running'" not in compiled_sql + assert "'paused'" not in compiled_sql + + class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test create_workflow_pause method.""" @@ -186,6 +223,61 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ) +class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): + node_ids_result = Mock() + node_ids_result.all.return_value = [] + pause_ids_result = Mock() + pause_ids_result.all.return_value = [] + mock_session.scalars.side_effect = [node_ids_result, pause_ids_result] + + # app_logs delete, runs delete + mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)] + + fake_trigger_repo = Mock() + fake_trigger_repo.delete_by_run_ids.return_value = 3 + + run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") + counts = repository.delete_runs_with_related( + [run], + delete_node_executions=lambda session, runs: (2, 1), + delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids), + ) + + fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 + assert counts["trigger_logs"] == 3 + assert counts["runs"] == 1 + + +class TestCountRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): + pause_ids_result = Mock() + pause_ids_result.all.return_value = ["pause-1", "pause-2"] + mock_session.scalars.return_value = pause_ids_result + mock_session.scalar.side_effect = [5, 2] + + fake_trigger_repo = Mock() + fake_trigger_repo.count_by_run_ids.return_value = 3 + + run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") + counts = repository.count_runs_with_related( + [run], + count_node_executions=lambda session, runs: (2, 1), + count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids), + ) + + fake_trigger_repo.count_by_run_ids.assert_called_once_with(["run-1"]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 + assert counts["trigger_logs"] == 3 + assert counts["app_logs"] == 5 + assert counts["pauses"] == 2 + assert counts["pause_reasons"] == 2 + assert counts["runs"] == 1 + + class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test resume_workflow_pause method.""" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py new file mode 100644 index 0000000000..d409618211 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py @@ -0,0 +1,31 @@ +from unittest.mock import Mock + +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import Session + +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository + + +def test_delete_by_run_ids_executes_delete(): + session = Mock(spec=Session) + session.execute.return_value = Mock(rowcount=2) + repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + deleted = repo.delete_by_run_ids(["run-1", "run-2"]) + + stmt = session.execute.call_args[0][0] + compiled_sql = str(stmt.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) + assert "workflow_trigger_logs" in compiled_sql + assert "'run-1'" in compiled_sql + assert "'run-2'" in compiled_sql + assert deleted == 2 + + +def test_delete_by_run_ids_empty_short_circuits(): + session = Mock(spec=Session) + repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + deleted = repo.delete_by_run_ids([]) + + session.execute.assert_not_called() + assert deleted == 0 diff --git a/api/tests/unit_tests/services/enterprise/__init__.py b/api/tests/unit_tests/services/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py b/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py new file mode 100644 index 0000000000..87c03f13a3 --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py @@ -0,0 +1,59 @@ +"""Unit tests for traceparent header propagation in EnterpriseRequest. + +This test module verifies that the W3C traceparent header is properly +generated and included in HTTP requests made by EnterpriseRequest. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from services.enterprise.base import EnterpriseRequest + + +class TestTraceparentPropagation: + """Unit tests for traceparent header propagation.""" + + @pytest.fixture + def mock_enterprise_config(self): + """Mock EnterpriseRequest configuration.""" + with ( + patch.object(EnterpriseRequest, "base_url", "https://enterprise-api.example.com"), + patch.object(EnterpriseRequest, "secret_key", "test-secret-key"), + patch.object(EnterpriseRequest, "secret_key_header", "Enterprise-Api-Secret-Key"), + ): + yield + + @pytest.fixture + def mock_httpx_client(self): + """Mock httpx.Client for testing.""" + with patch("services.enterprise.base.httpx.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value.__enter__.return_value = mock_client + mock_client_class.return_value.__exit__.return_value = None + + # Setup default response + mock_response = MagicMock() + mock_response.json.return_value = {"result": "success"} + mock_client.request.return_value = mock_response + + yield mock_client + + def test_traceparent_header_included_when_generated(self, mock_enterprise_config, mock_httpx_client): + """Test that traceparent header is included when successfully generated.""" + # Arrange + expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01" + + with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent): + # Act + EnterpriseRequest.send_request("GET", "/test") + + # Assert + mock_httpx_client.request.assert_called_once() + call_args = mock_httpx_client.request.call_args + headers = call_args[1]["headers"] + + assert "traceparent" in headers + assert headers["traceparent"] == expected_traceparent + assert headers["Content-Type"] == "application/json" + assert headers["Enterprise-Api-Secret-Key"] == "test-secret-key" diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 627a04bcd0..8ae20f35d8 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from configs import dify_config -from models.account import Account +from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import ( AccountAlreadyInTenantError, @@ -619,8 +619,13 @@ class TestTenantService: mock_tenant_instance.name = "Test User's Workspace" mock_tenant_class.return_value = mock_tenant_instance - # Execute test - TenantService.create_owner_tenant_if_not_exist(mock_account) + # Mock the db import in CreditPoolService to avoid database connection + with patch("services.credit_pool_service.db") as mock_credit_pool_db: + mock_credit_pool_db.session.add = MagicMock() + mock_credit_pool_db.session.commit = MagicMock() + + # Execute test + TenantService.create_owner_tenant_if_not_exist(mock_account) # Verify tenant was created with correct parameters mock_db_dependencies["db"].session.add.assert_called() @@ -1142,9 +1147,13 @@ class TestRegisterService: mock_session = MagicMock() mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account - with patch("services.account_service.Session") as mock_session_class: + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): mock_session_class.return_value.__enter__.return_value = mock_session mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = None # Mock RegisterService.register mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( @@ -1177,9 +1186,59 @@ class TestRegisterService: email="newuser@example.com", name="newuser", language="en-US", - status="pending", + status=AccountStatus.PENDING, is_setup=True, ) + mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session) + + def test_invite_new_member_normalizes_new_account_email( + self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies + ): + """Ensure inviting with mixed-case email normalizes before registering.""" + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") + mixed_email = "Invitee@Example.com" + + mock_session = MagicMock() + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = None + + mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="new-user-789", email="invitee@example.com", name="invitee", status="pending" + ) + with patch("services.account_service.RegisterService.register") as mock_register: + mock_register.return_value = mock_new_account + with ( + patch("services.account_service.TenantService.check_member_permission") as mock_check_permission, + patch("services.account_service.TenantService.create_tenant_member") as mock_create_member, + patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant, + patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token, + ): + mock_generate_token.return_value = "invite-token-abc" + + RegisterService.invite_new_member( + tenant=mock_tenant, + email=mixed_email, + language="en-US", + role="normal", + inviter=mock_inviter, + ) + + mock_register.assert_called_once_with( + email="invitee@example.com", + name="invitee", + language="en-US", + status=AccountStatus.PENDING, + is_setup=True, + ) + mock_lookup.assert_called_once_with(mixed_email, session=mock_session) + mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account) @@ -1202,9 +1261,13 @@ class TestRegisterService: mock_session = MagicMock() mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account - with patch("services.account_service.Session") as mock_session_class: + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): mock_session_class.return_value.__enter__.return_value = mock_session mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = mock_existing_account # Mock the db.session.query for TenantAccountJoin mock_db_query = MagicMock() @@ -1233,6 +1296,7 @@ class TestRegisterService: mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_task_dependencies.delay.assert_called_once() + mock_lookup.assert_called_once_with("existing@example.com", session=mock_session) def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): """Test inviting a member who is already in the tenant.""" @@ -1246,7 +1310,6 @@ class TestRegisterService: # Mock database queries query_results = { - ("Account", "email", "existing@example.com"): mock_existing_account, ( "TenantAccountJoin", "tenant_id", @@ -1256,7 +1319,11 @@ class TestRegisterService: ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) # Mock TenantService methods - with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission: + with ( + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + patch("services.account_service.TenantService.check_member_permission") as mock_check_permission, + ): + mock_lookup.return_value = mock_existing_account # Execute test and verify exception self._assert_exception_raised( AccountAlreadyInTenantError, @@ -1267,6 +1334,7 @@ class TestRegisterService: role="normal", inviter=mock_inviter, ) + mock_lookup.assert_called_once() def test_invite_new_member_no_inviter(self): """Test inviting a member without providing an inviter.""" @@ -1492,6 +1560,30 @@ class TestRegisterService: # Verify results assert result is None + def test_get_invitation_with_case_fallback_returns_initial_match(self): + """Fallback helper should return the initial invitation when present.""" + invitation = {"workspace_id": "tenant-456"} + with patch( + "services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation + ) as mock_get: + result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + + assert result == invitation + mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123") + + def test_get_invitation_with_case_fallback_retries_with_lowercase(self): + """Fallback helper should retry with lowercase email when needed.""" + invitation = {"workspace_id": "tenant-456"} + with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get: + mock_get.side_effect = [None, invitation] + result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + + assert result == invitation + assert mock_get.call_args_list == [ + (("tenant-456", "User@Test.com", "token-123"),), + (("tenant-456", "user@test.com", "token-123"),), + ] + # ==================== Helper Method Tests ==================== def test_get_invitation_token_key(self): diff --git a/api/tests/unit_tests/services/test_archive_workflow_run_logs.py b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py new file mode 100644 index 0000000000..ef62dacd6b --- /dev/null +++ b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py @@ -0,0 +1,54 @@ +""" +Unit tests for workflow run archiving functionality. + +This module contains tests for: +- Archive service +- Rollback service +""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME + + +class TestWorkflowRunArchiver: + """Tests for the WorkflowRunArchiver class.""" + + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage") + def test_archiver_initialization(self, mock_get_storage, mock_config): + """Test archiver can be initialized with various options.""" + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + mock_config.BILLING_ENABLED = False + + archiver = WorkflowRunArchiver( + days=90, + batch_size=100, + tenant_ids=["test-tenant"], + limit=50, + dry_run=True, + ) + + assert archiver.days == 90 + assert archiver.batch_size == 100 + assert archiver.tenant_ids == ["test-tenant"] + assert archiver.limit == 50 + assert archiver.dry_run is True + + def test_get_archive_key(self): + """Test archive key generation.""" + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + archiver = WorkflowRunArchiver.__new__(WorkflowRunArchiver) + + mock_run = MagicMock() + mock_run.tenant_id = "tenant-123" + mock_run.app_id = "app-999" + mock_run.id = "run-456" + mock_run.created_at = datetime(2024, 1, 15, 12, 0, 0) + + key = archiver._get_archive_key(mock_run) + + assert key == f"tenant-123/app_id=app-999/year=2024/month=01/workflow_run_id=run-456/{ARCHIVE_BUNDLE_NAME}" diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index f50f744a75..eecb3c7672 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -171,22 +171,26 @@ class TestBillingServiceSendRequest: "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] ) def test_delete_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code): - """Test DELETE request with non-200 status code but valid JSON response. + """Test DELETE request with non-200 status code raises ValueError. - DELETE doesn't check status code, so it returns the error JSON. + DELETE now checks status code and raises ValueError for non-200 responses. """ # Arrange error_response = {"detail": "Error message"} mock_response = MagicMock() mock_response.status_code = status_code + mock_response.text = "Error message" mock_response.json.return_value = error_response mock_httpx_request.return_value = mock_response - # Act - result = BillingService._send_request("DELETE", "/test", json={"key": "value"}) - - # Assert - assert result == error_response + # Act & Assert + with patch("services.billing_service.logger") as mock_logger: + with pytest.raises(ValueError) as exc_info: + BillingService._send_request("DELETE", "/test", json={"key": "value"}) + assert "Unable to process delete request" in str(exc_info.value) + # Verify error logging + mock_logger.error.assert_called_once() + assert "DELETE response" in str(mock_logger.error.call_args) @pytest.mark.parametrize( "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] @@ -210,9 +214,9 @@ class TestBillingServiceSendRequest: "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] ) def test_delete_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code): - """Test DELETE request with non-200 status code and invalid JSON response raises exception. + """Test DELETE request with non-200 status code raises ValueError before JSON parsing. - DELETE doesn't check status code, so it calls response.json() which raises JSONDecodeError + DELETE now checks status code before calling response.json(), so ValueError is raised when the response cannot be parsed as JSON (e.g., empty response). """ # Arrange @@ -223,8 +227,13 @@ class TestBillingServiceSendRequest: mock_httpx_request.return_value = mock_response # Act & Assert - with pytest.raises(json.JSONDecodeError): - BillingService._send_request("DELETE", "/test", json={"key": "value"}) + with patch("services.billing_service.logger") as mock_logger: + with pytest.raises(ValueError) as exc_info: + BillingService._send_request("DELETE", "/test", json={"key": "value"}) + assert "Unable to process delete request" in str(exc_info.value) + # Verify error logging + mock_logger.error.assert_called_once() + assert "DELETE response" in str(mock_logger.error.call_args) def test_retry_on_request_error(self, mock_httpx_request, mock_billing_config): """Test that _send_request retries on httpx.RequestError.""" @@ -789,7 +798,7 @@ class TestBillingServiceAccountManagement: # Assert assert result == expected_response - mock_send_request.assert_called_once_with("DELETE", "/account/", params={"account_id": account_id}) + mock_send_request.assert_called_once_with("DELETE", "/account", params={"account_id": account_id}) def test_is_email_in_freeze_true(self, mock_send_request): """Test checking if email is frozen (returns True).""" @@ -1294,6 +1303,42 @@ class TestBillingServiceSubscriptionOperations: # Assert assert result == {} + def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request): + """Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant).""" + # Arrange + tenant_ids = ["tenant-valid-1", "tenant-invalid", "tenant-valid-2"] + + # Response with one invalid tenant plan (missing expiration_date) and two valid ones + mock_send_request.return_value = { + "data": { + "tenant-valid-1": {"plan": "sandbox", "expiration_date": 1735689600}, + "tenant-invalid": {"plan": "professional"}, # Missing expiration_date field + "tenant-valid-2": {"plan": "team", "expiration_date": 1767225600}, + } + } + + # Act + with patch("services.billing_service.logger") as mock_logger: + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should only contain valid tenants + assert len(result) == 2 + assert "tenant-valid-1" in result + assert "tenant-valid-2" in result + assert "tenant-invalid" not in result + + # Verify valid tenants have correct data + assert result["tenant-valid-1"]["plan"] == "sandbox" + assert result["tenant-valid-1"]["expiration_date"] == 1735689600 + assert result["tenant-valid-2"]["plan"] == "team" + assert result["tenant-valid-2"]["expiration_date"] == 1767225600 + + # Verify exception was logged for the invalid tenant + mock_logger.exception.assert_called_once() + log_call_args = mock_logger.exception.call_args[0] + assert "get_plan_bulk: failed to validate subscription plan for tenant" in log_call_args[0] + assert "tenant-invalid" in log_call_args[1] + def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request): """Test successful retrieval of expired subscription cleanup whitelist.""" # Arrange diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..8c80e2b4ad --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,327 @@ +import datetime +from typing import Any + +import pytest + +from services.billing_service import SubscriptionPlan +from services.retention.workflow_run import clear_free_plan_expired_workflow_run_logs as cleanup_module +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +class FakeRun: + def __init__( + self, + run_id: str, + tenant_id: str, + created_at: datetime.datetime, + app_id: str = "app-1", + workflow_id: str = "wf-1", + triggered_from: str = "workflow-run", + ) -> None: + self.id = run_id + self.tenant_id = tenant_id + self.app_id = app_id + self.workflow_id = workflow_id + self.triggered_from = triggered_from + self.created_at = created_at + + +class FakeRepo: + def __init__( + self, + batches: list[list[FakeRun]], + delete_result: dict[str, int] | None = None, + count_result: dict[str, int] | None = None, + ) -> None: + self.batches = batches + self.call_idx = 0 + self.deleted: list[list[str]] = [] + self.counted: list[list[str]] = [] + self.delete_result = delete_result or { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + self.count_result = count_result or { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + def get_runs_batch_by_time_range( + self, + start_from: datetime.datetime | None, + end_before: datetime.datetime, + last_seen: tuple[datetime.datetime, str] | None, + batch_size: int, + ) -> list[FakeRun]: + if self.call_idx >= len(self.batches): + return [] + batch = self.batches[self.call_idx] + self.call_idx += 1 + return batch + + def delete_runs_with_related( + self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None + ) -> dict[str, int]: + self.deleted.append([run.id for run in runs]) + result = self.delete_result.copy() + result["runs"] = len(runs) + return result + + def count_runs_with_related( + self, runs: list[FakeRun], count_node_executions=None, count_trigger_logs=None + ) -> dict[str, int]: + self.counted.append([run.id for run in runs]) + result = self.count_result.copy() + result["runs"] = len(runs) + return result + + +def plan_info(plan: str, expiration: int) -> SubscriptionPlan: + return SubscriptionPlan(plan=plan, expiration_date=expiration) + + +def create_cleanup( + monkeypatch: pytest.MonkeyPatch, + repo: FakeRepo, + *, + grace_period_days: int = 0, + whitelist: set[str] | None = None, + **kwargs: Any, +) -> WorkflowRunCleanup: + monkeypatch.setattr( + cleanup_module.dify_config, + "SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD", + grace_period_days, + ) + monkeypatch.setattr( + cleanup_module.WorkflowRunCleanup, + "_get_cleanup_whitelist", + lambda self: whitelist or set(), + ) + return WorkflowRunCleanup(workflow_run_repo=repo, **kwargs) + + +def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + def fail_bulk(_: list[str]) -> dict[str, SubscriptionPlan]: + raise RuntimeError("should not call") + + monkeypatch.setattr(cleanup_module.BillingService, "get_plan_bulk_with_cache", staticmethod(fail_bulk)) + + tenants = {"t1", "t2"} + free = cleanup._filter_free_tenants(tenants) + + assert free == tenants + + +def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"}) + + assert free == {"t_free", "t_missing"} + + +def test_filter_free_tenants_respects_grace_period(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, grace_period_days=45) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + now = datetime.datetime.now(datetime.UTC) + within_grace_ts = int((now - datetime.timedelta(days=10)).timestamp()) + outside_grace_ts = int((now - datetime.timedelta(days=90)).timestamp()) + + def fake_bulk(_: list[str]) -> dict[str, SubscriptionPlan]: + return { + "recently_downgraded": plan_info("sandbox", within_grace_ts), + "long_sandbox": plan_info("sandbox", outside_grace_ts), + } + + monkeypatch.setattr(cleanup_module.BillingService, "get_plan_bulk_with_cache", staticmethod(fake_bulk)) + + free = cleanup._filter_free_tenants({"recently_downgraded", "long_sandbox"}) + + assert free == {"long_sandbox"} + + +def test_filter_free_tenants_skips_cleanup_whitelist(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup( + monkeypatch, + repo=FakeRepo([]), + days=30, + batch_size=10, + whitelist={"tenant_whitelist"}, + ) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + tenants = {"tenant_whitelist", "tenant_regular"} + free = cleanup._filter_free_tenants(tenants) + + assert free == {"tenant_regular"} + + +def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod(lambda tenant_ids: (_ for _ in ()).throw(RuntimeError("boom"))), + ) + + free = cleanup._filter_free_tenants({"t1", "t2"}) + + assert free == set() + + +def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[ + [ + FakeRun("run-free", "t_free", cutoff), + FakeRun("run-paid", "t_paid", cutoff), + ] + ] + ) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + cleanup.run() + + assert repo.deleted == [["run-free"]] + + +def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo(batches=[[FakeRun("run-paid", "t_paid", cutoff)]]) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod(lambda tenant_ids: {tenant_id: plan_info("team", 1893456000) for tenant_id in tenant_ids}), + ) + + cleanup.run() + + assert repo.deleted == [] + + +def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + cleanup.run() + + +def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[[FakeRun("run-free", "t_free", cutoff)]], + count_result={ + "runs": 0, + "node_executions": 2, + "offloads": 1, + "app_logs": 3, + "trigger_logs": 4, + "pauses": 5, + "pause_reasons": 6, + }, + ) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10, dry_run=True) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + cleanup.run() + + assert repo.deleted == [] + assert repo.counted == [["run-free"]] + captured = capsys.readouterr().out + assert "Dry run mode enabled" in captured + assert "would delete 1 runs" in captured + assert "related records" in captured + assert "node_executions 2" in captured + assert "offloads 1" in captured + assert "app_logs 3" in captured + assert "trigger_logs 4" in captured + assert "pauses 5" in captured + assert "pause_reasons 6" in captured + + +def test_between_sets_window_bounds(monkeypatch: pytest.MonkeyPatch) -> None: + start_from = datetime.datetime(2024, 5, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 6, 1, 0, 0, 0) + cleanup = create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=start_from, end_before=end_before + ) + + assert cleanup.window_start == start_from + assert cleanup.window_end == end_before + + +def test_between_requires_both_boundaries(monkeypatch: pytest.MonkeyPatch) -> None: + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=datetime.datetime.now(), end_before=None + ) + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=None, end_before=datetime.datetime.now() + ) + + +def test_between_requires_end_after_start(monkeypatch: pytest.MonkeyPatch) -> None: + start_from = datetime.datetime(2024, 6, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 5, 1, 0, 0, 0) + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=start_from, end_before=end_before + ) diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py new file mode 100644 index 0000000000..2c9d946ea6 --- /dev/null +++ b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py @@ -0,0 +1,180 @@ +""" +Unit tests for archived workflow run deletion service. +""" + +from unittest.mock import MagicMock, patch + + +class TestArchivedWorkflowRunDeletion: + def test_delete_by_run_id_returns_error_when_run_missing(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + session = MagicMock() + session.get.return_value = None + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is False + assert result.error == "Workflow run run-1 not found" + repo.get_archived_run_ids.assert_not_called() + + def test_delete_by_run_id_returns_error_when_not_archived(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + repo.get_archived_run_ids.return_value = set() + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + session = MagicMock() + session.get.return_value = run + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object(deleter, "_delete_run") as mock_delete_run, + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is False + assert result.error == "Workflow run run-1 is not archived" + mock_delete_run.assert_not_called() + + def test_delete_by_run_id_calls_delete_run(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + repo.get_archived_run_ids.return_value = {"run-1"} + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + session = MagicMock() + session.get.return_value = run + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run, + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is True + mock_delete_run.assert_called_once_with(run) + + def test_delete_batch_uses_repo(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + run1 = MagicMock() + run1.id = "run-1" + run1.tenant_id = "tenant-1" + run2 = MagicMock() + run2.id = "run-2" + run2.tenant_id = "tenant-1" + repo.get_archived_runs_by_time_range.return_value = [run1, run2] + + session = MagicMock() + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + start_date = MagicMock() + end_date = MagicMock() + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object( + deleter, "_delete_run", side_effect=[MagicMock(success=True), MagicMock(success=True)] + ) as mock_delete_run, + ): + results = deleter.delete_batch( + tenant_ids=["tenant-1"], + start_date=start_date, + end_date=end_date, + limit=2, + ) + + assert len(results) == 2 + repo.get_archived_runs_by_time_range.assert_called_once_with( + session=session, + tenant_ids=["tenant-1"], + start_date=start_date, + end_date=end_date, + limit=2, + ) + assert mock_delete_run.call_count == 2 + + def test_delete_run_dry_run(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + result = deleter._delete_run(run) + + assert result.success is True + mock_get_repo.assert_not_called() + + def test_delete_run_calls_repo(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + repo = MagicMock() + repo.delete_runs_with_related.return_value = {"runs": 1} + + with patch.object(deleter, "_get_workflow_run_repo", return_value=repo): + result = deleter._delete_run(run) + + assert result.success is True + assert result.deleted_counts == {"runs": 1} + repo.delete_runs_with_related.assert_called_once() diff --git a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py b/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py new file mode 100644 index 0000000000..7b4d349e33 --- /dev/null +++ b/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py @@ -0,0 +1,99 @@ +""" +Unit tests for `services.file_service.FileService` helpers. + +We keep these tests focused on: +- ZIP tempfile building (sanitization + deduplication + content writes) +- tenant-scoped batch lookup behavior (`get_upload_files_by_ids`) +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from zipfile import ZipFile + +import pytest + +import services.file_service as file_service_module +from services.file_service import FileService + + +def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure ZIP entry names are safe and unique while preserving extensions.""" + + # Arrange: three upload files that all sanitize down to the same basename ("b.txt"). + upload_files: list[Any] = [ + SimpleNamespace(name="a/b.txt", key="k1"), + SimpleNamespace(name="c/b.txt", key="k2"), + SimpleNamespace(name="../b.txt", key="k3"), + ] + + # Stream distinct bytes per key so we can verify content is written to the right entry. + data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} + + def _load(key: str, stream: bool = True) -> list[bytes]: + # Return the corresponding chunks for this key (the production code iterates chunks). + assert stream is True + return data_by_key[key] + + monkeypatch.setattr(file_service_module.storage, "load", _load) + + # Act: build zip in a tempfile. + with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: + with ZipFile(tmp, mode="r") as zf: + # Assert: names are sanitized (no directory components) and deduped with suffixes. + assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] + + # Assert: each entry contains the correct bytes from storage. + assert zf.read("b.txt") == b"one" + assert zf.read("b (1).txt") == b"two" + assert zf.read("b (2).txt") == b"three" + + +def test_get_upload_files_by_ids_returns_empty_when_no_ids(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure empty input returns an empty mapping without hitting the database.""" + + class _Session: + def scalars(self, _stmt): # type: ignore[no-untyped-def] + raise AssertionError("db.session.scalars should not be called for empty id lists") + + monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=_Session())) + + assert FileService.get_upload_files_by_ids("tenant-1", []) == {} + + +def test_get_upload_files_by_ids_returns_id_keyed_mapping(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" + + upload_files: list[Any] = [ + SimpleNamespace(id="file-1", tenant_id="tenant-1"), + SimpleNamespace(id="file-2", tenant_id="tenant-1"), + ] + + class _ScalarResult: + def __init__(self, items: list[Any]) -> None: + self._items = items + + def all(self) -> list[Any]: + return self._items + + class _Session: + def __init__(self, items: list[Any]) -> None: + self._items = items + self.calls: list[object] = [] + + def scalars(self, stmt): # type: ignore[no-untyped-def] + # Capture the statement so we can at least assert the query path is taken. + self.calls.append(stmt) + return _ScalarResult(self._items) + + session = _Session(upload_files) + monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=session)) + + # Provide duplicates to ensure callers can safely pass repeated ids. + result = FileService.get_upload_files_by_ids("tenant-1", ["file-1", "file-1", "file-2"]) + + assert set(result.keys()) == {"file-1", "file-2"} + assert result["file-1"].id == "file-1" + assert result["file-2"].id == "file-2" + assert len(session.calls) == 1 diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py new file mode 100644 index 0000000000..3b619195c7 --- /dev/null +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -0,0 +1,627 @@ +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from enums.cloud_plan import CloudPlan +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + SimpleMessage, + create_message_clean_policy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage: + """Helper to create a SimpleMessage with a fixed created_at timestamp.""" + return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1)) + + +def make_plan_provider(tenant_plans: dict) -> MagicMock: + """Helper to create a mock plan_provider that returns the given tenant_plans.""" + provider = MagicMock() + provider.return_value = tenant_plans + return provider + + +class TestBillingSandboxPolicyFilterMessageIds: + """Unit tests for BillingSandboxPolicy.filter_message_ids method.""" + + # Fixed timestamp for deterministic tests + CURRENT_TIMESTAMP = 1000000 + GRACEFUL_PERIOD_DAYS = 8 + GRACEFUL_PERIOD_SECONDS = GRACEFUL_PERIOD_DAYS * 24 * 60 * 60 + + def test_missing_tenant_mapping_excluded(self): + """Test that messages with missing app-to-tenant mapping are excluded.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {} # No mapping + tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}} + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_missing_tenant_plan_excluded(self): + """Test that messages with missing tenant plan are excluded (safe default).""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = {} # No plans + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_non_sandbox_plan_excluded(self): + """Test that messages from non-sandbox plans (PROFESSIONAL/TEAM) are excluded.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.TEAM, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, # Only this one + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg3 (sandbox tenant) should be included + assert set(result) == {"msg3"} + + def test_whitelist_skip(self): + """Test that whitelisted tenants are excluded even if sandbox + expired.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), # Whitelisted - excluded + make_simple_message("msg2", "app2"), # Not whitelisted - included + make_simple_message("msg3", "app3"), # Whitelisted - excluded + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + plan_provider = make_plan_provider(tenant_plans) + tenant_whitelist = ["tenant1", "tenant3"] + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + tenant_whitelist=tenant_whitelist, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg2 should be included + assert set(result) == {"msg2"} + + def test_no_previous_subscription_included(self): + """Test that messages with expiration_date=-1 (no previous subscription) are included.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all messages should be included + assert set(result) == {"msg1", "msg2"} + + def test_within_grace_period_excluded(self): + """Test that messages within grace period are excluded.""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_1_day_ago = now - (1 * 24 * 60 * 60) + expired_5_days_ago = now - (5 * 24 * 60 * 60) + expired_7_days_ago = now - (7 * 24 * 60 * 60) + + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, # 8 days + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all within 8-day grace period, none should be included + assert list(result) == [] + + def test_exactly_at_boundary_excluded(self): + """Test that messages exactly at grace period boundary are excluded (code uses >).""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_exactly_8_days_ago = now - self.GRACEFUL_PERIOD_SECONDS # Exactly at boundary + + messages = [make_simple_message("msg1", "app1")] + app_to_tenant = {"app1": "tenant1"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - exactly at boundary (==) should be excluded (code uses >) + assert list(result) == [] + + def test_beyond_grace_period_included(self): + """Test that messages beyond grace period are included.""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond 8-day grace + expired_30_days_ago = now - (30 * 24 * 60 * 60) # Well beyond + + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - both beyond grace period, should be included + assert set(result) == {"msg1", "msg2"} + + def test_empty_messages_returns_empty(self): + """Test that empty messages returns empty list.""" + # Arrange + messages: list[SimpleMessage] = [] + app_to_tenant = {"app1": "tenant1"} + plan_provider = make_plan_provider({"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_plan_provider_called_with_correct_tenant_ids(self): + """Test that plan_provider is called with correct tenant_ids.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant1"} # tenant1 appears twice + plan_provider = make_plan_provider({}) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + policy.filter_message_ids(messages, app_to_tenant) + + # Assert - plan_provider should be called once with unique tenant_ids + plan_provider.assert_called_once() + called_tenant_ids = set(plan_provider.call_args[0][0]) + assert called_tenant_ids == {"tenant1", "tenant2"} + + def test_complex_mixed_scenario(self): + """Test complex scenario with mixed plans, expirations, whitelist, and missing mappings.""" + # Arrange + now = self.CURRENT_TIMESTAMP + sandbox_expired_old = now - (15 * 24 * 60 * 60) # Beyond grace + sandbox_expired_recent = now - (3 * 24 * 60 * 60) # Within grace + future_expiration = now + (30 * 24 * 60 * 60) + + messages = [ + make_simple_message("msg1", "app1"), # Sandbox, no subscription - included + make_simple_message("msg2", "app2"), # Sandbox, expired old - included + make_simple_message("msg3", "app3"), # Sandbox, within grace - excluded + make_simple_message("msg4", "app4"), # Team plan, active - excluded + make_simple_message("msg5", "app5"), # No tenant mapping - excluded + make_simple_message("msg6", "app6"), # No plan info - excluded + make_simple_message("msg7", "app7"), # Sandbox, expired old, whitelisted - excluded + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + "app4": "tenant4", + "app6": "tenant6", # Has mapping but no plan + "app7": "tenant7", + # app5 has no mapping + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent}, + "tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration}, + "tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + # tenant6 has no plan + } + plan_provider = make_plan_provider(tenant_plans) + tenant_whitelist = ["tenant7"] + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + tenant_whitelist=tenant_whitelist, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg1 and msg2 should be included + assert set(result) == {"msg1", "msg2"} + + +class TestBillingDisabledPolicyFilterMessageIds: + """Unit tests for BillingDisabledPolicy.filter_message_ids method.""" + + def test_returns_all_message_ids(self): + """Test that all message IDs are returned (order-preserving).""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all message IDs returned in order + assert list(result) == ["msg1", "msg2", "msg3"] + + def test_ignores_app_to_tenant(self): + """Test that app_to_tenant mapping is ignored.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant: dict[str, str] = {} # Empty - should be ignored + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all message IDs still returned + assert list(result) == ["msg1", "msg2"] + + def test_empty_messages_returns_empty(self): + """Test that empty messages returns empty list.""" + # Arrange + messages: list[SimpleMessage] = [] + app_to_tenant = {"app1": "tenant1"} + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + +class TestCreateMessageCleanPolicy: + """Unit tests for create_message_clean_policy factory function.""" + + @patch("services.retention.conversation.messages_clean_policy.dify_config") + def test_billing_disabled_returns_billing_disabled_policy(self, mock_config): + """Test that BILLING_ENABLED=False returns BillingDisabledPolicy.""" + # Arrange + mock_config.BILLING_ENABLED = False + + # Act + policy = create_message_clean_policy(graceful_period_days=21) + + # Assert + assert isinstance(policy, BillingDisabledPolicy) + + @patch("services.retention.conversation.messages_clean_policy.BillingService") + @patch("services.retention.conversation.messages_clean_policy.dify_config") + def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service): + """Test that BillingSandboxPolicy is created with correct internal values.""" + # Arrange + mock_config.BILLING_ENABLED = True + whitelist = ["tenant1", "tenant2"] + mock_billing_service.get_expired_subscription_cleanup_whitelist.return_value = whitelist + mock_plan_provider = MagicMock() + mock_billing_service.get_plan_bulk_with_cache = mock_plan_provider + + # Act + policy = create_message_clean_policy(graceful_period_days=14, current_timestamp=1234567) + + # Assert + mock_billing_service.get_expired_subscription_cleanup_whitelist.assert_called_once() + assert isinstance(policy, BillingSandboxPolicy) + assert policy._graceful_period_days == 14 + assert list(policy._tenant_whitelist) == whitelist + assert policy._plan_provider == mock_plan_provider + assert policy._current_timestamp == 1234567 + + +class TestMessagesCleanServiceFromTimeRange: + """Unit tests for MessagesCleanService.from_time_range factory method.""" + + def test_start_from_end_before_raises_value_error(self): + """Test that start_from == end_before raises ValueError.""" + policy = BillingDisabledPolicy() + + # Arrange + same_time = datetime.datetime(2024, 1, 1, 12, 0, 0) + + # Act & Assert + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=same_time, + end_before=same_time, + ) + + # Arrange + start_from = datetime.datetime(2024, 12, 31) + end_before = datetime.datetime(2024, 1, 1) + + # Act & Assert + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + ) + + def test_batch_size_raises_value_error(self): + """Test that batch_size=0 raises ValueError.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=0, + ) + + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=-100, + ) + + def test_valid_params_creates_instance(self): + """Test that valid parameters create a correctly configured instance.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 12, 31, 23, 59, 59) + policy = BillingDisabledPolicy() + batch_size = 500 + dry_run = True + + # Act + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert + assert isinstance(service, MessagesCleanService) + assert service._policy is policy + assert service._start_from == start_from + assert service._end_before == end_before + assert service._batch_size == batch_size + assert service._dry_run == dry_run + + def test_default_params(self): + """Test that default parameters are applied correctly.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + ) + + # Assert + assert service._batch_size == 1000 # default + assert service._dry_run is False # default + + +class TestMessagesCleanServiceFromDays: + """Unit tests for MessagesCleanService.from_days factory method.""" + + def test_days_raises_value_error(self): + """Test that days < 0 raises ValueError.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(policy=policy, days=-1) + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days(policy=policy, days=0) + + # Assert + assert service._end_before == fixed_now + + def test_batch_size_raises_value_error(self): + """Test that batch_size=0 raises ValueError.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy=policy, days=30, batch_size=0) + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy=policy, days=30, batch_size=-500) + + def test_valid_params_creates_instance(self): + """Test that valid parameters create a correctly configured instance.""" + # Arrange + policy = BillingDisabledPolicy() + days = 90 + batch_size = 500 + dry_run = True + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days( + policy=policy, + days=days, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=days) + assert isinstance(service, MessagesCleanService) + assert service._policy is policy + assert service._start_from is None + assert service._end_before == expected_end_before + assert service._batch_size == batch_size + assert service._dry_run == dry_run + + def test_default_params(self): + """Test that default parameters are applied correctly.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days(policy=policy) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30 + assert service._end_before == expected_end_before + assert service._batch_size == 1000 # default + assert service._dry_run is False # default diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 9a107da1c7..e2360b116d 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -27,7 +27,6 @@ def service_with_fake_configurations(): description=None, icon_small=None, icon_small_dark=None, - icon_large=None, background=None, help=None, supported_model_types=[ModelType.LLM], diff --git a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py new file mode 100644 index 0000000000..68aa8c0fe1 --- /dev/null +++ b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py @@ -0,0 +1,65 @@ +""" +Unit tests for workflow run restore functionality. +""" + +from datetime import datetime +from unittest.mock import MagicMock + + +class TestWorkflowRunRestore: + """Tests for the WorkflowRunRestore class.""" + + def test_restore_initialization(self): + """Restore service should respect dry_run flag.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + restore = WorkflowRunRestore(dry_run=True) + + assert restore.dry_run is True + + def test_convert_datetime_fields(self): + """ISO datetime strings should be converted to datetime objects.""" + from models.workflow import WorkflowRun + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + record = { + "id": "test-id", + "created_at": "2024-01-01T12:00:00", + "finished_at": "2024-01-01T12:05:00", + "name": "test", + } + + restore = WorkflowRunRestore() + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["created_at"].month == 1 + assert result["name"] == "test" + + def test_restore_table_records_returns_rowcount(self): + """Restore should return inserted rowcount.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + session = MagicMock() + session.execute.return_value = MagicMock(rowcount=2) + + restore = WorkflowRunRestore() + records = [{"id": "p1", "workflow_run_id": "r1", "created_at": "2024-01-01T00:00:00"}] + + restored = restore._restore_table_records(session, "workflow_pauses", records, schema_version="1.0") + + assert restored == 2 + session.execute.assert_called_once() + + def test_restore_table_records_unknown_table(self): + """Unknown table names should be ignored gracefully.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + session = MagicMock() + + restore = WorkflowRunRestore() + restored = restore._restore_table_records(session, "unknown_table", [{"id": "x1"}], schema_version="1.0") + + assert restored == 0 + session.execute.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index bace66bec4..cb18d15084 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -49,10 +49,14 @@ def pipeline_id(): @pytest.fixture def mock_db_session(): - """Mock database session with query capabilities.""" - with patch("tasks.clean_dataset_task.db") as mock_db: + """Mock database session via session_factory.create_session().""" + with patch("tasks.clean_dataset_task.session_factory") as mock_sf: mock_session = MagicMock() - mock_db.session = mock_session + # context manager for create_session() + cm = MagicMock() + cm.__enter__.return_value = mock_session + cm.__exit__.return_value = None + mock_sf.create_session.return_value = cm # Setup query chain mock_query = MagicMock() @@ -66,7 +70,10 @@ def mock_db_session(): # Setup execute for JOIN queries mock_session.execute.return_value.all.return_value = [] - yield mock_db + # Yield an object with a `.session` attribute to keep tests unchanged + wrapper = MagicMock() + wrapper.session = mock_session + yield wrapper @pytest.fixture @@ -227,7 +234,9 @@ class TestBasicCleanup: # Assert mock_db_session.session.delete.assert_any_call(mock_document) - mock_db_session.session.delete.assert_any_call(mock_segment) + # Segments are deleted in batch; verify a DELETE on document_segments was issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) mock_db_session.session.commit.assert_called_once() def test_clean_dataset_task_deletes_related_records( @@ -413,7 +422,9 @@ class TestErrorHandling: # Assert - documents and segments should still be deleted mock_db_session.session.delete.assert_any_call(mock_document) - mock_db_session.session.delete.assert_any_call(mock_segment) + # Segments are deleted in batch; verify a DELETE on document_segments was issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) mock_db_session.session.commit.assert_called_once() def test_clean_dataset_task_storage_delete_failure_continues( @@ -461,7 +472,7 @@ class TestErrorHandling: [mock_segment], # segments ] mock_get_image_upload_file_ids.return_value = [image_file_id] - mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file] mock_storage.delete.side_effect = Exception("Storage service unavailable") # Act @@ -476,8 +487,9 @@ class TestErrorHandling: # Assert - storage delete was attempted for image file mock_storage.delete.assert_called_with(mock_upload_file.key) - # Image file should still be deleted from database - mock_db_session.session.delete.assert_any_call(mock_upload_file) + # Upload files are deleted in batch; verify a DELETE on upload_files was issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM upload_files" in sql for sql in execute_sqls) def test_clean_dataset_task_database_error_rollback( self, @@ -691,8 +703,10 @@ class TestSegmentAttachmentCleanup: # Assert mock_storage.delete.assert_called_with(mock_attachment_file.key) - mock_db_session.session.delete.assert_any_call(mock_attachment_file) - mock_db_session.session.delete.assert_any_call(mock_binding) + # Attachment file and binding are deleted in batch; verify DELETEs were issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM upload_files" in sql for sql in execute_sqls) + assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls) def test_clean_dataset_task_attachment_storage_failure( self, @@ -734,9 +748,10 @@ class TestSegmentAttachmentCleanup: # Assert - storage delete was attempted mock_storage.delete.assert_called_once() - # Records should still be deleted from database - mock_db_session.session.delete.assert_any_call(mock_attachment_file) - mock_db_session.session.delete.assert_any_call(mock_binding) + # Records are deleted in batch; verify DELETEs were issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM upload_files" in sql for sql in execute_sqls) + assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls) # ============================================================================ @@ -784,7 +799,7 @@ class TestUploadFileCleanup: [mock_document], # documents [], # segments ] - mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file] # Act clean_dataset_task( @@ -798,7 +813,9 @@ class TestUploadFileCleanup: # Assert mock_storage.delete.assert_called_with(mock_upload_file.key) - mock_db_session.session.delete.assert_any_call(mock_upload_file) + # Upload files are deleted in batch; verify a DELETE on upload_files was issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM upload_files" in sql for sql in execute_sqls) def test_clean_dataset_task_handles_missing_upload_file( self, @@ -832,7 +849,7 @@ class TestUploadFileCleanup: [mock_document], # documents [], # segments ] - mock_db_session.session.query.return_value.where.return_value.first.return_value = None + mock_db_session.session.query.return_value.where.return_value.all.return_value = [] # Act - should not raise exception clean_dataset_task( @@ -949,11 +966,11 @@ class TestImageFileCleanup: [mock_segment], # segments ] - # Setup a mock query chain that returns files in sequence + # Setup a mock query chain that returns files in batch (align with .in_().all()) mock_query = MagicMock() mock_where = MagicMock() mock_query.where.return_value = mock_where - mock_where.first.side_effect = mock_image_files + mock_where.all.return_value = mock_image_files mock_db_session.session.query.return_value = mock_query # Act @@ -966,10 +983,10 @@ class TestImageFileCleanup: doc_form="paragraph_index", ) - # Assert - assert mock_storage.delete.call_count == 2 - mock_storage.delete.assert_any_call("images/image-1.jpg") - mock_storage.delete.assert_any_call("images/image-2.jpg") + # Assert - each expected image key was deleted at least once + calls = [c.args[0] for c in mock_storage.delete.call_args_list] + assert "images/image-1.jpg" in calls + assert "images/image-2.jpg" in calls def test_clean_dataset_task_handles_missing_image_file( self, @@ -1010,7 +1027,7 @@ class TestImageFileCleanup: ] # Image file not found - mock_db_session.session.query.return_value.where.return_value.first.return_value = None + mock_db_session.session.query.return_value.where.return_value.all.return_value = [] # Act - should not raise exception clean_dataset_task( @@ -1086,14 +1103,15 @@ class TestEdgeCases: doc_form="paragraph_index", ) - # Assert - all documents and segments should be deleted + # Assert - all documents and segments should be deleted (documents per-entity, segments in batch) delete_calls = mock_db_session.session.delete.call_args_list deleted_items = [call[0][0] for call in delete_calls] for doc in mock_documents: assert doc in deleted_items - for seg in mock_segments: - assert seg in deleted_items + # Verify a batch DELETE on document_segments occurred + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) def test_clean_dataset_task_document_with_empty_data_source_info( self, diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 9d7599b8fe..e24ef32a24 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -81,12 +81,25 @@ def mock_documents(document_ids, dataset_id): @pytest.fixture def mock_db_session(): - """Mock database session.""" - with patch("tasks.document_indexing_task.db.session") as mock_session: - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - yield mock_session + """Mock database session via session_factory.create_session().""" + with patch("tasks.document_indexing_task.session_factory") as mock_sf: + session = MagicMock() + # Ensure tests that expect session.close() to be called can observe it via the context manager + session.close = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + # Link __exit__ to session.close so "close" expectations reflect context manager teardown + + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + mock_sf.create_session.return_value = cm + + query = MagicMock() + session.query.return_value = query + query.where.return_value = query + yield session @pytest.fixture diff --git a/api/tests/unit_tests/tasks/test_delete_account_task.py b/api/tests/unit_tests/tasks/test_delete_account_task.py index 3b148e63f2..8a12a4a169 100644 --- a/api/tests/unit_tests/tasks/test_delete_account_task.py +++ b/api/tests/unit_tests/tasks/test_delete_account_task.py @@ -18,12 +18,18 @@ from tasks.delete_account_task import delete_account_task @pytest.fixture def mock_db_session(): - """Mock the db.session used in delete_account_task.""" - with patch("tasks.delete_account_task.db.session") as mock_session: - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - yield mock_session + """Mock session via session_factory.create_session().""" + with patch("tasks.delete_account_task.session_factory") as mock_sf: + session = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + mock_sf.create_session.return_value = cm + + query = MagicMock() + session.query.return_value = query + query.where.return_value = query + yield session @pytest.fixture diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 374abe0368..fa33034f40 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -109,13 +109,25 @@ def mock_document_segments(document_id): @pytest.fixture def mock_db_session(): - """Mock database session.""" - with patch("tasks.document_indexing_sync_task.db.session") as mock_session: - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_session.scalars.return_value = MagicMock() - yield mock_session + """Mock database session via session_factory.create_session().""" + with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf: + session = MagicMock() + # Ensure tests can observe session.close() via context manager teardown + session.close = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + mock_sf.create_session.return_value = cm + + query = MagicMock() + session.query.return_value = query + query.where.return_value = query + session.scalars.return_value = MagicMock() + yield session @pytest.fixture @@ -251,8 +263,8 @@ class TestDocumentIndexingSyncTask: # Assert # Document status should remain unchanged assert mock_document.indexing_status == "completed" - # No session operations should be performed beyond the initial query - mock_db_session.close.assert_not_called() + # Session should still be closed via context manager teardown + assert mock_db_session.close.called def test_successful_sync_when_page_updated( self, @@ -286,9 +298,9 @@ class TestDocumentIndexingSyncTask: mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_processor.clean.assert_called_once() - # Verify segments were deleted from database - for segment in mock_document_segments: - mock_db_session.delete.assert_any_call(segment) + # Verify segments were deleted from database in batch (DELETE FROM document_segments) + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) # Verify indexing runner was called mock_indexing_runner.run.assert_called_once_with([mock_document]) diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py index 0be6ea045e..8a4c6da2e9 100644 --- a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -94,13 +94,25 @@ def mock_document_segments(document_ids): @pytest.fixture def mock_db_session(): - """Mock database session.""" - with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session: - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_session.scalars.return_value = MagicMock() - yield mock_session + """Mock database session via session_factory.create_session().""" + with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf: + session = MagicMock() + # Allow tests to observe session.close() via context manager teardown + session.close = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + mock_sf.create_session.return_value = cm + + query = MagicMock() + session.query.return_value = query + query.where.return_value = query + session.scalars.return_value = MagicMock() + yield session @pytest.fixture @@ -200,8 +212,25 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test successful duplicate document indexing flow.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = mock_document_segments + # Dataset via query.first() + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + # scalars() call sequence: + # 1) documents list + # 2..N) segments per document + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + # First call returns documents; subsequent calls return segments + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = mock_document_segments + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect # Act _duplicate_document_indexing_task(dataset_id, document_ids) @@ -264,8 +293,21 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test duplicate document indexing when billing limit is exceeded.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + # First scalars() -> documents; subsequent -> empty segments + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = [] + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect mock_features = mock_feature_service.get_features.return_value mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.TEAM @@ -294,8 +336,20 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test duplicate document indexing when IndexingRunner raises an error.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = [] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = [] + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect mock_indexing_runner.run.side_effect = Exception("Indexing error") # Act @@ -318,8 +372,20 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test duplicate document indexing when document is paused.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = [] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = [] + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") # Act @@ -343,8 +409,20 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test that duplicate document indexing cleans old segments.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = mock_document_segments + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value # Act @@ -354,9 +432,9 @@ class TestDuplicateDocumentIndexingTaskCore: # Verify clean was called for each document assert mock_processor.clean.call_count == len(mock_documents) - # Verify segments were deleted - for segment in mock_document_segments: - mock_db_session.delete.assert_any_call(segment) + # Verify segments were deleted in batch (DELETE FROM document_segments) + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) # ============================================================================ diff --git a/api/tests/unit_tests/tasks/test_mail_send_task.py b/api/tests/unit_tests/tasks/test_mail_send_task.py index 736871d784..5cb933e3f3 100644 --- a/api/tests/unit_tests/tasks/test_mail_send_task.py +++ b/api/tests/unit_tests/tasks/test_mail_send_task.py @@ -9,7 +9,7 @@ This module tests the mail sending functionality including: """ import smtplib -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest @@ -151,7 +151,7 @@ class TestSMTPIntegration: client.send(mail_data) # Assert - mock_smtp_ssl.assert_called_once_with("smtp.example.com", 465, timeout=10) + mock_smtp_ssl.assert_called_once_with("smtp.example.com", 465, timeout=10, local_hostname=ANY) mock_server.login.assert_called_once_with("user@example.com", "password123") mock_server.sendmail.assert_called_once() mock_server.quit.assert_called_once() @@ -181,7 +181,7 @@ class TestSMTPIntegration: client.send(mail_data) # Assert - mock_smtp.assert_called_once_with("smtp.example.com", 587, timeout=10) + mock_smtp.assert_called_once_with("smtp.example.com", 587, timeout=10, local_hostname=ANY) mock_server.ehlo.assert_called() mock_server.starttls.assert_called_once() assert mock_server.ehlo.call_count == 2 # Before and after STARTTLS @@ -213,7 +213,7 @@ class TestSMTPIntegration: client.send(mail_data) # Assert - mock_smtp.assert_called_once_with("smtp.example.com", 25, timeout=10) + mock_smtp.assert_called_once_with("smtp.example.com", 25, timeout=10, local_hostname=ANY) mock_server.login.assert_called_once() mock_server.sendmail.assert_called_once() mock_server.quit.assert_called_once() diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index 1fe77c2935..a14bbb01d0 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,7 +2,11 @@ from unittest.mock import ANY, MagicMock, call, patch import pytest +from libs.archive_storage import ArchiveStorageNotConfiguredError +from models.workflow import WorkflowArchiveLog from tasks.remove_app_and_related_data_task import ( + _delete_app_workflow_archive_logs, + _delete_archived_workflow_run_files, _delete_draft_variable_offload_data, _delete_draft_variables, delete_draft_variables_batch, @@ -11,21 +15,18 @@ from tasks.remove_app_and_related_data_task import ( class TestDeleteDraftVariablesBatch: @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") - @patch("tasks.remove_app_and_related_data_task.db") - def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup): + @patch("tasks.remove_app_and_related_data_task.session_factory") + def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup): """Test successful deletion of draft variables in batches.""" app_id = "test-app-id" batch_size = 100 - # Mock database connection and engine - mock_conn = MagicMock() - mock_engine = MagicMock() - mock_db.engine = mock_engine - # Properly mock the context manager + # Mock session via session_factory + mock_session = MagicMock() mock_context_manager = MagicMock() - mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__enter__.return_value = mock_session mock_context_manager.__exit__.return_value = None - mock_engine.begin.return_value = mock_context_manager + mock_sf.create_session.return_value = mock_context_manager # Mock two batches of results, then empty batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)] @@ -68,7 +69,7 @@ class TestDeleteDraftVariablesBatch: select_result3.__iter__.return_value = iter([]) # Configure side effects in the correct order - mock_conn.execute.side_effect = [ + mock_session.execute.side_effect = [ select_result1, # First SELECT delete_result1, # First DELETE select_result2, # Second SELECT @@ -86,54 +87,49 @@ class TestDeleteDraftVariablesBatch: assert result == 150 # Verify database calls - assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes + assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes # Verify offload cleanup was called for both batches with file_ids - expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)] + expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)] mock_offload_cleanup.assert_has_calls(expected_offload_calls) # Simplified verification - check that the right number of calls were made # and that the SQL queries contain the expected patterns - actual_calls = mock_conn.execute.call_args_list + actual_calls = mock_session.execute.call_args_list for i, actual_call in enumerate(actual_calls): + sql_text = str(actual_call[0][0]) + normalized = " ".join(sql_text.split()) if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4) - # Verify it's a SELECT query that now includes file_id - sql_text = str(actual_call[0][0]) - assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text - assert "WHERE app_id = :app_id" in sql_text - assert "LIMIT :batch_size" in sql_text + assert "SELECT id, file_id FROM workflow_draft_variables" in normalized + assert "WHERE app_id = :app_id" in normalized + assert "LIMIT :batch_size" in normalized else: # DELETE calls (odd indices: 1, 3) - # Verify it's a DELETE query - sql_text = str(actual_call[0][0]) - assert "DELETE FROM workflow_draft_variables" in sql_text - assert "WHERE id IN :ids" in sql_text + assert "DELETE FROM workflow_draft_variables" in normalized + assert "WHERE id IN :ids" in normalized @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") - @patch("tasks.remove_app_and_related_data_task.db") - def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup): + @patch("tasks.remove_app_and_related_data_task.session_factory") + def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup): """Test deletion when no draft variables exist for the app.""" app_id = "nonexistent-app-id" batch_size = 1000 - # Mock database connection - mock_conn = MagicMock() - mock_engine = MagicMock() - mock_db.engine = mock_engine - # Properly mock the context manager + # Mock session via session_factory + mock_session = MagicMock() mock_context_manager = MagicMock() - mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__enter__.return_value = mock_session mock_context_manager.__exit__.return_value = None - mock_engine.begin.return_value = mock_context_manager + mock_sf.create_session.return_value = mock_context_manager # Mock empty result empty_result = MagicMock() empty_result.__iter__.return_value = iter([]) - mock_conn.execute.return_value = empty_result + mock_session.execute.return_value = empty_result result = delete_draft_variables_batch(app_id, batch_size) assert result == 0 - assert mock_conn.execute.call_count == 1 # Only one select query + assert mock_session.execute.call_count == 1 # Only one select query mock_offload_cleanup.assert_not_called() # No files to clean up def test_delete_draft_variables_batch_invalid_batch_size(self): @@ -147,22 +143,19 @@ class TestDeleteDraftVariablesBatch: delete_draft_variables_batch(app_id, 0) @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") - @patch("tasks.remove_app_and_related_data_task.db") + @patch("tasks.remove_app_and_related_data_task.session_factory") @patch("tasks.remove_app_and_related_data_task.logger") - def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup): + def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup): """Test that batch deletion logs progress correctly.""" app_id = "test-app-id" batch_size = 50 - # Mock database - mock_conn = MagicMock() - mock_engine = MagicMock() - mock_db.engine = mock_engine - # Properly mock the context manager + # Mock session via session_factory + mock_session = MagicMock() mock_context_manager = MagicMock() - mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__enter__.return_value = mock_session mock_context_manager.__exit__.return_value = None - mock_engine.begin.return_value = mock_context_manager + mock_sf.create_session.return_value = mock_context_manager # Mock one batch then empty batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)] @@ -183,7 +176,7 @@ class TestDeleteDraftVariablesBatch: empty_result = MagicMock() empty_result.__iter__.return_value = iter([]) - mock_conn.execute.side_effect = [ + mock_session.execute.side_effect = [ # Select query result select_result, # Delete query result @@ -201,7 +194,7 @@ class TestDeleteDraftVariablesBatch: # Verify offload cleanup was called with file_ids if batch_file_ids: - mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids) + mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids) # Verify logging calls assert mock_logging.info.call_count == 2 @@ -261,19 +254,19 @@ class TestDeleteDraftVariableOffloadData: actual_calls = mock_conn.execute.call_args_list # First call should be the SELECT query - select_call_sql = str(actual_calls[0][0][0]) + select_call_sql = " ".join(str(actual_calls[0][0][0]).split()) assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql assert "FROM workflow_draft_variable_files wdvf" in select_call_sql assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql assert "WHERE wdvf.id IN :file_ids" in select_call_sql # Second call should be DELETE upload_files - delete_upload_call_sql = str(actual_calls[1][0][0]) + delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split()) assert "DELETE FROM upload_files" in delete_upload_call_sql assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql # Third call should be DELETE workflow_draft_variable_files - delete_variable_files_call_sql = str(actual_calls[2][0][0]) + delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split()) assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql assert "WHERE id IN :file_ids" in delete_variable_files_call_sql @@ -335,3 +328,68 @@ class TestDeleteDraftVariableOffloadData: # Verify error was logged mock_logging.exception.assert_called_once_with("Error deleting draft variable offload data:") + + +class TestDeleteWorkflowArchiveLogs: + @patch("tasks.remove_app_and_related_data_task._delete_records") + @patch("tasks.remove_app_and_related_data_task.db") + def test_delete_app_workflow_archive_logs_calls_delete_records(self, mock_db, mock_delete_records): + tenant_id = "tenant-1" + app_id = "app-1" + + _delete_app_workflow_archive_logs(tenant_id, app_id) + + mock_delete_records.assert_called_once() + query_sql, params, delete_func, name = mock_delete_records.call_args[0] + assert "workflow_archive_logs" in query_sql + assert params == {"tenant_id": tenant_id, "app_id": app_id} + assert name == "workflow archive log" + + mock_query = MagicMock() + mock_delete_query = MagicMock() + mock_query.where.return_value = mock_delete_query + mock_db.session.query.return_value = mock_query + + delete_func("log-1") + + mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) + mock_query.where.assert_called_once() + mock_delete_query.delete.assert_called_once_with(synchronize_session=False) + + +class TestDeleteArchivedWorkflowRunFiles: + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_not_configured(self, mock_logger, mock_get_storage): + mock_get_storage.side_effect = ArchiveStorageNotConfiguredError("missing config") + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + assert mock_logger.info.call_count == 1 + assert "Archive storage not configured" in mock_logger.info.call_args[0][0] + + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_list_failure(self, mock_logger, mock_get_storage): + storage = MagicMock() + storage.list_objects.side_effect = Exception("list failed") + mock_get_storage.return_value = storage + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + storage.list_objects.assert_called_once_with("tenant-1/app_id=app-1/") + storage.delete_object.assert_not_called() + mock_logger.exception.assert_called_once_with("Failed to list archive files for app %s", "app-1") + + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_success(self, mock_logger, mock_get_storage): + storage = MagicMock() + storage.list_objects.return_value = ["key-1", "key-2"] + mock_get_storage.return_value = storage + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + storage.list_objects.assert_called_once_with("tenant-1/app_id=app-1/") + storage.delete_object.assert_has_calls([call("key-1"), call("key-2")], any_order=False) + mock_logger.info.assert_called_with("Deleted %s archive objects for app %s", 2, "app-1") diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py new file mode 100644 index 0000000000..a527773e4e --- /dev/null +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -0,0 +1,122 @@ +import base64 +from unittest.mock import Mock, patch + +import pytest + +from core.mcp.types import ( + AudioContent, + BlobResourceContents, + CallToolResult, + EmbeddedResource, + ImageContent, + TextResourceContents, +) +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.mcp_tool.tool import MCPTool + + +def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: + identity = ToolIdentity( + author="test", + name="test_mcp_tool", + label=I18nObject(en_US="Test MCP Tool", zh_Hans="测试MCP工具"), + provider="test_provider", + ) + entity = ToolEntity(identity=identity, output_schema=output_schema or {}) + runtime = Mock(spec=ToolRuntime) + runtime.credentials = {} + return MCPTool( + entity=entity, + runtime=runtime, + tenant_id="test_tenant", + icon="", + server_url="https://server.invalid", + provider_id="provider_1", + headers={}, + ) + + +class TestMCPToolInvoke: + @pytest.mark.parametrize( + ("content_factory", "mime_type"), + [ + ( + lambda b64, mt: ImageContent(type="image", data=b64, mimeType=mt), + "image/png", + ), + ( + lambda b64, mt: AudioContent(type="audio", data=b64, mimeType=mt), + "audio/mpeg", + ), + ], + ) + def test_invoke_image_or_audio_yields_blob(self, content_factory, mime_type) -> None: + tool = _make_mcp_tool() + raw = b"\x00\x01test-bytes\x02" + b64 = base64.b64encode(raw).decode() + content = content_factory(b64, mime_type) + result = CallToolResult(content=[content]) + + with patch.object(tool, "invoke_remote_mcp_tool", return_value=result): + messages = list(tool._invoke(user_id="test_user", tool_parameters={})) + + assert len(messages) == 1 + msg = messages[0] + assert msg.type == ToolInvokeMessage.MessageType.BLOB + assert isinstance(msg.message, ToolInvokeMessage.BlobMessage) + assert msg.message.blob == raw + assert msg.meta == {"mime_type": mime_type} + + def test_invoke_embedded_text_resource_yields_text(self) -> None: + tool = _make_mcp_tool() + text_resource = TextResourceContents(uri="file://test.txt", mimeType="text/plain", text="hello world") + content = EmbeddedResource(type="resource", resource=text_resource) + result = CallToolResult(content=[content]) + + with patch.object(tool, "invoke_remote_mcp_tool", return_value=result): + messages = list(tool._invoke(user_id="test_user", tool_parameters={})) + + assert len(messages) == 1 + msg = messages[0] + assert msg.type == ToolInvokeMessage.MessageType.TEXT + assert isinstance(msg.message, ToolInvokeMessage.TextMessage) + assert msg.message.text == "hello world" + + @pytest.mark.parametrize( + ("mime_type", "expected_mime"), + [("application/pdf", "application/pdf"), (None, "application/octet-stream")], + ) + def test_invoke_embedded_blob_resource_yields_blob(self, mime_type, expected_mime) -> None: + tool = _make_mcp_tool() + raw = b"binary-data" + b64 = base64.b64encode(raw).decode() + blob_resource = BlobResourceContents(uri="file://doc.bin", mimeType=mime_type, blob=b64) + content = EmbeddedResource(type="resource", resource=blob_resource) + result = CallToolResult(content=[content]) + + with patch.object(tool, "invoke_remote_mcp_tool", return_value=result): + messages = list(tool._invoke(user_id="test_user", tool_parameters={})) + + assert len(messages) == 1 + msg = messages[0] + assert msg.type == ToolInvokeMessage.MessageType.BLOB + assert isinstance(msg.message, ToolInvokeMessage.BlobMessage) + assert msg.message.blob == raw + assert msg.meta == {"mime_type": expected_mime} + + def test_invoke_yields_variables_when_structured_content_and_schema(self) -> None: + tool = _make_mcp_tool(output_schema={"type": "object"}) + result = CallToolResult(content=[], structuredContent={"a": 1, "b": "x"}) + + with patch.object(tool, "invoke_remote_mcp_tool", return_value=result): + messages = list(tool._invoke(user_id="test_user", tool_parameters={})) + + # Expect two variable messages corresponding to keys a and b + assert len(messages) == 2 + var_msgs = [m for m in messages if isinstance(m.message, ToolInvokeMessage.VariableMessage)] + assert {m.message.variable_name for m in var_msgs} == {"a", "b"} + # Validate values + values = {m.message.variable_name: m.message.variable_value for m in var_msgs} + assert values == {"a": 1, "b": "x"} diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index 11e017464a..bf61162a66 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -15,6 +15,11 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols ("", ""), (" ", " "), ("【测试】", "【测试】"), + # Markdown link preservation - should be preserved if text starts with a markdown link + ("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"), + ("[Example](http://example.com) some text", "[Example](http://example.com) some text"), + # Leading symbols before markdown link are removed, including the opening bracket [ + ("@[Test](https://example.com)", "Test](https://example.com)"), ], ) def test_remove_leading_symbols(input_text, expected_output): diff --git a/api/uv.lock b/api/uv.lock index c31b7fe445..7808c16a8c 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -48,7 +48,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.13.2" +version = "3.13.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -59,42 +59,42 @@ dependencies = [ { name = "propcache" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/ce/3b83ebba6b3207a7135e5fcaba49706f8a4b6008153b4e30540c982fae26/aiohttp-3.13.2.tar.gz", hash = "sha256:40176a52c186aefef6eb3cad2cdd30cd06e3afbe88fe8ab2af9c0b90f228daca", size = 7837994, upload-time = "2025-10-28T20:59:39.937Z" } +sdist = { url = "https://files.pythonhosted.org/packages/50/42/32cf8e7704ceb4481406eb87161349abb46a57fee3f008ba9cb610968646/aiohttp-3.13.3.tar.gz", hash = "sha256:a949eee43d3782f2daae4f4a2819b2cb9b0c5d3b7f7a927067cc84dafdbb9f88", size = 7844556, upload-time = "2026-01-03T17:33:05.204Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/35/74/b321e7d7ca762638cdf8cdeceb39755d9c745aff7a64c8789be96ddf6e96/aiohttp-3.13.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4647d02df098f6434bafd7f32ad14942f05a9caa06c7016fdcc816f343997dd0", size = 743409, upload-time = "2025-10-28T20:56:00.354Z" }, - { url = "https://files.pythonhosted.org/packages/99/3d/91524b905ec473beaf35158d17f82ef5a38033e5809fe8742e3657cdbb97/aiohttp-3.13.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e3403f24bcb9c3b29113611c3c16a2a447c3953ecf86b79775e7be06f7ae7ccb", size = 497006, upload-time = "2025-10-28T20:56:01.85Z" }, - { url = "https://files.pythonhosted.org/packages/eb/d3/7f68bc02a67716fe80f063e19adbd80a642e30682ce74071269e17d2dba1/aiohttp-3.13.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:43dff14e35aba17e3d6d5ba628858fb8cb51e30f44724a2d2f0c75be492c55e9", size = 493195, upload-time = "2025-10-28T20:56:03.314Z" }, - { url = "https://files.pythonhosted.org/packages/98/31/913f774a4708775433b7375c4f867d58ba58ead833af96c8af3621a0d243/aiohttp-3.13.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2a9ea08e8c58bb17655630198833109227dea914cd20be660f52215f6de5613", size = 1747759, upload-time = "2025-10-28T20:56:04.904Z" }, - { url = "https://files.pythonhosted.org/packages/e8/63/04efe156f4326f31c7c4a97144f82132c3bb21859b7bb84748d452ccc17c/aiohttp-3.13.2-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:53b07472f235eb80e826ad038c9d106c2f653584753f3ddab907c83f49eedead", size = 1704456, upload-time = "2025-10-28T20:56:06.986Z" }, - { url = "https://files.pythonhosted.org/packages/8e/02/4e16154d8e0a9cf4ae76f692941fd52543bbb148f02f098ca73cab9b1c1b/aiohttp-3.13.2-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e736c93e9c274fce6419af4aac199984d866e55f8a4cec9114671d0ea9688780", size = 1807572, upload-time = "2025-10-28T20:56:08.558Z" }, - { url = "https://files.pythonhosted.org/packages/34/58/b0583defb38689e7f06798f0285b1ffb3a6fb371f38363ce5fd772112724/aiohttp-3.13.2-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ff5e771f5dcbc81c64898c597a434f7682f2259e0cd666932a913d53d1341d1a", size = 1895954, upload-time = "2025-10-28T20:56:10.545Z" }, - { url = "https://files.pythonhosted.org/packages/6b/f3/083907ee3437425b4e376aa58b2c915eb1a33703ec0dc30040f7ae3368c6/aiohttp-3.13.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3b6fb0c207cc661fa0bf8c66d8d9b657331ccc814f4719468af61034b478592", size = 1747092, upload-time = "2025-10-28T20:56:12.118Z" }, - { url = "https://files.pythonhosted.org/packages/ac/61/98a47319b4e425cc134e05e5f3fc512bf9a04bf65aafd9fdcda5d57ec693/aiohttp-3.13.2-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:97a0895a8e840ab3520e2288db7cace3a1981300d48babeb50e7425609e2e0ab", size = 1606815, upload-time = "2025-10-28T20:56:14.191Z" }, - { url = "https://files.pythonhosted.org/packages/97/4b/e78b854d82f66bb974189135d31fce265dee0f5344f64dd0d345158a5973/aiohttp-3.13.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9e8f8afb552297aca127c90cb840e9a1d4bfd6a10d7d8f2d9176e1acc69bad30", size = 1723789, upload-time = "2025-10-28T20:56:16.101Z" }, - { url = "https://files.pythonhosted.org/packages/ed/fc/9d2ccc794fc9b9acd1379d625c3a8c64a45508b5091c546dea273a41929e/aiohttp-3.13.2-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:ed2f9c7216e53c3df02264f25d824b079cc5914f9e2deba94155190ef648ee40", size = 1718104, upload-time = "2025-10-28T20:56:17.655Z" }, - { url = "https://files.pythonhosted.org/packages/66/65/34564b8765ea5c7d79d23c9113135d1dd3609173da13084830f1507d56cf/aiohttp-3.13.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:99c5280a329d5fa18ef30fd10c793a190d996567667908bef8a7f81f8202b948", size = 1785584, upload-time = "2025-10-28T20:56:19.238Z" }, - { url = "https://files.pythonhosted.org/packages/30/be/f6a7a426e02fc82781afd62016417b3948e2207426d90a0e478790d1c8a4/aiohttp-3.13.2-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:2ca6ffef405fc9c09a746cb5d019c1672cd7f402542e379afc66b370833170cf", size = 1595126, upload-time = "2025-10-28T20:56:20.836Z" }, - { url = "https://files.pythonhosted.org/packages/e5/c7/8e22d5d28f94f67d2af496f14a83b3c155d915d1fe53d94b66d425ec5b42/aiohttp-3.13.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:47f438b1a28e926c37632bff3c44df7d27c9b57aaf4e34b1def3c07111fdb782", size = 1800665, upload-time = "2025-10-28T20:56:22.922Z" }, - { url = "https://files.pythonhosted.org/packages/d1/11/91133c8b68b1da9fc16555706aa7276fdf781ae2bb0876c838dd86b8116e/aiohttp-3.13.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9acda8604a57bb60544e4646a4615c1866ee6c04a8edef9b8ee6fd1d8fa2ddc8", size = 1739532, upload-time = "2025-10-28T20:56:25.924Z" }, - { url = "https://files.pythonhosted.org/packages/17/6b/3747644d26a998774b21a616016620293ddefa4d63af6286f389aedac844/aiohttp-3.13.2-cp311-cp311-win32.whl", hash = "sha256:868e195e39b24aaa930b063c08bb0c17924899c16c672a28a65afded9c46c6ec", size = 431876, upload-time = "2025-10-28T20:56:27.524Z" }, - { url = "https://files.pythonhosted.org/packages/c3/63/688462108c1a00eb9f05765331c107f95ae86f6b197b865d29e930b7e462/aiohttp-3.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:7fd19df530c292542636c2a9a85854fab93474396a52f1695e799186bbd7f24c", size = 456205, upload-time = "2025-10-28T20:56:29.062Z" }, - { url = "https://files.pythonhosted.org/packages/29/9b/01f00e9856d0a73260e86dd8ed0c2234a466c5c1712ce1c281548df39777/aiohttp-3.13.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:b1e56bab2e12b2b9ed300218c351ee2a3d8c8fdab5b1ec6193e11a817767e47b", size = 737623, upload-time = "2025-10-28T20:56:30.797Z" }, - { url = "https://files.pythonhosted.org/packages/5a/1b/4be39c445e2b2bd0aab4ba736deb649fabf14f6757f405f0c9685019b9e9/aiohttp-3.13.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:364e25edaabd3d37b1db1f0cbcee8c73c9a3727bfa262b83e5e4cf3489a2a9dc", size = 492664, upload-time = "2025-10-28T20:56:32.708Z" }, - { url = "https://files.pythonhosted.org/packages/28/66/d35dcfea8050e131cdd731dff36434390479b4045a8d0b9d7111b0a968f1/aiohttp-3.13.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c5c94825f744694c4b8db20b71dba9a257cd2ba8e010a803042123f3a25d50d7", size = 491808, upload-time = "2025-10-28T20:56:34.57Z" }, - { url = "https://files.pythonhosted.org/packages/00/29/8e4609b93e10a853b65f8291e64985de66d4f5848c5637cddc70e98f01f8/aiohttp-3.13.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ba2715d842ffa787be87cbfce150d5e88c87a98e0b62e0f5aa489169a393dbbb", size = 1738863, upload-time = "2025-10-28T20:56:36.377Z" }, - { url = "https://files.pythonhosted.org/packages/9d/fa/4ebdf4adcc0def75ced1a0d2d227577cd7b1b85beb7edad85fcc87693c75/aiohttp-3.13.2-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:585542825c4bc662221fb257889e011a5aa00f1ae4d75d1d246a5225289183e3", size = 1700586, upload-time = "2025-10-28T20:56:38.034Z" }, - { url = "https://files.pythonhosted.org/packages/da/04/73f5f02ff348a3558763ff6abe99c223381b0bace05cd4530a0258e52597/aiohttp-3.13.2-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:39d02cb6025fe1aabca329c5632f48c9532a3dabccd859e7e2f110668972331f", size = 1768625, upload-time = "2025-10-28T20:56:39.75Z" }, - { url = "https://files.pythonhosted.org/packages/f8/49/a825b79ffec124317265ca7d2344a86bcffeb960743487cb11988ffb3494/aiohttp-3.13.2-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e67446b19e014d37342f7195f592a2a948141d15a312fe0e700c2fd2f03124f6", size = 1867281, upload-time = "2025-10-28T20:56:41.471Z" }, - { url = "https://files.pythonhosted.org/packages/b9/48/adf56e05f81eac31edcfae45c90928f4ad50ef2e3ea72cb8376162a368f8/aiohttp-3.13.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4356474ad6333e41ccefd39eae869ba15a6c5299c9c01dfdcfdd5c107be4363e", size = 1752431, upload-time = "2025-10-28T20:56:43.162Z" }, - { url = "https://files.pythonhosted.org/packages/30/ab/593855356eead019a74e862f21523db09c27f12fd24af72dbc3555b9bfd9/aiohttp-3.13.2-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:eeacf451c99b4525f700f078becff32c32ec327b10dcf31306a8a52d78166de7", size = 1562846, upload-time = "2025-10-28T20:56:44.85Z" }, - { url = "https://files.pythonhosted.org/packages/39/0f/9f3d32271aa8dc35036e9668e31870a9d3b9542dd6b3e2c8a30931cb27ae/aiohttp-3.13.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d8a9b889aeabd7a4e9af0b7f4ab5ad94d42e7ff679aaec6d0db21e3b639ad58d", size = 1699606, upload-time = "2025-10-28T20:56:46.519Z" }, - { url = "https://files.pythonhosted.org/packages/2c/3c/52d2658c5699b6ef7692a3f7128b2d2d4d9775f2a68093f74bca06cf01e1/aiohttp-3.13.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:fa89cb11bc71a63b69568d5b8a25c3ca25b6d54c15f907ca1c130d72f320b76b", size = 1720663, upload-time = "2025-10-28T20:56:48.528Z" }, - { url = "https://files.pythonhosted.org/packages/9b/d4/8f8f3ff1fb7fb9e3f04fcad4e89d8a1cd8fc7d05de67e3de5b15b33008ff/aiohttp-3.13.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8aa7c807df234f693fed0ecd507192fc97692e61fee5702cdc11155d2e5cadc8", size = 1737939, upload-time = "2025-10-28T20:56:50.77Z" }, - { url = "https://files.pythonhosted.org/packages/03/d3/ddd348f8a27a634daae39a1b8e291ff19c77867af438af844bf8b7e3231b/aiohttp-3.13.2-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:9eb3e33fdbe43f88c3c75fa608c25e7c47bbd80f48d012763cb67c47f39a7e16", size = 1555132, upload-time = "2025-10-28T20:56:52.568Z" }, - { url = "https://files.pythonhosted.org/packages/39/b8/46790692dc46218406f94374903ba47552f2f9f90dad554eed61bfb7b64c/aiohttp-3.13.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9434bc0d80076138ea986833156c5a48c9c7a8abb0c96039ddbb4afc93184169", size = 1764802, upload-time = "2025-10-28T20:56:54.292Z" }, - { url = "https://files.pythonhosted.org/packages/ba/e4/19ce547b58ab2a385e5f0b8aa3db38674785085abcf79b6e0edd1632b12f/aiohttp-3.13.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ff15c147b2ad66da1f2cbb0622313f2242d8e6e8f9b79b5206c84523a4473248", size = 1719512, upload-time = "2025-10-28T20:56:56.428Z" }, - { url = "https://files.pythonhosted.org/packages/70/30/6355a737fed29dcb6dfdd48682d5790cb5eab050f7b4e01f49b121d3acad/aiohttp-3.13.2-cp312-cp312-win32.whl", hash = "sha256:27e569eb9d9e95dbd55c0fc3ec3a9335defbf1d8bc1d20171a49f3c4c607b93e", size = 426690, upload-time = "2025-10-28T20:56:58.736Z" }, - { url = "https://files.pythonhosted.org/packages/0a/0d/b10ac09069973d112de6ef980c1f6bb31cb7dcd0bc363acbdad58f927873/aiohttp-3.13.2-cp312-cp312-win_amd64.whl", hash = "sha256:8709a0f05d59a71f33fd05c17fc11fcb8c30140506e13c2f5e8ee1b8964e1b45", size = 453465, upload-time = "2025-10-28T20:57:00.795Z" }, + { url = "https://files.pythonhosted.org/packages/f1/4c/a164164834f03924d9a29dc3acd9e7ee58f95857e0b467f6d04298594ebb/aiohttp-3.13.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5b6073099fb654e0a068ae678b10feff95c5cae95bbfcbfa7af669d361a8aa6b", size = 746051, upload-time = "2026-01-03T17:29:43.287Z" }, + { url = "https://files.pythonhosted.org/packages/82/71/d5c31390d18d4f58115037c432b7e0348c60f6f53b727cad33172144a112/aiohttp-3.13.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cb93e166e6c28716c8c6aeb5f99dfb6d5ccf482d29fe9bf9a794110e6d0ab64", size = 499234, upload-time = "2026-01-03T17:29:44.822Z" }, + { url = "https://files.pythonhosted.org/packages/0e/c9/741f8ac91e14b1d2e7100690425a5b2b919a87a5075406582991fb7de920/aiohttp-3.13.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:28e027cf2f6b641693a09f631759b4d9ce9165099d2b5d92af9bd4e197690eea", size = 494979, upload-time = "2026-01-03T17:29:46.405Z" }, + { url = "https://files.pythonhosted.org/packages/75/b5/31d4d2e802dfd59f74ed47eba48869c1c21552c586d5e81a9d0d5c2ad640/aiohttp-3.13.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3b61b7169ababd7802f9568ed96142616a9118dd2be0d1866e920e77ec8fa92a", size = 1748297, upload-time = "2026-01-03T17:29:48.083Z" }, + { url = "https://files.pythonhosted.org/packages/1a/3e/eefad0ad42959f226bb79664826883f2687d602a9ae2941a18e0484a74d3/aiohttp-3.13.3-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:80dd4c21b0f6237676449c6baaa1039abae86b91636b6c91a7f8e61c87f89540", size = 1707172, upload-time = "2026-01-03T17:29:49.648Z" }, + { url = "https://files.pythonhosted.org/packages/c5/3a/54a64299fac2891c346cdcf2aa6803f994a2e4beeaf2e5a09dcc54acc842/aiohttp-3.13.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:65d2ccb7eabee90ce0503c17716fc77226be026dcc3e65cce859a30db715025b", size = 1805405, upload-time = "2026-01-03T17:29:51.244Z" }, + { url = "https://files.pythonhosted.org/packages/6c/70/ddc1b7169cf64075e864f64595a14b147a895a868394a48f6a8031979038/aiohttp-3.13.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5b179331a481cb5529fca8b432d8d3c7001cb217513c94cd72d668d1248688a3", size = 1899449, upload-time = "2026-01-03T17:29:53.938Z" }, + { url = "https://files.pythonhosted.org/packages/a1/7e/6815aab7d3a56610891c76ef79095677b8b5be6646aaf00f69b221765021/aiohttp-3.13.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d4c940f02f49483b18b079d1c27ab948721852b281f8b015c058100e9421dd1", size = 1748444, upload-time = "2026-01-03T17:29:55.484Z" }, + { url = "https://files.pythonhosted.org/packages/6b/f2/073b145c4100da5511f457dc0f7558e99b2987cf72600d42b559db856fbc/aiohttp-3.13.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f9444f105664c4ce47a2a7171a2418bce5b7bae45fb610f4e2c36045d85911d3", size = 1606038, upload-time = "2026-01-03T17:29:57.179Z" }, + { url = "https://files.pythonhosted.org/packages/0a/c1/778d011920cae03ae01424ec202c513dc69243cf2db303965615b81deeea/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:694976222c711d1d00ba131904beb60534f93966562f64440d0c9d41b8cdb440", size = 1724156, upload-time = "2026-01-03T17:29:58.914Z" }, + { url = "https://files.pythonhosted.org/packages/0e/cb/3419eabf4ec1e9ec6f242c32b689248365a1cf621891f6f0386632525494/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:f33ed1a2bf1997a36661874b017f5c4b760f41266341af36febaf271d179f6d7", size = 1722340, upload-time = "2026-01-03T17:30:01.962Z" }, + { url = "https://files.pythonhosted.org/packages/7a/e5/76cf77bdbc435bf233c1f114edad39ed4177ccbfab7c329482b179cff4f4/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e636b3c5f61da31a92bf0d91da83e58fdfa96f178ba682f11d24f31944cdd28c", size = 1783041, upload-time = "2026-01-03T17:30:03.609Z" }, + { url = "https://files.pythonhosted.org/packages/9d/d4/dd1ca234c794fd29c057ce8c0566b8ef7fd6a51069de5f06fa84b9a1971c/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:5d2d94f1f5fcbe40838ac51a6ab5704a6f9ea42e72ceda48de5e6b898521da51", size = 1596024, upload-time = "2026-01-03T17:30:05.132Z" }, + { url = "https://files.pythonhosted.org/packages/55/58/4345b5f26661a6180afa686c473620c30a66afdf120ed3dd545bbc809e85/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:2be0e9ccf23e8a94f6f0650ce06042cefc6ac703d0d7ab6c7a917289f2539ad4", size = 1804590, upload-time = "2026-01-03T17:30:07.135Z" }, + { url = "https://files.pythonhosted.org/packages/7b/06/05950619af6c2df7e0a431d889ba2813c9f0129cec76f663e547a5ad56f2/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9af5e68ee47d6534d36791bbe9b646d2a7c7deb6fc24d7943628edfbb3581f29", size = 1740355, upload-time = "2026-01-03T17:30:09.083Z" }, + { url = "https://files.pythonhosted.org/packages/3e/80/958f16de79ba0422d7c1e284b2abd0c84bc03394fbe631d0a39ffa10e1eb/aiohttp-3.13.3-cp311-cp311-win32.whl", hash = "sha256:a2212ad43c0833a873d0fb3c63fa1bacedd4cf6af2fee62bf4b739ceec3ab239", size = 433701, upload-time = "2026-01-03T17:30:10.869Z" }, + { url = "https://files.pythonhosted.org/packages/dc/f2/27cdf04c9851712d6c1b99df6821a6623c3c9e55956d4b1e318c337b5a48/aiohttp-3.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:642f752c3eb117b105acbd87e2c143de710987e09860d674e068c4c2c441034f", size = 457678, upload-time = "2026-01-03T17:30:12.719Z" }, + { url = "https://files.pythonhosted.org/packages/a0/be/4fc11f202955a69e0db803a12a062b8379c970c7c84f4882b6da17337cc1/aiohttp-3.13.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:b903a4dfee7d347e2d87697d0713be59e0b87925be030c9178c5faa58ea58d5c", size = 739732, upload-time = "2026-01-03T17:30:14.23Z" }, + { url = "https://files.pythonhosted.org/packages/97/2c/621d5b851f94fa0bb7430d6089b3aa970a9d9b75196bc93bb624b0db237a/aiohttp-3.13.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a45530014d7a1e09f4a55f4f43097ba0fd155089372e105e4bff4ca76cb1b168", size = 494293, upload-time = "2026-01-03T17:30:15.96Z" }, + { url = "https://files.pythonhosted.org/packages/5d/43/4be01406b78e1be8320bb8316dc9c42dbab553d281c40364e0f862d5661c/aiohttp-3.13.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27234ef6d85c914f9efeb77ff616dbf4ad2380be0cda40b4db086ffc7ddd1b7d", size = 493533, upload-time = "2026-01-03T17:30:17.431Z" }, + { url = "https://files.pythonhosted.org/packages/8d/a8/5a35dc56a06a2c90d4742cbf35294396907027f80eea696637945a106f25/aiohttp-3.13.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d32764c6c9aafb7fb55366a224756387cd50bfa720f32b88e0e6fa45b27dcf29", size = 1737839, upload-time = "2026-01-03T17:30:19.422Z" }, + { url = "https://files.pythonhosted.org/packages/bf/62/4b9eeb331da56530bf2e198a297e5303e1c1ebdceeb00fe9b568a65c5a0c/aiohttp-3.13.3-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b1a6102b4d3ebc07dad44fbf07b45bb600300f15b552ddf1851b5390202ea2e3", size = 1703932, upload-time = "2026-01-03T17:30:21.756Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f6/af16887b5d419e6a367095994c0b1332d154f647e7dc2bd50e61876e8e3d/aiohttp-3.13.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c014c7ea7fb775dd015b2d3137378b7be0249a448a1612268b5a90c2d81de04d", size = 1771906, upload-time = "2026-01-03T17:30:23.932Z" }, + { url = "https://files.pythonhosted.org/packages/ce/83/397c634b1bcc24292fa1e0c7822800f9f6569e32934bdeef09dae7992dfb/aiohttp-3.13.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2b8d8ddba8f95ba17582226f80e2de99c7a7948e66490ef8d947e272a93e9463", size = 1871020, upload-time = "2026-01-03T17:30:26Z" }, + { url = "https://files.pythonhosted.org/packages/86/f6/a62cbbf13f0ac80a70f71b1672feba90fdb21fd7abd8dbf25c0105fb6fa3/aiohttp-3.13.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ae8dd55c8e6c4257eae3a20fd2c8f41edaea5992ed67156642493b8daf3cecc", size = 1755181, upload-time = "2026-01-03T17:30:27.554Z" }, + { url = "https://files.pythonhosted.org/packages/0a/87/20a35ad487efdd3fba93d5843efdfaa62d2f1479eaafa7453398a44faf13/aiohttp-3.13.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:01ad2529d4b5035578f5081606a465f3b814c542882804e2e8cda61adf5c71bf", size = 1561794, upload-time = "2026-01-03T17:30:29.254Z" }, + { url = "https://files.pythonhosted.org/packages/de/95/8fd69a66682012f6716e1bc09ef8a1a2a91922c5725cb904689f112309c4/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bb4f7475e359992b580559e008c598091c45b5088f28614e855e42d39c2f1033", size = 1697900, upload-time = "2026-01-03T17:30:31.033Z" }, + { url = "https://files.pythonhosted.org/packages/e5/66/7b94b3b5ba70e955ff597672dad1691333080e37f50280178967aff68657/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c19b90316ad3b24c69cd78d5c9b4f3aa4497643685901185b65166293d36a00f", size = 1728239, upload-time = "2026-01-03T17:30:32.703Z" }, + { url = "https://files.pythonhosted.org/packages/47/71/6f72f77f9f7d74719692ab65a2a0252584bf8d5f301e2ecb4c0da734530a/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:96d604498a7c782cb15a51c406acaea70d8c027ee6b90c569baa6e7b93073679", size = 1740527, upload-time = "2026-01-03T17:30:34.695Z" }, + { url = "https://files.pythonhosted.org/packages/fa/b4/75ec16cbbd5c01bdaf4a05b19e103e78d7ce1ef7c80867eb0ace42ff4488/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:084911a532763e9d3dd95adf78a78f4096cd5f58cdc18e6fdbc1b58417a45423", size = 1554489, upload-time = "2026-01-03T17:30:36.864Z" }, + { url = "https://files.pythonhosted.org/packages/52/8f/bc518c0eea29f8406dcf7ed1f96c9b48e3bc3995a96159b3fc11f9e08321/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7a4a94eb787e606d0a09404b9c38c113d3b099d508021faa615d70a0131907ce", size = 1767852, upload-time = "2026-01-03T17:30:39.433Z" }, + { url = "https://files.pythonhosted.org/packages/9d/f2/a07a75173124f31f11ea6f863dc44e6f09afe2bca45dd4e64979490deab1/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:87797e645d9d8e222e04160ee32aa06bc5c163e8499f24db719e7852ec23093a", size = 1722379, upload-time = "2026-01-03T17:30:41.081Z" }, + { url = "https://files.pythonhosted.org/packages/3c/4a/1a3fee7c21350cac78e5c5cef711bac1b94feca07399f3d406972e2d8fcd/aiohttp-3.13.3-cp312-cp312-win32.whl", hash = "sha256:b04be762396457bef43f3597c991e192ee7da460a4953d7e647ee4b1c28e7046", size = 428253, upload-time = "2026-01-03T17:30:42.644Z" }, + { url = "https://files.pythonhosted.org/packages/d9/b7/76175c7cb4eb73d91ad63c34e29fc4f77c9386bba4a65b53ba8e05ee3c39/aiohttp-3.13.3-cp312-cp312-win_amd64.whl", hash = "sha256:e3531d63d3bdfa7e3ac5e9b27b2dd7ec9df3206a98e0b3445fa906f233264c57", size = 455407, upload-time = "2026-01-03T17:30:44.195Z" }, ] [[package]] @@ -441,27 +441,27 @@ wheels = [ [[package]] name = "authlib" -version = "1.6.5" +version = "1.6.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cd/3f/1d3bbd0bf23bdd99276d4def22f29c27a914067b4cf66f753ff9b8bbd0f3/authlib-1.6.5.tar.gz", hash = "sha256:6aaf9c79b7cc96c900f0b284061691c5d4e61221640a948fe690b556a6d6d10b", size = 164553, upload-time = "2025-10-02T13:36:09.489Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/9b/b1661026ff24bc641b76b78c5222d614776b0c085bcfdac9bd15a1cb4b35/authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e", size = 164894, upload-time = "2025-12-12T08:01:41.464Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/aa/5082412d1ee302e9e7d80b6949bc4d2a8fa1149aaab610c5fc24709605d6/authlib-1.6.5-py2.py3-none-any.whl", hash = "sha256:3e0e0507807f842b02175507bdee8957a1d5707fd4afb17c32fb43fee90b6e3a", size = 243608, upload-time = "2025-10-02T13:36:07.637Z" }, + { url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005, upload-time = "2025-12-12T08:01:40.209Z" }, ] [[package]] name = "azure-core" -version = "1.36.0" +version = "1.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" }, + { url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" }, ] [[package]] @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.11.2" +version = "1.11.4" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1382,6 +1382,7 @@ dependencies = [ { name = "celery" }, { name = "charset-normalizer" }, { name = "croniter" }, + { name = "fastopenapi", extra = ["flask"] }, { name = "flask" }, { name = "flask-compress" }, { name = "flask-cors" }, @@ -1580,6 +1581,7 @@ requires-dist = [ { name = "celery", specifier = "~=5.5.2" }, { name = "charset-normalizer", specifier = ">=3.4.4" }, { name = "croniter", specifier = ">=6.0.0" }, + { name = "fastopenapi", extras = ["flask"], specifier = ">=0.7.0" }, { name = "flask", specifier = "~=3.1.2" }, { name = "flask-compress", specifier = ">=1.17,<1.18" }, { name = "flask-cors", specifier = "~=6.0.0" }, @@ -1600,7 +1602,7 @@ requires-dist = [ { name = "httpx", extras = ["socks"], specifier = "~=0.27.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, - { name = "json-repair", specifier = ">=0.41.1" }, + { name = "json-repair", specifier = ">=0.55.1" }, { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, @@ -1631,7 +1633,7 @@ requires-dist = [ { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=2.2.2" }, { name = "psycogreen", specifier = "~=1.0.2" }, { name = "psycopg2-binary", specifier = "~=2.9.6" }, - { name = "pycryptodome", specifier = "==3.19.1" }, + { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.11.4" }, { name = "pydantic-extra-types", specifier = "~=2.10.3" }, { name = "pydantic-settings", specifier = "~=2.11.0" }, @@ -1731,7 +1733,7 @@ storage = [ { name = "opendal", specifier = "~=0.46.0" }, { name = "oss2", specifier = "==2.18.5" }, { name = "supabase", specifier = "~=2.18.1" }, - { name = "tos", specifier = "~=2.7.1" }, + { name = "tos", specifier = "~=2.9.0" }, ] tools = [ { name = "cloudscraper", specifier = "~=1.2.71" }, @@ -1921,6 +1923,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7a/93/aa8072af4ff37b795f6bbf43dcaf61115f40f49935c7dbb180c9afc3f421/fastapi-0.122.0-py3-none-any.whl", hash = "sha256:a456e8915dfc6c8914a50d9651133bd47ec96d331c5b44600baa635538a30d67", size = 110671, upload-time = "2025-11-24T19:17:45.96Z" }, ] +[[package]] +name = "fastopenapi" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/02/6ee3ecc1e176bbb8c02cbeee30b11f526167c77ef2f7e741ab7999787ad0/fastopenapi-0.7.0.tar.gz", hash = "sha256:5a671fa663e3d89608e9b39a213595f7ac0bf0caf71f2b6016adf4d8c3e1a50e", size = 17191, upload-time = "2025-04-27T13:38:48.368Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/ad/1881ed46a0a7d3ce14472425db0ddc7aa45c858b4b14727647805487f10d/fastopenapi-0.7.0-py3-none-any.whl", hash = "sha256:482914301f348270cb231617863cacfadf1841012c5ff7d4255a27077704c7b2", size = 21272, upload-time = "2025-04-27T13:38:46.877Z" }, +] + +[package.optional-dependencies] +flask = [ + { name = "flask" }, +] + [[package]] name = "fastuuid" version = "0.14.0" @@ -1953,23 +1972,23 @@ wheels = [ [[package]] name = "fickling" -version = "0.1.5" +version = "0.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "stdlib-list" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/41/94/0d0ce455952c036cfee235637f786c1d1d07d1b90f6a4dfb50e0eff929d6/fickling-0.1.5.tar.gz", hash = "sha256:92f9b49e717fa8dbc198b4b7b685587adb652d85aa9ede8131b3e44494efca05", size = 282462, upload-time = "2025-11-18T05:04:30.748Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/91/e05428d1891970047c9bb81324391f47bf3c612c4ec39f4eef3e40009e05/fickling-0.1.7.tar.gz", hash = "sha256:03d11db2fbb86eb40bdc12a3c4e7cac1dbb16e1207893511d7df0d91ae000899", size = 284009, upload-time = "2026-01-09T18:14:03.198Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/a7/d25912b2e3a5b0a37e6f460050bbc396042b5906a6563a1962c484abc3c6/fickling-0.1.5-py3-none-any.whl", hash = "sha256:6aed7270bfa276e188b0abe043a27b3a042129d28ec1fa6ff389bdcc5ad178bb", size = 46240, upload-time = "2025-11-18T05:04:29.048Z" }, + { url = "https://files.pythonhosted.org/packages/85/44/9ce98b41f8b13bb8f7d5d688b95b8a1190533da39e7eb3d231f45ee38351/fickling-0.1.7-py3-none-any.whl", hash = "sha256:cebee4df382e27b6e33fb98a4c76fee01a333609bb992a26e140673954e561e4", size = 47923, upload-time = "2026-01-09T18:14:02.076Z" }, ] [[package]] name = "filelock" -version = "3.20.0" +version = "3.20.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, + { url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" }, ] [[package]] @@ -2955,14 +2974,14 @@ wheels = [ [[package]] name = "intersystems-irispython" -version = "5.3.0" +version = "5.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/56/16d93576b50408d97a5cbbd055d8da024d585e96a360e2adc95b41ae6284/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-macosx_10_9_universal2.whl", hash = "sha256:59d3176a35867a55b1ab69a6b5c75438b460291bccb254c2d2f4173be08b6e55", size = 6594480, upload-time = "2025-10-09T20:47:27.629Z" }, - { url = "https://files.pythonhosted.org/packages/99/bc/19e144ee805ea6ee0df6342a711e722c84347c05a75b3bf040c5fbe19982/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56bccefd1997c25f9f9f6c4086214c18d4fdaac0a93319d4b21dd9a6c59c9e51", size = 14779928, upload-time = "2025-10-09T20:47:30.564Z" }, - { url = "https://files.pythonhosted.org/packages/e6/fb/59ba563a80b39e9450b4627b5696019aa831dce27dacc3831b8c1e669102/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e160adc0785c55bb64e4264b8e99075691a15b0afa5d8d529f1b4bac7e57b81", size = 14422035, upload-time = "2025-10-09T20:47:32.552Z" }, - { url = "https://files.pythonhosted.org/packages/c1/68/ade8ad43f0ed1e5fba60e1710fa5ddeb01285f031e465e8c006329072e63/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win32.whl", hash = "sha256:820f2c5729119e5173a5bf6d6ac2a41275c4f1ffba6af6c59ea313ecd8f499cc", size = 2824316, upload-time = "2025-10-09T20:47:28.998Z" }, - { url = "https://files.pythonhosted.org/packages/f4/03/cd45cb94e42c01dc525efebf3c562543a18ee55b67fde4022665ca672351/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win_amd64.whl", hash = "sha256:fc07ec24bc50b6f01573221cd7d86f2937549effe31c24af8db118e0131e340c", size = 3463297, upload-time = "2025-10-09T20:47:34.636Z" }, + { url = "https://files.pythonhosted.org/packages/33/5b/8eac672a6ef26bef6ef79a7c9557096167b50c4d3577d558ae6999c195fe/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-macosx_10_9_universal2.whl", hash = "sha256:634c9b4ec620837d830ff49543aeb2797a1ce8d8570a0e868398b85330dfcc4d", size = 6736686, upload-time = "2025-12-19T16:24:57.734Z" }, + { url = "https://files.pythonhosted.org/packages/ba/17/bab3e525ffb6711355f7feea18c1b7dced9c2484cecbcdd83f74550398c0/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf912f30f85e2a42f2c2ea77fbeb98a24154d5ea7428a50382786a684ec4f583", size = 16005259, upload-time = "2025-12-19T16:25:05.578Z" }, + { url = "https://files.pythonhosted.org/packages/39/59/9bb79d9e32e3e55fc9aed8071a797b4497924cbc6457cea9255bb09320b7/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be5659a6bb57593910f2a2417eddb9f5dc2f93a337ead6ddca778f557b8a359a", size = 15638040, upload-time = "2025-12-19T16:24:54.429Z" }, + { url = "https://files.pythonhosted.org/packages/cf/47/654ccf9c5cca4f5491f070888544165c9e2a6a485e320ea703e4e38d2358/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win32.whl", hash = "sha256:583e4f17088c1e0530f32efda1c0ccb02993cbc22035bc8b4c71d8693b04ee7e", size = 2879644, upload-time = "2025-12-19T16:24:59.945Z" }, + { url = "https://files.pythonhosted.org/packages/68/95/19cc13d09f1b4120bd41b1434509052e1d02afd27f2679266d7ad9cc1750/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win_amd64.whl", hash = "sha256:1d5d40450a0cdeec2a1f48d12d946a8a8ffc7c128576fcae7d58e66e3a127eae", size = 3522092, upload-time = "2025-12-19T16:25:01.834Z" }, ] [[package]] @@ -3072,11 +3091,11 @@ wheels = [ [[package]] name = "json-repair" -version = "0.54.1" +version = "0.55.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/00/46/d3a4d9a3dad39bb4a2ad16b8adb9fe2e8611b20b71197fe33daa6768e85d/json_repair-0.54.1.tar.gz", hash = "sha256:d010bc31f1fc66e7c36dc33bff5f8902674498ae5cb8e801ad455a53b455ad1d", size = 38555, upload-time = "2025-11-19T14:55:24.265Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/de/71d6bb078d167c0d0959776cee6b6bb8d2ad843f512a5222d7151dde4955/json_repair-0.55.1.tar.gz", hash = "sha256:b27aa0f6bf2e5bf58554037468690446ef26f32ca79c8753282adb3df25fb888", size = 39231, upload-time = "2026-01-23T09:37:20.93Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/db/96/c9aad7ee949cc1bf15df91f347fbc2d3bd10b30b80c7df689ce6fe9332b5/json_repair-0.54.1-py3-none-any.whl", hash = "sha256:016160c5db5d5fe443164927bb58d2dfbba5f43ad85719fa9bc51c713a443ab1", size = 29311, upload-time = "2025-11-19T14:55:22.886Z" }, + { url = "https://files.pythonhosted.org/packages/56/da/289ba9eb550ae420cfc457926f6c49b87cacf8083ee9927e96921888a665/json_repair-0.55.1-py3-none-any.whl", hash = "sha256:a1bcc151982a12bc3ef9e9528198229587b1074999cfe08921ab6333b0c8e206", size = 29743, upload-time = "2026-01-23T09:37:19.404Z" }, ] [[package]] @@ -3382,14 +3401,14 @@ wheels = [ [[package]] name = "marshmallow" -version = "3.26.1" +version = "3.26.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/5e/5e53d26b42ab75491cda89b871dab9e97c840bf12c63ec58a1919710cd06/marshmallow-3.26.1.tar.gz", hash = "sha256:e6d8affb6cb61d39d26402096dc0aee12d5a26d490a121f118d2e81dc0719dc6", size = 221825, upload-time = "2025-02-03T15:32:25.093Z" } +sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/75/51952c7b2d3873b44a0028b1bd26a25078c18f92f256608e8d1dc61b39fd/marshmallow-3.26.1-py3-none-any.whl", hash = "sha256:3350409f20a70a7e4e11a27661187b77cdcaeb20abca41c1454fe33636bea09c", size = 50878, upload-time = "2025-02-03T15:32:22.295Z" }, + { url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" }, ] [[package]] @@ -4403,15 +4422,15 @@ wheels = [ [[package]] name = "pdfminer-six" -version = "20250506" +version = "20251230" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "charset-normalizer" }, { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/78/46/5223d613ac4963e1f7c07b2660fe0e9e770102ec6bda8c038400113fb215/pdfminer_six-20250506.tar.gz", hash = "sha256:b03cc8df09cf3c7aba8246deae52e0bca7ebb112a38895b5e1d4f5dd2b8ca2e7", size = 7387678, upload-time = "2025-05-06T16:17:00.787Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/9a/d79d8fa6d47a0338846bb558b39b9963b8eb2dfedec61867c138c1b17eeb/pdfminer_six-20251230.tar.gz", hash = "sha256:e8f68a14c57e00c2d7276d26519ea64be1b48f91db1cdc776faa80528ca06c1e", size = 8511285, upload-time = "2025-12-30T15:49:13.104Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/73/16/7a432c0101fa87457e75cb12c879e1749c5870a786525e2e0f42871d6462/pdfminer_six-20250506-py3-none-any.whl", hash = "sha256:d81ad173f62e5f841b53a8ba63af1a4a355933cfc0ffabd608e568b9193909e3", size = 5620187, upload-time = "2025-05-06T16:16:58.669Z" }, + { url = "https://files.pythonhosted.org/packages/65/d7/b288ea32deb752a09aab73c75e1e7572ab2a2b56c3124a5d1eb24c62ceb3/pdfminer_six-20251230-py3-none-any.whl", hash = "sha256:9ff2e3466a7dfc6de6fd779478850b6b7c2d9e9405aa2a5869376a822771f485", size = 6591909, upload-time = "2025-12-30T15:49:10.76Z" }, ] [[package]] @@ -4509,7 +4528,7 @@ wheels = [ [[package]] name = "polyfile-weave" -version = "0.5.7" +version = "0.5.8" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "abnf" }, @@ -4527,9 +4546,9 @@ dependencies = [ { name = "pyyaml" }, { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/02/c3/5a2a2ba06850bc5ec27f83ac8b92210dff9ff6736b2c42f700b489b3fd86/polyfile_weave-0.5.7.tar.gz", hash = "sha256:c3d863f51c30322c236bdf385e116ac06d4e7de9ec25a3aae14d42b1d528e33b", size = 5987445, upload-time = "2025-09-22T19:21:11.222Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/d4/76e56e4429646d9353b4287794f8324ff94201bdb0a2c35ce88cf3de90d0/polyfile_weave-0.5.8.tar.gz", hash = "sha256:cf2ca6a1351165fbbf2971ace4b8bebbb03b2c00e4f2159ff29bed88854e7b32", size = 5989602, upload-time = "2026-01-08T04:21:26.689Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cd/f6/d1efedc0f9506e47699616e896d8efe39e8f0b6a7d1d590c3e97455ecf4a/polyfile_weave-0.5.7-py3-none-any.whl", hash = "sha256:880454788bc383408bf19eefd6d1c49a18b965d90c99bccb58f4da65870c82dd", size = 1655397, upload-time = "2025-09-22T19:21:09.142Z" }, + { url = "https://files.pythonhosted.org/packages/54/32/c09fd626366c00325d1981e310be5cac8661c09206098d267a592e0c5000/polyfile_weave-0.5.8-py3-none-any.whl", hash = "sha256:f68c570ef189a4219798a7c797730fc3b7feace7ff5bd7e662490f89b772964a", size = 1656208, upload-time = "2026-01-08T04:21:15.213Z" }, ] [[package]] @@ -4747,11 +4766,11 @@ wheels = [ [[package]] name = "pyasn1" -version = "0.6.1" +version = "0.6.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/b6/6e630dff89739fcd427e3f72b3d905ce0acb85a45d4ec3e2678718a3487f/pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b", size = 146586, upload-time = "2026-01-16T18:04:18.534Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, + { url = "https://files.pythonhosted.org/packages/44/b5/a96872e5184f354da9c84ae119971a0a4c221fe9b27a4d94bd43f2596727/pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf", size = 83371, upload-time = "2026-01-16T18:04:17.174Z" }, ] [[package]] @@ -4777,20 +4796,21 @@ wheels = [ [[package]] name = "pycryptodome" -version = "3.19.1" +version = "3.23.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b1/38/42a8855ff1bf568c61ca6557e2203f318fb7afeadaf2eb8ecfdbde107151/pycryptodome-3.19.1.tar.gz", hash = "sha256:8ae0dd1bcfada451c35f9e29a3e5db385caabc190f98e4a80ad02a61098fb776", size = 4782144, upload-time = "2023-12-28T06:52:40.741Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/a6/8452177684d5e906854776276ddd34eca30d1b1e15aa1ee9cefc289a33f5/pycryptodome-3.23.0.tar.gz", hash = "sha256:447700a657182d60338bab09fdb27518f8856aecd80ae4c6bdddb67ff5da44ef", size = 4921276, upload-time = "2025-05-17T17:21:45.242Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/ef/4931bc30674f0de0ca0e827b58c8b0c17313a8eae2754976c610b866118b/pycryptodome-3.19.1-cp35-abi3-macosx_10_9_universal2.whl", hash = "sha256:67939a3adbe637281c611596e44500ff309d547e932c449337649921b17b6297", size = 2417027, upload-time = "2023-12-28T06:51:50.138Z" }, - { url = "https://files.pythonhosted.org/packages/67/e6/238c53267fd8d223029c0a0d3730cb1b6594d60f62e40c4184703dc490b1/pycryptodome-3.19.1-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:11ddf6c9b52116b62223b6a9f4741bc4f62bb265392a4463282f7f34bb287180", size = 1579728, upload-time = "2023-12-28T06:51:52.385Z" }, - { url = "https://files.pythonhosted.org/packages/7c/87/7181c42c8d5ba89822a4b824830506d0aeec02959bb893614767e3279846/pycryptodome-3.19.1-cp35-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3e6f89480616781d2a7f981472d0cdb09b9da9e8196f43c1234eff45c915766", size = 2051440, upload-time = "2023-12-28T06:51:55.751Z" }, - { url = "https://files.pythonhosted.org/packages/34/dd/332c4c0055527d17dac317ed9f9c864fc047b627d82f4b9a56c110afc6fc/pycryptodome-3.19.1-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27e1efcb68993b7ce5d1d047a46a601d41281bba9f1971e6be4aa27c69ab8065", size = 2125379, upload-time = "2023-12-28T06:51:58.567Z" }, - { url = "https://files.pythonhosted.org/packages/24/9e/320b885ea336c218ff54ec2b276cd70ba6904e4f5a14a771ed39a2c47d59/pycryptodome-3.19.1-cp35-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c6273ca5a03b672e504995529b8bae56da0ebb691d8ef141c4aa68f60765700", size = 2153951, upload-time = "2023-12-28T06:52:01.699Z" }, - { url = "https://files.pythonhosted.org/packages/f4/54/8ae0c43d1257b41bc9d3277c3f875174fd8ad86b9567f0b8609b99c938ee/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:b0bfe61506795877ff974f994397f0c862d037f6f1c0bfc3572195fc00833b96", size = 2044041, upload-time = "2023-12-28T06:52:03.737Z" }, - { url = "https://files.pythonhosted.org/packages/45/93/f8450a92cc38541c3ba1f4cb4e267e15ae6d6678ca617476d52c3a3764d4/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_i686.whl", hash = "sha256:f34976c5c8eb79e14c7d970fb097482835be8d410a4220f86260695ede4c3e17", size = 2182446, upload-time = "2023-12-28T06:52:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914, upload-time = "2023-12-28T06:52:07.44Z" }, - { url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105, upload-time = "2023-12-28T06:52:09.585Z" }, - { url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222, upload-time = "2023-12-28T06:52:11.534Z" }, + { url = "https://files.pythonhosted.org/packages/db/6c/a1f71542c969912bb0e106f64f60a56cc1f0fabecf9396f45accbe63fa68/pycryptodome-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:187058ab80b3281b1de11c2e6842a357a1f71b42cb1e15bce373f3d238135c27", size = 2495627, upload-time = "2025-05-17T17:20:47.139Z" }, + { url = "https://files.pythonhosted.org/packages/6e/4e/a066527e079fc5002390c8acdd3aca431e6ea0a50ffd7201551175b47323/pycryptodome-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:cfb5cd445280c5b0a4e6187a7ce8de5a07b5f3f897f235caa11f1f435f182843", size = 1640362, upload-time = "2025-05-17T17:20:50.392Z" }, + { url = "https://files.pythonhosted.org/packages/50/52/adaf4c8c100a8c49d2bd058e5b551f73dfd8cb89eb4911e25a0c469b6b4e/pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67bd81fcbe34f43ad9422ee8fd4843c8e7198dd88dd3d40e6de42ee65fbe1490", size = 2182625, upload-time = "2025-05-17T17:20:52.866Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e9/a09476d436d0ff1402ac3867d933c61805ec2326c6ea557aeeac3825604e/pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8987bd3307a39bc03df5c8e0e3d8be0c4c3518b7f044b0f4c15d1aa78f52575", size = 2268954, upload-time = "2025-05-17T17:20:55.027Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c5/ffe6474e0c551d54cab931918127c46d70cab8f114e0c2b5a3c071c2f484/pycryptodome-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa0698f65e5b570426fc31b8162ed4603b0c2841cbb9088e2b01641e3065915b", size = 2308534, upload-time = "2025-05-17T17:20:57.279Z" }, + { url = "https://files.pythonhosted.org/packages/18/28/e199677fc15ecf43010f2463fde4c1a53015d1fe95fb03bca2890836603a/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:53ecbafc2b55353edcebd64bf5da94a2a2cdf5090a6915bcca6eca6cc452585a", size = 2181853, upload-time = "2025-05-17T17:20:59.322Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ea/4fdb09f2165ce1365c9eaefef36625583371ee514db58dc9b65d3a255c4c/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:156df9667ad9f2ad26255926524e1c136d6664b741547deb0a86a9acf5ea631f", size = 2342465, upload-time = "2025-05-17T17:21:03.83Z" }, + { url = "https://files.pythonhosted.org/packages/22/82/6edc3fc42fe9284aead511394bac167693fb2b0e0395b28b8bedaa07ef04/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:dea827b4d55ee390dc89b2afe5927d4308a8b538ae91d9c6f7a5090f397af1aa", size = 2267414, upload-time = "2025-05-17T17:21:06.72Z" }, + { url = "https://files.pythonhosted.org/packages/59/fe/aae679b64363eb78326c7fdc9d06ec3de18bac68be4b612fc1fe8902693c/pycryptodome-3.23.0-cp37-abi3-win32.whl", hash = "sha256:507dbead45474b62b2bbe318eb1c4c8ee641077532067fec9c1aa82c31f84886", size = 1768484, upload-time = "2025-05-17T17:21:08.535Z" }, + { url = "https://files.pythonhosted.org/packages/54/2f/e97a1b8294db0daaa87012c24a7bb714147c7ade7656973fd6c736b484ff/pycryptodome-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:c75b52aacc6c0c260f204cbdd834f76edc9fb0d8e0da9fbf8352ef58202564e2", size = 1799636, upload-time = "2025-05-17T17:21:10.393Z" }, + { url = "https://files.pythonhosted.org/packages/18/3d/f9441a0d798bf2b1e645adc3265e55706aead1255ccdad3856dbdcffec14/pycryptodome-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:11eeeb6917903876f134b56ba11abe95c0b0fd5e3330def218083c7d98bbcb3c", size = 1703675, upload-time = "2025-05-17T17:21:13.146Z" }, ] [[package]] @@ -4984,11 +5004,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.4.0" +version = "6.6.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f3/01/f7510cc6124f494cfbec2e8d3c2e1a20d4f6c18622b0c03a3a70e968bacb/pypdf-6.4.0.tar.gz", hash = "sha256:4769d471f8ddc3341193ecc5d6560fa44cf8cd0abfabf21af4e195cc0c224072", size = 5276661, upload-time = "2025-11-23T14:04:43.185Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b8/bb/a44bab1ac3c54dbcf653d7b8bcdee93dddb2d3bf025a3912cacb8149a2f2/pypdf-6.6.2.tar.gz", hash = "sha256:0a3ea3b3303982333404e22d8f75d7b3144f9cf4b2970b96856391a516f9f016", size = 5281850, upload-time = "2026-01-26T11:57:55.964Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cd/f2/9c9429411c91ac1dd5cd66780f22b6df20c64c3646cdd1e6d67cf38579c4/pypdf-6.4.0-py3-none-any.whl", hash = "sha256:55ab9837ed97fd7fcc5c131d52fcc2223bc5c6b8a1488bbf7c0e27f1f0023a79", size = 329497, upload-time = "2025-11-23T14:04:41.448Z" }, + { url = "https://files.pythonhosted.org/packages/7d/be/549aaf1dfa4ab4aed29b09703d2fb02c4366fc1f05e880948c296c5764b9/pypdf-6.6.2-py3-none-any.whl", hash = "sha256:44c0c9811cfb3b83b28f1c3d054531d5b8b81abaedee0d8cb403650d023832ba", size = 329132, upload-time = "2026-01-26T11:57:54.099Z" }, ] [[package]] @@ -6148,7 +6168,7 @@ wheels = [ [[package]] name = "tos" -version = "2.7.2" +version = "2.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "crcmod" }, @@ -6156,8 +6176,9 @@ dependencies = [ { name = "pytz" }, { name = "requests" }, { name = "six" }, + { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0c/01/f811af86f1f80d5f289be075c3b281e74bf3fe081cfbe5cfce44954d2c3a/tos-2.7.2.tar.gz", hash = "sha256:3c31257716785bca7b2cac51474ff32543cda94075a7b7aff70d769c15c7b7ed", size = 123407, upload-time = "2024-10-16T15:59:08.634Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/b3/13451226f564f88d9db2323e9b7eabcced792a0ad5ee1e333751a7634257/tos-2.9.0.tar.gz", hash = "sha256:861cfc348e770f099f911cb96b2c41774ada6c9c51b7a89d97e0c426074dd99e", size = 157071, upload-time = "2026-01-06T04:13:08.921Z" } [[package]] name = "tqdm" @@ -6858,11 +6879,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.6.0" +version = "2.6.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1c/43/554c2569b62f49350597348fc3ac70f786e3c32e7f19d266e19817812dd3/urllib3-2.6.0.tar.gz", hash = "sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1", size = 432585, upload-time = "2025-12-05T15:08:47.885Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/1a/9ffe814d317c5224166b23e7c47f606d6e473712a2fad0f704ea9b99f246/urllib3-2.6.0-py3-none-any.whl", hash = "sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f", size = 131083, upload-time = "2025-12-05T15:08:45.983Z" }, + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, ] [[package]] @@ -7134,43 +7155,43 @@ wheels = [ [[package]] name = "werkzeug" -version = "3.1.4" +version = "3.1.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/45/ea/b0f8eeb287f8df9066e56e831c7824ac6bab645dd6c7a8f4b2d767944f9b/werkzeug-3.1.4.tar.gz", hash = "sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e", size = 864687, upload-time = "2025-11-29T02:15:22.841Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/f9/9e082990c2585c744734f85bec79b5dae5df9c974ffee58fe421652c8e91/werkzeug-3.1.4-py3-none-any.whl", hash = "sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905", size = 224960, upload-time = "2025-11-29T02:15:21.13Z" }, + { url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" }, ] [[package]] name = "wrapt" -version = "1.17.3" +version = "1.16.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/95/8f/aeb76c5b46e273670962298c23e7ddde79916cb74db802131d49a85e4b7d/wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0", size = 55547, upload-time = "2025-08-12T05:53:21.714Z" } +sdist = { url = "https://files.pythonhosted.org/packages/95/4c/063a912e20bcef7124e0df97282a8af3ff3e4b603ce84c481d6d7346be0a/wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d", size = 53972, upload-time = "2023-11-09T06:33:30.191Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/52/db/00e2a219213856074a213503fdac0511203dceefff26e1daa15250cc01a0/wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7", size = 53482, upload-time = "2025-08-12T05:51:45.79Z" }, - { url = "https://files.pythonhosted.org/packages/5e/30/ca3c4a5eba478408572096fe9ce36e6e915994dd26a4e9e98b4f729c06d9/wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85", size = 38674, upload-time = "2025-08-12T05:51:34.629Z" }, - { url = "https://files.pythonhosted.org/packages/31/25/3e8cc2c46b5329c5957cec959cb76a10718e1a513309c31399a4dad07eb3/wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f", size = 38959, upload-time = "2025-08-12T05:51:56.074Z" }, - { url = "https://files.pythonhosted.org/packages/5d/8f/a32a99fc03e4b37e31b57cb9cefc65050ea08147a8ce12f288616b05ef54/wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311", size = 82376, upload-time = "2025-08-12T05:52:32.134Z" }, - { url = "https://files.pythonhosted.org/packages/31/57/4930cb8d9d70d59c27ee1332a318c20291749b4fba31f113c2f8ac49a72e/wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1", size = 83604, upload-time = "2025-08-12T05:52:11.663Z" }, - { url = "https://files.pythonhosted.org/packages/a8/f3/1afd48de81d63dd66e01b263a6fbb86e1b5053b419b9b33d13e1f6d0f7d0/wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5", size = 82782, upload-time = "2025-08-12T05:52:12.626Z" }, - { url = "https://files.pythonhosted.org/packages/1e/d7/4ad5327612173b144998232f98a85bb24b60c352afb73bc48e3e0d2bdc4e/wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2", size = 82076, upload-time = "2025-08-12T05:52:33.168Z" }, - { url = "https://files.pythonhosted.org/packages/bb/59/e0adfc831674a65694f18ea6dc821f9fcb9ec82c2ce7e3d73a88ba2e8718/wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89", size = 36457, upload-time = "2025-08-12T05:53:03.936Z" }, - { url = "https://files.pythonhosted.org/packages/83/88/16b7231ba49861b6f75fc309b11012ede4d6b0a9c90969d9e0db8d991aeb/wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77", size = 38745, upload-time = "2025-08-12T05:53:02.885Z" }, - { url = "https://files.pythonhosted.org/packages/9a/1e/c4d4f3398ec073012c51d1c8d87f715f56765444e1a4b11e5180577b7e6e/wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a", size = 36806, upload-time = "2025-08-12T05:52:53.368Z" }, - { url = "https://files.pythonhosted.org/packages/9f/41/cad1aba93e752f1f9268c77270da3c469883d56e2798e7df6240dcb2287b/wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0", size = 53998, upload-time = "2025-08-12T05:51:47.138Z" }, - { url = "https://files.pythonhosted.org/packages/60/f8/096a7cc13097a1869fe44efe68dace40d2a16ecb853141394047f0780b96/wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba", size = 39020, upload-time = "2025-08-12T05:51:35.906Z" }, - { url = "https://files.pythonhosted.org/packages/33/df/bdf864b8997aab4febb96a9ae5c124f700a5abd9b5e13d2a3214ec4be705/wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd", size = 39098, upload-time = "2025-08-12T05:51:57.474Z" }, - { url = "https://files.pythonhosted.org/packages/9f/81/5d931d78d0eb732b95dc3ddaeeb71c8bb572fb01356e9133916cd729ecdd/wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828", size = 88036, upload-time = "2025-08-12T05:52:34.784Z" }, - { url = "https://files.pythonhosted.org/packages/ca/38/2e1785df03b3d72d34fc6252d91d9d12dc27a5c89caef3335a1bbb8908ca/wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9", size = 88156, upload-time = "2025-08-12T05:52:13.599Z" }, - { url = "https://files.pythonhosted.org/packages/b3/8b/48cdb60fe0603e34e05cffda0b2a4adab81fd43718e11111a4b0100fd7c1/wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396", size = 87102, upload-time = "2025-08-12T05:52:14.56Z" }, - { url = "https://files.pythonhosted.org/packages/3c/51/d81abca783b58f40a154f1b2c56db1d2d9e0d04fa2d4224e357529f57a57/wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc", size = 87732, upload-time = "2025-08-12T05:52:36.165Z" }, - { url = "https://files.pythonhosted.org/packages/9e/b1/43b286ca1392a006d5336412d41663eeef1ad57485f3e52c767376ba7e5a/wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe", size = 36705, upload-time = "2025-08-12T05:53:07.123Z" }, - { url = "https://files.pythonhosted.org/packages/28/de/49493f962bd3c586ab4b88066e967aa2e0703d6ef2c43aa28cb83bf7b507/wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c", size = 38877, upload-time = "2025-08-12T05:53:05.436Z" }, - { url = "https://files.pythonhosted.org/packages/f1/48/0f7102fe9cb1e8a5a77f80d4f0956d62d97034bbe88d33e94699f99d181d/wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6", size = 36885, upload-time = "2025-08-12T05:52:54.367Z" }, - { url = "https://files.pythonhosted.org/packages/1f/f6/a933bd70f98e9cf3e08167fc5cd7aaaca49147e48411c0bd5ae701bb2194/wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22", size = 23591, upload-time = "2025-08-12T05:53:20.674Z" }, + { url = "https://files.pythonhosted.org/packages/fd/03/c188ac517f402775b90d6f312955a5e53b866c964b32119f2ed76315697e/wrapt-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09", size = 37313, upload-time = "2023-11-09T06:31:52.168Z" }, + { url = "https://files.pythonhosted.org/packages/0f/16/ea627d7817394db04518f62934a5de59874b587b792300991b3c347ff5e0/wrapt-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d", size = 38164, upload-time = "2023-11-09T06:31:53.522Z" }, + { url = "https://files.pythonhosted.org/packages/7f/a7/f1212ba098f3de0fd244e2de0f8791ad2539c03bef6c05a9fcb03e45b089/wrapt-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389", size = 80890, upload-time = "2023-11-09T06:31:55.247Z" }, + { url = "https://files.pythonhosted.org/packages/b7/96/bb5e08b3d6db003c9ab219c487714c13a237ee7dcc572a555eaf1ce7dc82/wrapt-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060", size = 73118, upload-time = "2023-11-09T06:31:57.023Z" }, + { url = "https://files.pythonhosted.org/packages/6e/52/2da48b35193e39ac53cfb141467d9f259851522d0e8c87153f0ba4205fb1/wrapt-1.16.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1", size = 80746, upload-time = "2023-11-09T06:31:58.686Z" }, + { url = "https://files.pythonhosted.org/packages/11/fb/18ec40265ab81c0e82a934de04596b6ce972c27ba2592c8b53d5585e6bcd/wrapt-1.16.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3", size = 85668, upload-time = "2023-11-09T06:31:59.992Z" }, + { url = "https://files.pythonhosted.org/packages/0f/ef/0ecb1fa23145560431b970418dce575cfaec555ab08617d82eb92afc7ccf/wrapt-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956", size = 78556, upload-time = "2023-11-09T06:32:01.942Z" }, + { url = "https://files.pythonhosted.org/packages/25/62/cd284b2b747f175b5a96cbd8092b32e7369edab0644c45784871528eb852/wrapt-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d", size = 85712, upload-time = "2023-11-09T06:32:03.686Z" }, + { url = "https://files.pythonhosted.org/packages/e5/a7/47b7ff74fbadf81b696872d5ba504966591a3468f1bc86bca2f407baef68/wrapt-1.16.0-cp311-cp311-win32.whl", hash = "sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362", size = 35327, upload-time = "2023-11-09T06:32:05.284Z" }, + { url = "https://files.pythonhosted.org/packages/cf/c3/0084351951d9579ae83a3d9e38c140371e4c6b038136909235079f2e6e78/wrapt-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89", size = 37523, upload-time = "2023-11-09T06:32:07.17Z" }, + { url = "https://files.pythonhosted.org/packages/92/17/224132494c1e23521868cdd57cd1e903f3b6a7ba6996b7b8f077ff8ac7fe/wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b", size = 37614, upload-time = "2023-11-09T06:32:08.859Z" }, + { url = "https://files.pythonhosted.org/packages/6a/d7/cfcd73e8f4858079ac59d9db1ec5a1349bc486ae8e9ba55698cc1f4a1dff/wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36", size = 38316, upload-time = "2023-11-09T06:32:10.719Z" }, + { url = "https://files.pythonhosted.org/packages/7e/79/5ff0a5c54bda5aec75b36453d06be4f83d5cd4932cc84b7cb2b52cee23e2/wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73", size = 86322, upload-time = "2023-11-09T06:32:12.592Z" }, + { url = "https://files.pythonhosted.org/packages/c4/81/e799bf5d419f422d8712108837c1d9bf6ebe3cb2a81ad94413449543a923/wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809", size = 79055, upload-time = "2023-11-09T06:32:14.394Z" }, + { url = "https://files.pythonhosted.org/packages/62/62/30ca2405de6a20448ee557ab2cd61ab9c5900be7cbd18a2639db595f0b98/wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b", size = 87291, upload-time = "2023-11-09T06:32:16.201Z" }, + { url = "https://files.pythonhosted.org/packages/49/4e/5d2f6d7b57fc9956bf06e944eb00463551f7d52fc73ca35cfc4c2cdb7aed/wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81", size = 90374, upload-time = "2023-11-09T06:32:18.052Z" }, + { url = "https://files.pythonhosted.org/packages/a6/9b/c2c21b44ff5b9bf14a83252a8b973fb84923764ff63db3e6dfc3895cf2e0/wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9", size = 83896, upload-time = "2023-11-09T06:32:19.533Z" }, + { url = "https://files.pythonhosted.org/packages/14/26/93a9fa02c6f257df54d7570dfe8011995138118d11939a4ecd82cb849613/wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c", size = 91738, upload-time = "2023-11-09T06:32:20.989Z" }, + { url = "https://files.pythonhosted.org/packages/a2/5b/4660897233eb2c8c4de3dc7cefed114c61bacb3c28327e64150dc44ee2f6/wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc", size = 35568, upload-time = "2023-11-09T06:32:22.715Z" }, + { url = "https://files.pythonhosted.org/packages/5c/cc/8297f9658506b224aa4bd71906447dea6bb0ba629861a758c28f67428b91/wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8", size = 37653, upload-time = "2023-11-09T06:32:24.533Z" }, + { url = "https://files.pythonhosted.org/packages/ff/21/abdedb4cdf6ff41ebf01a74087740a709e2edb146490e4d9beea054b0b7a/wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1", size = 23362, upload-time = "2023-11-09T06:33:28.271Z" }, ] [[package]] diff --git a/dev/setup b/dev/setup new file mode 100755 index 0000000000..399c8f28a5 --- /dev/null +++ b/dev/setup @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +ROOT="$(dirname "$SCRIPT_DIR")" + +API_ENV_EXAMPLE="$ROOT/api/.env.example" +API_ENV="$ROOT/api/.env" +WEB_ENV_EXAMPLE="$ROOT/web/.env.example" +WEB_ENV="$ROOT/web/.env.local" +MIDDLEWARE_ENV_EXAMPLE="$ROOT/docker/middleware.env.example" +MIDDLEWARE_ENV="$ROOT/docker/middleware.env" + +# 1) Copy api/.env.example -> api/.env +cp "$API_ENV_EXAMPLE" "$API_ENV" + +# 2) Copy web/.env.example -> web/.env.local +cp "$WEB_ENV_EXAMPLE" "$WEB_ENV" + +# 3) Copy docker/middleware.env.example -> docker/middleware.env +cp "$MIDDLEWARE_ENV_EXAMPLE" "$MIDDLEWARE_ENV" + +# 4) Install deps +cd "$ROOT/api" +uv sync --group dev + +cd "$ROOT/web" +pnpm install diff --git a/dev/start-api b/dev/start-api index 0b50ad0d0a..bdfa58eca6 100755 --- a/dev/start-api +++ b/dev/start-api @@ -3,8 +3,9 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/.." +cd "$SCRIPT_DIR/../api" +uv run flask db upgrade -uv --directory api run \ +uv run \ flask run --host 0.0.0.0 --port=5001 --debug diff --git a/dev/start-docker-compose b/dev/start-docker-compose new file mode 100755 index 0000000000..9652be169d --- /dev/null +++ b/dev/start-docker-compose @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +ROOT="$(dirname "$SCRIPT_DIR")" + +cd "$ROOT/docker" +docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d diff --git a/dev/start-web b/dev/start-web index 31c5e168f9..f853f4a895 100755 --- a/dev/start-web +++ b/dev/start-web @@ -5,4 +5,4 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../web" -pnpm install && pnpm dev +pnpm install && pnpm dev:inspect diff --git a/dev/start-worker b/dev/start-worker index 7876620188..3e48065631 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -83,7 +83,7 @@ while [[ $# -gt 0 ]]; do done SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/.." +cd "$SCRIPT_DIR/../api" if [[ -n "${ENV_FILE}" ]]; then if [[ ! -f "${ENV_FILE}" ]]; then @@ -123,6 +123,6 @@ echo " Concurrency: ${CONCURRENCY}" echo " Pool: ${POOL}" echo " Log Level: ${LOGLEVEL}" -uv --directory api run \ +uv run \ celery -A app.celery worker \ -P ${POOL} -c ${CONCURRENCY} --loglevel ${LOGLEVEL} -Q ${QUEUES} diff --git a/dev/update-uv b/dev/update-uv index 189bd1f6b1..35d723fe20 100755 --- a/dev/update-uv +++ b/dev/update-uv @@ -4,7 +4,7 @@ set -e set -o pipefail -SCRIPT_DIR="$(dirname "$0")" +SCRIPT_DIR="$(dirname "$(realpath "$0")")" REPO_ROOT="$(dirname "${SCRIPT_DIR}")" # rely on `poetry` in path diff --git a/docker/.env.example b/docker/.env.example index 55a7a26776..3583adbf48 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -58,8 +58,8 @@ FILES_URL= INTERNAL_FILES_URL= # Ensure UTF-8 encoding -LANG=en_US.UTF-8 -LC_ALL=en_US.UTF-8 +LANG=C.UTF-8 +LC_ALL=C.UTF-8 PYTHONIOENCODING=utf-8 # ------------------------------ @@ -69,6 +69,8 @@ PYTHONIOENCODING=utf-8 # The log level for the application. # Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` LOG_LEVEL=INFO +# Log output format: text or json +LOG_OUTPUT_FORMAT=text # Log file path LOG_FILE=/app/logs/server.log # Log file max size, the unit is MB @@ -231,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false # You can adjust the database configuration according to your needs. # ------------------------------ -# Database type, supported values are `postgresql` and `mysql` +# Database type, supported values are `postgresql`, `mysql`, `oceanbase`, `seekdb` DB_TYPE=postgresql # For MySQL, only `root` user is supported for now DB_USERNAME=postgres @@ -399,6 +401,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=* COOKIE_DOMAIN= # When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. NEXT_PUBLIC_COOKIE_DOMAIN= +NEXT_PUBLIC_BATCH_CONCURRENCY=5 # ------------------------------ # File Storage Configuration @@ -446,6 +449,15 @@ S3_SECRET_KEY= # If set to false, the access key and secret key must be provided. S3_USE_AWS_MANAGED_IAM=false +# Workflow run and Conversation archive storage (S3-compatible) +ARCHIVE_STORAGE_ENABLED=false +ARCHIVE_STORAGE_ENDPOINT= +ARCHIVE_STORAGE_ARCHIVE_BUCKET= +ARCHIVE_STORAGE_EXPORT_BUCKET= +ARCHIVE_STORAGE_ACCESS_KEY= +ARCHIVE_STORAGE_SECRET_KEY= +ARCHIVE_STORAGE_REGION=auto + # Azure Blob Configuration # AZURE_BLOB_ACCOUNT_NAME=difyai @@ -477,6 +489,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key TENCENT_COS_SECRET_ID=your-secret-id TENCENT_COS_REGION=your-region TENCENT_COS_SCHEME=your-scheme +TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain # Oracle Storage Configuration # @@ -520,7 +533,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. +# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -531,9 +544,9 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih WEAVIATE_GRPC_ENDPOINT=grpc://weaviate:50051 WEAVIATE_TOKENIZATION=word -# For OceanBase metadata database configuration, available when `DB_TYPE` is `mysql` and `COMPOSE_PROFILES` includes `oceanbase`. +# For OceanBase metadata database configuration, available when `DB_TYPE` is `oceanbase`. # For OceanBase vector database configuration, available when `VECTOR_STORE` is `oceanbase` -# If you want to use OceanBase as both vector database and metadata database, you need to set `DB_TYPE` to `mysql`, `COMPOSE_PROFILES` is `oceanbase`, and set Database Configuration is the same as the vector database. +# If you want to use OceanBase as both vector database and metadata database, you need to set both `DB_TYPE` and `VECTOR_STORE` to `oceanbase`, and set Database Configuration is the same as the vector database. # seekdb is the lite version of OceanBase and shares the connection configuration with OceanBase. OCEANBASE_VECTOR_HOST=oceanbase OCEANBASE_VECTOR_PORT=2881 @@ -955,6 +968,8 @@ SMTP_USERNAME= SMTP_PASSWORD= SMTP_USE_TLS=true SMTP_OPPORTUNISTIC_TLS=false +# Optional: override the local hostname used for SMTP HELO/EHLO +SMTP_LOCAL_HOSTNAME= # Sendgid configuration SENDGRID_API_KEY= @@ -1024,18 +1039,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms # Options: # - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) # - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository +# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository # Core workflow node execution repository implementation # Options: # - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default) # - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository +# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository # API workflow run repository implementation +# Options: +# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default) +# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository # API workflow node execution repository implementation +# Options: +# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default) +# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository # Workflow log cleanup configuration @@ -1064,6 +1087,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 @@ -1461,6 +1488,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false ENABLE_CREATE_TIDB_SERVERLESS_TASK=false ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false ENABLE_CLEAN_MESSAGES=false +ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true @@ -1515,3 +1543,6 @@ PUBSUB_REDIS_USE_CLUSTERS=false ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true # Human input timeout check interval in minutes HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 + + +SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 3c88cddf8c..9659990383 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.2 + image: langgenius/dify-web:1.11.4 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -475,7 +475,8 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini - LANG: en_US.UTF-8 + LANG: C.UTF-8 + LC_ALL: C.UTF-8 ports: - "${OCEANBASE_VECTOR_PORT:-2881}:2881" healthcheck: diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index dba61d1816..81c34fc6a2 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -129,6 +129,7 @@ services: - ./middleware.env environment: # Use the shared environment variables. + LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} REDIS_HOST: ${REDIS_HOST:-redis} REDIS_PORT: ${REDIS_PORT:-6379} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 9a1138512e..79fc9af77a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -13,10 +13,11 @@ x-shared-env: &shared-api-worker-env APP_WEB_URL: ${APP_WEB_URL:-} FILES_URL: ${FILES_URL:-} INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-} - LANG: ${LANG:-en_US.UTF-8} - LC_ALL: ${LC_ALL:-en_US.UTF-8} + LANG: ${LANG:-C.UTF-8} + LC_ALL: ${LC_ALL:-C.UTF-8} PYTHONIOENCODING: ${PYTHONIOENCODING:-utf-8} LOG_LEVEL: ${LOG_LEVEL:-INFO} + LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5} @@ -108,6 +109,7 @@ x-shared-env: &shared-api-worker-env CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} COOKIE_DOMAIN: ${COOKIE_DOMAIN:-} NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-} + NEXT_PUBLIC_BATCH_CONCURRENCY: ${NEXT_PUBLIC_BATCH_CONCURRENCY:-5} STORAGE_TYPE: ${STORAGE_TYPE:-opendal} OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} @@ -121,6 +123,13 @@ x-shared-env: &shared-api-worker-env S3_ACCESS_KEY: ${S3_ACCESS_KEY:-} S3_SECRET_KEY: ${S3_SECRET_KEY:-} S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} + ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false} + ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-} + ARCHIVE_STORAGE_ARCHIVE_BUCKET: ${ARCHIVE_STORAGE_ARCHIVE_BUCKET:-} + ARCHIVE_STORAGE_EXPORT_BUCKET: ${ARCHIVE_STORAGE_EXPORT_BUCKET:-} + ARCHIVE_STORAGE_ACCESS_KEY: ${ARCHIVE_STORAGE_ACCESS_KEY:-} + ARCHIVE_STORAGE_SECRET_KEY: ${ARCHIVE_STORAGE_SECRET_KEY:-} + ARCHIVE_STORAGE_REGION: ${ARCHIVE_STORAGE_REGION:-auto} AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} @@ -140,6 +149,7 @@ x-shared-env: &shared-api-worker-env TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region} TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme} + TENCENT_COS_CUSTOM_DOMAIN: ${TENCENT_COS_CUSTOM_DOMAIN:-your-custom-domain} OCI_ENDPOINT: ${OCI_ENDPOINT:-https://your-object-storage-namespace.compat.objectstorage.us-ashburn-1.oraclecloud.com} OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name} OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key} @@ -415,6 +425,7 @@ x-shared-env: &shared-api-worker-env SMTP_PASSWORD: ${SMTP_PASSWORD:-} SMTP_USE_TLS: ${SMTP_USE_TLS:-true} SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false} + SMTP_LOCAL_HOSTNAME: ${SMTP_LOCAL_HOSTNAME:-} SENDGRID_API_KEY: ${SENDGRID_API_KEY:-} INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} @@ -465,6 +476,7 @@ x-shared-env: &shared-api-worker-env ALIYUN_SLS_LOGSTORE_TTL: ${ALIYUN_SLS_LOGSTORE_TTL:-365} LOGSTORE_DUAL_WRITE_ENABLED: ${LOGSTORE_DUAL_WRITE_ENABLED:-false} LOGSTORE_DUAL_READ_ENABLED: ${LOGSTORE_DUAL_READ_ENABLED:-true} + LOGSTORE_ENABLE_PUT_GRAPH_FIELD: ${LOGSTORE_ENABLE_PUT_GRAPH_FIELD:-true} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} @@ -651,6 +663,7 @@ x-shared-env: &shared-api-worker-env ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false} ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: ${ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:-false} ENABLE_CLEAN_MESSAGES: ${ENABLE_CLEAN_MESSAGES:-false} + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: ${ENABLE_WORKFLOW_RUN_CLEANUP_TASK:-false} ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false} ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false} ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true} @@ -674,6 +687,7 @@ x-shared-env: &shared-api-worker-env PUBSUB_REDIS_USE_CLUSTERS: ${PUBSUB_REDIS_USE_CLUSTERS:-false} ENABLE_HUMAN_INPUT_TIMEOUT_TASK: ${ENABLE_HUMAN_INPUT_TIMEOUT_TASK:-true} HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: ${HUMAN_INPUT_TIMEOUT_TASK_INTERVAL:-1} + SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: ${SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL:-90000} services: # Init container to fix permissions @@ -697,7 +711,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -739,7 +753,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -778,7 +792,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -808,7 +822,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.2 + image: langgenius/dify-web:1.11.4 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -1151,7 +1165,8 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini - LANG: en_US.UTF-8 + LANG: C.UTF-8 + LC_ALL: C.UTF-8 ports: - "${OCEANBASE_VECTOR_PORT:-2881}:2881" healthcheck: diff --git a/docker/middleware.env.example b/docker/middleware.env.example index f7e0252a6f..c88dbe5511 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -233,4 +233,8 @@ ALIYUN_SLS_LOGSTORE_TTL=365 LOGSTORE_DUAL_WRITE_ENABLED=true # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database -LOGSTORE_DUAL_READ_ENABLED=true \ No newline at end of file +LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true \ No newline at end of file diff --git a/docker/ssrf_proxy/squid.conf.template b/docker/ssrf_proxy/squid.conf.template index 1775a1fff9..256e669c8d 100644 --- a/docker/ssrf_proxy/squid.conf.template +++ b/docker/ssrf_proxy/squid.conf.template @@ -54,3 +54,52 @@ http_access allow src_all # Unless the option's size is increased, an error will occur when uploading more than two files. client_request_buffer_max_size 100 MB + +################################## Performance & Concurrency ############################### +# Increase file descriptor limit for high concurrency +max_filedescriptors 65536 + +# Timeout configurations for image requests +connect_timeout 30 seconds +request_timeout 2 minutes +read_timeout 2 minutes +client_lifetime 5 minutes +shutdown_lifetime 30 seconds + +# Persistent connections - improve performance for multiple requests +server_persistent_connections on +client_persistent_connections on +persistent_request_timeout 30 seconds +pconn_timeout 1 minute + +# Connection pool and concurrency limits +client_db on +server_idle_pconn_timeout 2 minutes +client_idle_pconn_timeout 2 minutes + +# Quick abort settings - don't abort requests that are mostly done +quick_abort_min 16 KB +quick_abort_max 16 MB +quick_abort_pct 95 + +# Memory and cache optimization +memory_cache_mode disk +cache_mem 256 MB +maximum_object_size_in_memory 512 KB + +# DNS resolver settings for better performance +dns_timeout 30 seconds +dns_retransmit_interval 5 seconds +# By default, Squid uses the system's configured DNS resolvers. +# If you need to override them, set dns_nameservers to appropriate servers +# for your environment (for example, internal/corporate DNS). The following +# is an example using public DNS and SHOULD be customized before use: +# dns_nameservers 8.8.8.8 8.8.4.4 + +# Logging format for better debugging +logformat dify_log %ts.%03tu %6tr %>a %Ss/%03>Hs %=6.0.0'} - '@babel/helper-string-parser@7.27.1': resolution: {integrity: sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==} engines: {node: '>=6.9.0'} @@ -63,14 +59,9 @@ packages: resolution: {integrity: sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==} engines: {node: '>=6.9.0'} - '@bcoe/v8-coverage@0.2.3': - resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==} - - '@esbuild/aix-ppc64@0.21.5': - resolution: {integrity: sha512-1SDgH6ZSPTlggy1yI6+Dbkiz8xzpHJEVAlF/AM1tHPLsf5STom9rwtjE4hKAF20FfXXNTFqEYXyJNWh1GiZedQ==} - engines: {node: '>=12'} - cpu: [ppc64] - os: [aix] + '@bcoe/v8-coverage@1.0.2': + resolution: {integrity: sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==} + engines: {node: '>=18'} '@esbuild/aix-ppc64@0.27.2': resolution: {integrity: sha512-GZMB+a0mOMZs4MpDbj8RJp4cw+w1WV5NYD6xzgvzUJ5Ek2jerwfO2eADyI6ExDSUED+1X8aMbegahsJi+8mgpw==} @@ -78,192 +69,96 @@ packages: cpu: [ppc64] os: [aix] - '@esbuild/android-arm64@0.21.5': - resolution: {integrity: sha512-c0uX9VAUBQ7dTDCjq+wdyGLowMdtR/GoC2U5IYk/7D1H1JYC0qseD7+11iMP2mRLN9RcCMRcjC4YMclCzGwS/A==} - engines: {node: '>=12'} - cpu: [arm64] - os: [android] - '@esbuild/android-arm64@0.27.2': resolution: {integrity: sha512-pvz8ZZ7ot/RBphf8fv60ljmaoydPU12VuXHImtAs0XhLLw+EXBi2BLe3OYSBslR4rryHvweW5gmkKFwTiFy6KA==} engines: {node: '>=18'} cpu: [arm64] os: [android] - '@esbuild/android-arm@0.21.5': - resolution: {integrity: sha512-vCPvzSjpPHEi1siZdlvAlsPxXl7WbOVUBBAowWug4rJHb68Ox8KualB+1ocNvT5fjv6wpkX6o/iEpbDrf68zcg==} - engines: {node: '>=12'} - cpu: [arm] - os: [android] - '@esbuild/android-arm@0.27.2': resolution: {integrity: sha512-DVNI8jlPa7Ujbr1yjU2PfUSRtAUZPG9I1RwW4F4xFB1Imiu2on0ADiI/c3td+KmDtVKNbi+nffGDQMfcIMkwIA==} engines: {node: '>=18'} cpu: [arm] os: [android] - '@esbuild/android-x64@0.21.5': - resolution: {integrity: sha512-D7aPRUUNHRBwHxzxRvp856rjUHRFW1SdQATKXH2hqA0kAZb1hKmi02OpYRacl0TxIGz/ZmXWlbZgjwWYaCakTA==} - engines: {node: '>=12'} - cpu: [x64] - os: [android] - '@esbuild/android-x64@0.27.2': resolution: {integrity: sha512-z8Ank4Byh4TJJOh4wpz8g2vDy75zFL0TlZlkUkEwYXuPSgX8yzep596n6mT7905kA9uHZsf/o2OJZubl2l3M7A==} engines: {node: '>=18'} cpu: [x64] os: [android] - '@esbuild/darwin-arm64@0.21.5': - resolution: {integrity: sha512-DwqXqZyuk5AiWWf3UfLiRDJ5EDd49zg6O9wclZ7kUMv2WRFr4HKjXp/5t8JZ11QbQfUS6/cRCKGwYhtNAY88kQ==} - engines: {node: '>=12'} - cpu: [arm64] - os: [darwin] - '@esbuild/darwin-arm64@0.27.2': resolution: {integrity: sha512-davCD2Zc80nzDVRwXTcQP/28fiJbcOwvdolL0sOiOsbwBa72kegmVU0Wrh1MYrbuCL98Omp5dVhQFWRKR2ZAlg==} engines: {node: '>=18'} cpu: [arm64] os: [darwin] - '@esbuild/darwin-x64@0.21.5': - resolution: {integrity: sha512-se/JjF8NlmKVG4kNIuyWMV/22ZaerB+qaSi5MdrXtd6R08kvs2qCN4C09miupktDitvh8jRFflwGFBQcxZRjbw==} - engines: {node: '>=12'} - cpu: [x64] - os: [darwin] - '@esbuild/darwin-x64@0.27.2': resolution: {integrity: sha512-ZxtijOmlQCBWGwbVmwOF/UCzuGIbUkqB1faQRf5akQmxRJ1ujusWsb3CVfk/9iZKr2L5SMU5wPBi1UWbvL+VQA==} engines: {node: '>=18'} cpu: [x64] os: [darwin] - '@esbuild/freebsd-arm64@0.21.5': - resolution: {integrity: sha512-5JcRxxRDUJLX8JXp/wcBCy3pENnCgBR9bN6JsY4OmhfUtIHe3ZW0mawA7+RDAcMLrMIZaf03NlQiX9DGyB8h4g==} - engines: {node: '>=12'} - cpu: [arm64] - os: [freebsd] - '@esbuild/freebsd-arm64@0.27.2': resolution: {integrity: sha512-lS/9CN+rgqQ9czogxlMcBMGd+l8Q3Nj1MFQwBZJyoEKI50XGxwuzznYdwcav6lpOGv5BqaZXqvBSiB/kJ5op+g==} engines: {node: '>=18'} cpu: [arm64] os: [freebsd] - '@esbuild/freebsd-x64@0.21.5': - resolution: {integrity: sha512-J95kNBj1zkbMXtHVH29bBriQygMXqoVQOQYA+ISs0/2l3T9/kj42ow2mpqerRBxDJnmkUDCaQT/dfNXWX/ZZCQ==} - engines: {node: '>=12'} - cpu: [x64] - os: [freebsd] - '@esbuild/freebsd-x64@0.27.2': resolution: {integrity: sha512-tAfqtNYb4YgPnJlEFu4c212HYjQWSO/w/h/lQaBK7RbwGIkBOuNKQI9tqWzx7Wtp7bTPaGC6MJvWI608P3wXYA==} engines: {node: '>=18'} cpu: [x64] os: [freebsd] - '@esbuild/linux-arm64@0.21.5': - resolution: {integrity: sha512-ibKvmyYzKsBeX8d8I7MH/TMfWDXBF3db4qM6sy+7re0YXya+K1cem3on9XgdT2EQGMu4hQyZhan7TeQ8XkGp4Q==} - engines: {node: '>=12'} - cpu: [arm64] - os: [linux] - '@esbuild/linux-arm64@0.27.2': resolution: {integrity: sha512-hYxN8pr66NsCCiRFkHUAsxylNOcAQaxSSkHMMjcpx0si13t1LHFphxJZUiGwojB1a/Hd5OiPIqDdXONia6bhTw==} engines: {node: '>=18'} cpu: [arm64] os: [linux] - '@esbuild/linux-arm@0.21.5': - resolution: {integrity: sha512-bPb5AHZtbeNGjCKVZ9UGqGwo8EUu4cLq68E95A53KlxAPRmUyYv2D6F0uUI65XisGOL1hBP5mTronbgo+0bFcA==} - engines: {node: '>=12'} - cpu: [arm] - os: [linux] - '@esbuild/linux-arm@0.27.2': resolution: {integrity: sha512-vWfq4GaIMP9AIe4yj1ZUW18RDhx6EPQKjwe7n8BbIecFtCQG4CfHGaHuh7fdfq+y3LIA2vGS/o9ZBGVxIDi9hw==} engines: {node: '>=18'} cpu: [arm] os: [linux] - '@esbuild/linux-ia32@0.21.5': - resolution: {integrity: sha512-YvjXDqLRqPDl2dvRODYmmhz4rPeVKYvppfGYKSNGdyZkA01046pLWyRKKI3ax8fbJoK5QbxblURkwK/MWY18Tg==} - engines: {node: '>=12'} - cpu: [ia32] - os: [linux] - '@esbuild/linux-ia32@0.27.2': resolution: {integrity: sha512-MJt5BRRSScPDwG2hLelYhAAKh9imjHK5+NE/tvnRLbIqUWa+0E9N4WNMjmp/kXXPHZGqPLxggwVhz7QP8CTR8w==} engines: {node: '>=18'} cpu: [ia32] os: [linux] - '@esbuild/linux-loong64@0.21.5': - resolution: {integrity: sha512-uHf1BmMG8qEvzdrzAqg2SIG/02+4/DHB6a9Kbya0XDvwDEKCoC8ZRWI5JJvNdUjtciBGFQ5PuBlpEOXQj+JQSg==} - engines: {node: '>=12'} - cpu: [loong64] - os: [linux] - '@esbuild/linux-loong64@0.27.2': resolution: {integrity: sha512-lugyF1atnAT463aO6KPshVCJK5NgRnU4yb3FUumyVz+cGvZbontBgzeGFO1nF+dPueHD367a2ZXe1NtUkAjOtg==} engines: {node: '>=18'} cpu: [loong64] os: [linux] - '@esbuild/linux-mips64el@0.21.5': - resolution: {integrity: sha512-IajOmO+KJK23bj52dFSNCMsz1QP1DqM6cwLUv3W1QwyxkyIWecfafnI555fvSGqEKwjMXVLokcV5ygHW5b3Jbg==} - engines: {node: '>=12'} - cpu: [mips64el] - os: [linux] - '@esbuild/linux-mips64el@0.27.2': resolution: {integrity: sha512-nlP2I6ArEBewvJ2gjrrkESEZkB5mIoaTswuqNFRv/WYd+ATtUpe9Y09RnJvgvdag7he0OWgEZWhviS1OTOKixw==} engines: {node: '>=18'} cpu: [mips64el] os: [linux] - '@esbuild/linux-ppc64@0.21.5': - resolution: {integrity: sha512-1hHV/Z4OEfMwpLO8rp7CvlhBDnjsC3CttJXIhBi+5Aj5r+MBvy4egg7wCbe//hSsT+RvDAG7s81tAvpL2XAE4w==} - engines: {node: '>=12'} - cpu: [ppc64] - os: [linux] - '@esbuild/linux-ppc64@0.27.2': resolution: {integrity: sha512-C92gnpey7tUQONqg1n6dKVbx3vphKtTHJaNG2Ok9lGwbZil6DrfyecMsp9CrmXGQJmZ7iiVXvvZH6Ml5hL6XdQ==} engines: {node: '>=18'} cpu: [ppc64] os: [linux] - '@esbuild/linux-riscv64@0.21.5': - resolution: {integrity: sha512-2HdXDMd9GMgTGrPWnJzP2ALSokE/0O5HhTUvWIbD3YdjME8JwvSCnNGBnTThKGEB91OZhzrJ4qIIxk/SBmyDDA==} - engines: {node: '>=12'} - cpu: [riscv64] - os: [linux] - '@esbuild/linux-riscv64@0.27.2': resolution: {integrity: sha512-B5BOmojNtUyN8AXlK0QJyvjEZkWwy/FKvakkTDCziX95AowLZKR6aCDhG7LeF7uMCXEJqwa8Bejz5LTPYm8AvA==} engines: {node: '>=18'} cpu: [riscv64] os: [linux] - '@esbuild/linux-s390x@0.21.5': - resolution: {integrity: sha512-zus5sxzqBJD3eXxwvjN1yQkRepANgxE9lgOW2qLnmr8ikMTphkjgXu1HR01K4FJg8h1kEEDAqDcZQtbrRnB41A==} - engines: {node: '>=12'} - cpu: [s390x] - os: [linux] - '@esbuild/linux-s390x@0.27.2': resolution: {integrity: sha512-p4bm9+wsPwup5Z8f4EpfN63qNagQ47Ua2znaqGH6bqLlmJ4bx97Y9JdqxgGZ6Y8xVTixUnEkoKSHcpRlDnNr5w==} engines: {node: '>=18'} cpu: [s390x] os: [linux] - '@esbuild/linux-x64@0.21.5': - resolution: {integrity: sha512-1rYdTpyv03iycF1+BhzrzQJCdOuAOtaqHTWJZCWvijKD2N5Xu0TtVC8/+1faWqcP9iBCWOmjmhoH94dH82BxPQ==} - engines: {node: '>=12'} - cpu: [x64] - os: [linux] - '@esbuild/linux-x64@0.27.2': resolution: {integrity: sha512-uwp2Tip5aPmH+NRUwTcfLb+W32WXjpFejTIOWZFw/v7/KnpCDKG66u4DLcurQpiYTiYwQ9B7KOeMJvLCu/OvbA==} engines: {node: '>=18'} @@ -276,12 +171,6 @@ packages: cpu: [arm64] os: [netbsd] - '@esbuild/netbsd-x64@0.21.5': - resolution: {integrity: sha512-Woi2MXzXjMULccIwMnLciyZH4nCIMpWQAs049KEeMvOcNADVxo0UBIQPfSmxB3CWKedngg7sWZdLvLczpe0tLg==} - engines: {node: '>=12'} - cpu: [x64] - os: [netbsd] - '@esbuild/netbsd-x64@0.27.2': resolution: {integrity: sha512-HwGDZ0VLVBY3Y+Nw0JexZy9o/nUAWq9MlV7cahpaXKW6TOzfVno3y3/M8Ga8u8Yr7GldLOov27xiCnqRZf0tCA==} engines: {node: '>=18'} @@ -294,12 +183,6 @@ packages: cpu: [arm64] os: [openbsd] - '@esbuild/openbsd-x64@0.21.5': - resolution: {integrity: sha512-HLNNw99xsvx12lFBUwoT8EVCsSvRNDVxNpjZ7bPn947b8gJPzeHWyNVhFsaerc0n3TsbOINvRP2byTZ5LKezow==} - engines: {node: '>=12'} - cpu: [x64] - os: [openbsd] - '@esbuild/openbsd-x64@0.27.2': resolution: {integrity: sha512-/it7w9Nb7+0KFIzjalNJVR5bOzA9Vay+yIPLVHfIQYG/j+j9VTH84aNB8ExGKPU4AzfaEvN9/V4HV+F+vo8OEg==} engines: {node: '>=18'} @@ -312,48 +195,24 @@ packages: cpu: [arm64] os: [openharmony] - '@esbuild/sunos-x64@0.21.5': - resolution: {integrity: sha512-6+gjmFpfy0BHU5Tpptkuh8+uw3mnrvgs+dSPQXQOv3ekbordwnzTVEb4qnIvQcYXq6gzkyTnoZ9dZG+D4garKg==} - engines: {node: '>=12'} - cpu: [x64] - os: [sunos] - '@esbuild/sunos-x64@0.27.2': resolution: {integrity: sha512-kMtx1yqJHTmqaqHPAzKCAkDaKsffmXkPHThSfRwZGyuqyIeBvf08KSsYXl+abf5HDAPMJIPnbBfXvP2ZC2TfHg==} engines: {node: '>=18'} cpu: [x64] os: [sunos] - '@esbuild/win32-arm64@0.21.5': - resolution: {integrity: sha512-Z0gOTd75VvXqyq7nsl93zwahcTROgqvuAcYDUr+vOv8uHhNSKROyU961kgtCD1e95IqPKSQKH7tBTslnS3tA8A==} - engines: {node: '>=12'} - cpu: [arm64] - os: [win32] - '@esbuild/win32-arm64@0.27.2': resolution: {integrity: sha512-Yaf78O/B3Kkh+nKABUF++bvJv5Ijoy9AN1ww904rOXZFLWVc5OLOfL56W+C8F9xn5JQZa3UX6m+IktJnIb1Jjg==} engines: {node: '>=18'} cpu: [arm64] os: [win32] - '@esbuild/win32-ia32@0.21.5': - resolution: {integrity: sha512-SWXFF1CL2RVNMaVs+BBClwtfZSvDgtL//G/smwAc5oVK/UPu2Gu9tIaRgFmYFFKrmg3SyAjSrElf0TiJ1v8fYA==} - engines: {node: '>=12'} - cpu: [ia32] - os: [win32] - '@esbuild/win32-ia32@0.27.2': resolution: {integrity: sha512-Iuws0kxo4yusk7sw70Xa2E2imZU5HoixzxfGCdxwBdhiDgt9vX9VUCBhqcwY7/uh//78A1hMkkROMJq9l27oLQ==} engines: {node: '>=18'} cpu: [ia32] os: [win32] - '@esbuild/win32-x64@0.21.5': - resolution: {integrity: sha512-tQd/1efJuzPC6rCFwEvLtci/xNFcTZknmXs98FYDfGE4wP9ClFV98nyKrzJKVPMhdDnjzLhdUyMX4PsQAPjwIw==} - engines: {node: '>=12'} - cpu: [x64] - os: [win32] - '@esbuild/win32-x64@0.27.2': resolution: {integrity: sha512-sRdU18mcKf7F+YgheI/zGf5alZatMUTKj/jNS6l744f9u3WFu4v7twcUI9vu4mknF4Y9aDlblIie0IM+5xxaqQ==} engines: {node: '>=18'} @@ -414,14 +273,6 @@ packages: resolution: {integrity: sha512-bV0Tgo9K4hfPCek+aMAn81RppFKv2ySDQeMoSZuvTASywNTnVJCArCZE2FWqpvIatKu7VMRLWlR1EazvVhDyhQ==} engines: {node: '>=18.18'} - '@istanbuljs/schema@0.1.3': - resolution: {integrity: sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==} - engines: {node: '>=8'} - - '@jest/schemas@29.6.3': - resolution: {integrity: sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==} - engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} - '@jridgewell/gen-mapping@0.3.13': resolution: {integrity: sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==} @@ -545,8 +396,14 @@ packages: cpu: [x64] os: [win32] - '@sinclair/typebox@0.27.8': - resolution: {integrity: sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==} + '@standard-schema/spec@1.1.0': + resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==} + + '@types/chai@5.2.3': + resolution: {integrity: sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==} + + '@types/deep-eql@4.0.2': + resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} '@types/estree@1.0.8': resolution: {integrity: sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==} @@ -554,8 +411,8 @@ packages: '@types/json-schema@7.0.15': resolution: {integrity: sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==} - '@types/node@20.19.27': - resolution: {integrity: sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==} + '@types/node@25.0.3': + resolution: {integrity: sha512-W609buLVRVmeW693xKfzHeIV6nJGGz98uCPfeXI1ELMLXVeKYZ9m15fAMSaUPBHYLGFsVRcMmSCksQOrZV9BYA==} '@typescript-eslint/eslint-plugin@8.50.1': resolution: {integrity: sha512-PKhLGDq3JAg0Jk/aK890knnqduuI/Qj+udH7wCf0217IGi4gt+acgCyPVe79qoT+qKUvHMDQkwJeKW9fwl8Cyw==} @@ -616,35 +473,49 @@ packages: resolution: {integrity: sha512-IrDKrw7pCRUR94zeuCSUWQ+w8JEf5ZX5jl/e6AHGSLi1/zIr0lgutfn/7JpfCey+urpgQEdrZVYzCaVVKiTwhQ==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@vitest/coverage-v8@1.6.1': - resolution: {integrity: sha512-6YeRZwuO4oTGKxD3bijok756oktHSIm3eczVVzNe3scqzuhLwltIF3S9ZL/vwOVIpURmU6SnZhziXXAfw8/Qlw==} + '@vitest/coverage-v8@4.0.16': + resolution: {integrity: sha512-2rNdjEIsPRzsdu6/9Eq0AYAzYdpP6Bx9cje9tL3FE5XzXRQF1fNU9pe/1yE8fCrS0HD+fBtt6gLPh6LI57tX7A==} peerDependencies: - vitest: 1.6.1 + '@vitest/browser': 4.0.16 + vitest: 4.0.16 + peerDependenciesMeta: + '@vitest/browser': + optional: true - '@vitest/expect@1.6.1': - resolution: {integrity: sha512-jXL+9+ZNIJKruofqXuuTClf44eSpcHlgj3CiuNihUF3Ioujtmc0zIa3UJOW5RjDK1YLBJZnWBlPuqhYycLioog==} + '@vitest/expect@4.0.16': + resolution: {integrity: sha512-eshqULT2It7McaJkQGLkPjPjNph+uevROGuIMJdG3V+0BSR2w9u6J9Lwu+E8cK5TETlfou8GRijhafIMhXsimA==} - '@vitest/runner@1.6.1': - resolution: {integrity: sha512-3nSnYXkVkf3mXFfE7vVyPmi3Sazhb/2cfZGGs0JRzFsPFvAMBEcrweV1V1GsrstdXeKCTXlJbvnQwGWgEIHmOA==} + '@vitest/mocker@4.0.16': + resolution: {integrity: sha512-yb6k4AZxJTB+q9ycAvsoxGn+j/po0UaPgajllBgt1PzoMAAmJGYFdDk0uCcRcxb3BrME34I6u8gHZTQlkqSZpg==} + peerDependencies: + msw: ^2.4.9 + vite: ^6.0.0 || ^7.0.0-0 + peerDependenciesMeta: + msw: + optional: true + vite: + optional: true - '@vitest/snapshot@1.6.1': - resolution: {integrity: sha512-WvidQuWAzU2p95u8GAKlRMqMyN1yOJkGHnx3M1PL9Raf7AQ1kwLKg04ADlCa3+OXUZE7BceOhVZiuWAbzCKcUQ==} + '@vitest/pretty-format@4.0.16': + resolution: {integrity: sha512-eNCYNsSty9xJKi/UdVD8Ou16alu7AYiS2fCPRs0b1OdhJiV89buAXQLpTbe+X8V9L6qrs9CqyvU7OaAopJYPsA==} - '@vitest/spy@1.6.1': - resolution: {integrity: sha512-MGcMmpGkZebsMZhbQKkAf9CX5zGvjkBTqf8Zx3ApYWXr3wG+QvEu2eXWfnIIWYSJExIp4V9FCKDEeygzkYrXMw==} + '@vitest/runner@4.0.16': + resolution: {integrity: sha512-VWEDm5Wv9xEo80ctjORcTQRJ539EGPB3Pb9ApvVRAY1U/WkHXmmYISqU5E79uCwcW7xYUV38gwZD+RV755fu3Q==} - '@vitest/utils@1.6.1': - resolution: {integrity: sha512-jOrrUvXM4Av9ZWiG1EajNto0u96kWAhJ1LmPmJhXXQx/32MecEKd10pOLYgS2BQx1TgkGhloPU1ArDW2vvaY6g==} + '@vitest/snapshot@4.0.16': + resolution: {integrity: sha512-sf6NcrYhYBsSYefxnry+DR8n3UV4xWZwWxYbCJUt2YdvtqzSPR7VfGrY0zsv090DAbjFZsi7ZaMi1KnSRyK1XA==} + + '@vitest/spy@4.0.16': + resolution: {integrity: sha512-4jIOWjKP0ZUaEmJm00E0cOBLU+5WE0BpeNr3XN6TEF05ltro6NJqHWxXD0kA8/Zc8Nh23AT8WQxwNG+WeROupw==} + + '@vitest/utils@4.0.16': + resolution: {integrity: sha512-h8z9yYhV3e1LEfaQ3zdypIrnAg/9hguReGZoS7Gl0aBG5xgA410zBqECqmaF/+RkTggRsfnzc1XaAHA6bmUufA==} acorn-jsx@5.3.2: resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} peerDependencies: acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 - acorn-walk@8.3.4: - resolution: {integrity: sha512-ueEepnujpqee2o5aIYnvHU6C0A42MNdsIDeqy5BydrkuC5R1ZuUFnm27EeFJGoEHJQgn3uleRvmTXaJgfXbt4g==} - engines: {node: '>=0.4.0'} - acorn@8.15.0: resolution: {integrity: sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==} engines: {node: '>=0.4.0'} @@ -657,18 +528,18 @@ packages: resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} engines: {node: '>=8'} - ansi-styles@5.2.0: - resolution: {integrity: sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==} - engines: {node: '>=10'} - any-promise@1.3.0: resolution: {integrity: sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==} argparse@2.0.1: resolution: {integrity: sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==} - assertion-error@1.1.0: - resolution: {integrity: sha512-jgsaNduz+ndvGyFt3uSuWqvy4lCnIJiovtouQN5JZHOKCS2QuhEdbcQHFhVksz2N2U9hXJo8odG7ETyWlEeuDw==} + assertion-error@2.0.1: + resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} + engines: {node: '>=12'} + + ast-v8-to-istanbul@0.3.10: + resolution: {integrity: sha512-p4K7vMz2ZSk3wN8l5o3y2bJAoZXT3VuJI5OLTATY/01CYWumWvwkUw0SqDBnNq6IiTO3qDa1eSQDibAV8g7XOQ==} asynckit@0.4.0: resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} @@ -703,17 +574,14 @@ packages: resolution: {integrity: sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==} engines: {node: '>=6'} - chai@4.5.0: - resolution: {integrity: sha512-RITGBfijLkBddZvnn8jdqoTypxvqbOLYQkGGxXzeFjVHvudaPw0HNFD9x928/eUwYWd2dPCugVqspGALTZZQKw==} - engines: {node: '>=4'} + chai@6.2.2: + resolution: {integrity: sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==} + engines: {node: '>=18'} chalk@4.1.2: resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==} engines: {node: '>=10'} - check-error@1.0.3: - resolution: {integrity: sha512-iKEoDYaRmd1mxM90a2OEfWhjsjPpYPuQ+lMYsoxB126+t8fw7ySEO48nmDg5COTjxDI65/Y2OWpeEHk3ZOe8zg==} - chokidar@4.0.3: resolution: {integrity: sha512-Qgzu8kfBvo+cA4962jnP1KkS6Dop5NS6g7R5LFYJr4b8Ub94PPQXUksCw9PvXoeXPRRddRNC5C1JQUR2SMGtnA==} engines: {node: '>= 14.16.0'} @@ -756,10 +624,6 @@ packages: supports-color: optional: true - deep-eql@4.1.4: - resolution: {integrity: sha512-SUwdGfqdKOwxCPeVYjwSyRpJ7Z+fhpwIAtmCUdZIWZ/YP5R9WAsyuSgpLVDi9bjWoN2LXHNss/dk3urXtdQxGg==} - engines: {node: '>=6'} - deep-is@0.1.4: resolution: {integrity: sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==} @@ -767,10 +631,6 @@ packages: resolution: {integrity: sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==} engines: {node: '>=0.4.0'} - diff-sequences@29.6.3: - resolution: {integrity: sha512-EjePK1srD3P08o2j4f0ExnylqRs5B9tJjcp9t1krH2qRi8CCdsYfwe9JgSLurFBWwq4uOlipzfk5fHNvwFKr8Q==} - engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} - dunder-proto@1.0.1: resolution: {integrity: sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==} engines: {node: '>= 0.4'} @@ -783,6 +643,9 @@ packages: resolution: {integrity: sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==} engines: {node: '>= 0.4'} + es-module-lexer@1.7.0: + resolution: {integrity: sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==} + es-object-atoms@1.1.1: resolution: {integrity: sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==} engines: {node: '>= 0.4'} @@ -791,11 +654,6 @@ packages: resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==} engines: {node: '>= 0.4'} - esbuild@0.21.5: - resolution: {integrity: sha512-mg3OPMV4hXywwpoDxu3Qda5xCKQi+vCTZq8S9J/EpkhB2HzKXq4SNFZE3+NK93JYxc8VMSep+lOUSC/RVKaBqw==} - engines: {node: '>=12'} - hasBin: true - esbuild@0.27.2: resolution: {integrity: sha512-HyNQImnsOC7X9PMNaCIeAm4ISCQXs5a5YasTXVliKv4uuBo1dKrG0A+uQS8M5eXjVMnLg3WgXaKvprHlFJQffw==} engines: {node: '>=18'} @@ -850,9 +708,9 @@ packages: resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==} engines: {node: '>=0.10.0'} - execa@8.0.1: - resolution: {integrity: sha512-VyhnebXciFV2DESc+p6B+y0LjSm0krU4OgJN44qFAhBY0TJ+1V61tYD2+wHusZ6F9n5K+vl8k0sTy7PEfV4qpg==} - engines: {node: '>=16.17'} + expect-type@1.3.0: + resolution: {integrity: sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==} + engines: {node: '>=12.0.0'} fast-deep-equal@3.1.3: resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} @@ -903,9 +761,6 @@ packages: resolution: {integrity: sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==} engines: {node: '>= 6'} - fs.realpath@1.0.0: - resolution: {integrity: sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==} - fsevents@2.3.3: resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} @@ -914,9 +769,6 @@ packages: function-bind@1.1.2: resolution: {integrity: sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==} - get-func-name@2.0.2: - resolution: {integrity: sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==} - get-intrinsic@1.3.0: resolution: {integrity: sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==} engines: {node: '>= 0.4'} @@ -925,18 +777,10 @@ packages: resolution: {integrity: sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==} engines: {node: '>= 0.4'} - get-stream@8.0.1: - resolution: {integrity: sha512-VaUJspBffn/LMCJVoMvSAdmscJyS1auj5Zulnn5UoYcY531UWmdwhRWkcGKnGU93m5HSXP9LP2usOryrBtQowA==} - engines: {node: '>=16'} - glob-parent@6.0.2: resolution: {integrity: sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==} engines: {node: '>=10.13.0'} - glob@7.2.3: - resolution: {integrity: sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==} - deprecated: Glob versions prior to v9 are no longer supported - globals@14.0.0: resolution: {integrity: sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==} engines: {node: '>=18'} @@ -964,10 +808,6 @@ packages: html-escaper@2.0.2: resolution: {integrity: sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==} - human-signals@5.0.0: - resolution: {integrity: sha512-AXcZb6vzzrFAUE61HnN4mpLqd/cSIwNQjtNWR0euPm6y0iqx3G4gOXaIDdtdDwZmhwe82LA6+zinmW4UBWVePQ==} - engines: {node: '>=16.17.0'} - ignore@5.3.2: resolution: {integrity: sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==} engines: {node: '>= 4'} @@ -984,13 +824,6 @@ packages: resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==} engines: {node: '>=0.8.19'} - inflight@1.0.6: - resolution: {integrity: sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==} - deprecated: This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful. - - inherits@2.0.4: - resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} - is-extglob@2.1.1: resolution: {integrity: sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==} engines: {node: '>=0.10.0'} @@ -999,10 +832,6 @@ packages: resolution: {integrity: sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==} engines: {node: '>=0.10.0'} - is-stream@3.0.0: - resolution: {integrity: sha512-LnQR4bZ9IADDRSkvpqMGvt/tEJWclzklNgSw48V5EAaAeDd6qGvN8ei6k5p0tvxSR171VmGyHuTiAOfxAbr8kA==} - engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} - isexe@2.0.0: resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==} @@ -1060,10 +889,6 @@ packages: resolution: {integrity: sha512-IXO6OCs9yg8tMKzfPZ1YmheJbZCiEsnBdcB03l0OcfK9prKnJb96siuHCr5Fl37/yo9DnKU+TLpxzTUspw9shg==} engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} - local-pkg@0.5.1: - resolution: {integrity: sha512-9rrA30MRRP3gBD3HTGnC6cDFpaE1kVDWxWgqWJUN0RvDNAo+Nz/9GxB+nHOH0ifbVFy0hSA1V6vFDvnx54lTEQ==} - engines: {node: '>=14'} - locate-path@6.0.0: resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==} engines: {node: '>=10'} @@ -1071,14 +896,11 @@ packages: lodash.merge@4.6.2: resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} - loupe@2.3.7: - resolution: {integrity: sha512-zSMINGVYkdpYSOBmLi0D1Uo7JU9nVdQKrHxC8eYlV+9YKK9WePqAlL7lSlorG/U2Fw1w0hTBmaa/jrQ3UbPHtA==} - magic-string@0.30.21: resolution: {integrity: sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ==} - magicast@0.3.5: - resolution: {integrity: sha512-L0WhttDl+2BOsybvEOLK7fW3UA0OQ0IQ2d6Zl2x/a6vVRs3bAY0ECOSHHeL5jD+SbOpOCUEi0y1DgHEn9Qn1AQ==} + magicast@0.5.1: + resolution: {integrity: sha512-xrHS24IxaLrvuo613F719wvOIv9xPHFWQHuvGUBmPnCA/3MQxKI3b+r7n1jAoDHmsbC5bRhTZYR77invLAxVnw==} make-dir@4.0.0: resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==} @@ -1088,9 +910,6 @@ packages: resolution: {integrity: sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==} engines: {node: '>= 0.4'} - merge-stream@2.0.0: - resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==} - mime-db@1.52.0: resolution: {integrity: sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==} engines: {node: '>= 0.6'} @@ -1099,10 +918,6 @@ packages: resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==} engines: {node: '>= 0.6'} - mimic-fn@4.0.0: - resolution: {integrity: sha512-vqiC06CuhBTUdZH+RYl8sFrL096vA45Ok5ISO6sE/Mr1jRbGH4Csnhi8f3wKVl7x8mO4Au7Ir9D3Oyv1VYMFJw==} - engines: {node: '>=12'} - minimatch@3.1.2: resolution: {integrity: sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==} @@ -1127,20 +942,12 @@ packages: natural-compare@1.4.0: resolution: {integrity: sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==} - npm-run-path@5.3.0: - resolution: {integrity: sha512-ppwTtiJZq0O/ai0z7yfudtBpWIoxM8yE6nHi1X47eFR2EWORqfbu6CnPlNsjeN683eT0qG6H/Pyf9fCcvjnnnQ==} - engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} - object-assign@4.1.1: resolution: {integrity: sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==} engines: {node: '>=0.10.0'} - once@1.4.0: - resolution: {integrity: sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==} - - onetime@6.0.0: - resolution: {integrity: sha512-1FlR+gjXK7X+AsAHso35MnyN5KqGwJRi/31ft6x0M194ht7S+rWAvd7PHss9xSKMzE0asv1pyIHaJYq+BbacAQ==} - engines: {node: '>=12'} + obug@2.1.1: + resolution: {integrity: sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==} optionator@0.9.4: resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==} @@ -1150,10 +957,6 @@ packages: resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==} engines: {node: '>=10'} - p-limit@5.0.0: - resolution: {integrity: sha512-/Eaoq+QyLSiXQ4lyYV23f14mZRQcXnxfHrN0vCai+ak9G0pp9iEQukIIZq5NccEvwRB8PUnZT0KsOoDCINS1qQ==} - engines: {node: '>=18'} - p-locate@5.0.0: resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==} engines: {node: '>=10'} @@ -1166,27 +969,13 @@ packages: resolution: {integrity: sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==} engines: {node: '>=8'} - path-is-absolute@1.0.1: - resolution: {integrity: sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==} - engines: {node: '>=0.10.0'} - path-key@3.1.1: resolution: {integrity: sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==} engines: {node: '>=8'} - path-key@4.0.0: - resolution: {integrity: sha512-haREypq7xkM7ErfgIyA0z+Bj4AGKlMSdlQE2jvJo6huWD1EdkKYV+G/T4nq0YEF2vgTT8kqMFKo1uHn950r4SQ==} - engines: {node: '>=12'} - - pathe@1.1.2: - resolution: {integrity: sha512-whLdWMYL2TwI08hn8/ZqAbrVemu0LNaNNJZX73O6qaIdCTfXutsLhMkjdENX0qhsQ9uIimo4/aQOmXkoon2nDQ==} - pathe@2.0.3: resolution: {integrity: sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==} - pathval@1.1.1: - resolution: {integrity: sha512-Dp6zGqpTdETdR63lehJYPeIOqpiNBNtc7BpWSLrOje7UaIsE5aY92r/AunQA7rsXvet3lrJ3JnZX29UPTKXyKQ==} - picocolors@1.1.1: resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} @@ -1227,10 +1016,6 @@ packages: resolution: {integrity: sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==} engines: {node: '>= 0.8.0'} - pretty-format@29.7.0: - resolution: {integrity: sha512-Pdlw/oPxN+aXdmM9R00JVC9WVFoCLTKJvDVLgmJ+qAffBMxsV85l/Lu7sNx4zSzPyoL2euImuEwHhOXdEgNFZQ==} - engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} - proxy-from-env@1.1.0: resolution: {integrity: sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==} @@ -1238,9 +1023,6 @@ packages: resolution: {integrity: sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==} engines: {node: '>=6'} - react-is@18.3.1: - resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==} - readdirp@4.1.2: resolution: {integrity: sha512-GDhwkLfywWL2s6vEjyhri+eXmfH6j1L7JE27WhqLeYzoh/A3DBaYGEj2H/HFZCn/kMfim73FXxEJTw06WtxQwg==} engines: {node: '>= 14.18.0'} @@ -1274,10 +1056,6 @@ packages: siginfo@2.0.0: resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==} - signal-exit@4.1.0: - resolution: {integrity: sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==} - engines: {node: '>=14'} - source-map-js@1.2.1: resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==} engines: {node: '>=0.10.0'} @@ -1292,17 +1070,10 @@ packages: std-env@3.10.0: resolution: {integrity: sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==} - strip-final-newline@3.0.0: - resolution: {integrity: sha512-dOESqjYr96iWYylGObzd39EuNTa5VJxyvVAEm5Jnh7KGo75V43Hk1odPQkNDyXNmUR6k+gEiDVXnjB8HJ3crXw==} - engines: {node: '>=12'} - strip-json-comments@3.1.1: resolution: {integrity: sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==} engines: {node: '>=8'} - strip-literal@2.1.1: - resolution: {integrity: sha512-631UJ6O00eNGfMiWG78ck80dfBab8X6IVFB51jZK5Icd7XAs60Z5y7QdSd/wGIklnWvRbUNloVzhOKKmutxQ6Q==} - sucrase@3.35.1: resolution: {integrity: sha512-DhuTmvZWux4H1UOnWMB3sk0sbaCVOoQZjv8u1rDoTV0HTdGem9hkAZtl4JZy8P2z4Bg0nT+YMeOFyVr4zcG5Tw==} engines: {node: '>=16 || 14 >=14.17'} @@ -1312,10 +1083,6 @@ packages: resolution: {integrity: sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==} engines: {node: '>=8'} - test-exclude@6.0.0: - resolution: {integrity: sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w==} - engines: {node: '>=8'} - thenify-all@1.6.0: resolution: {integrity: sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==} engines: {node: '>=0.8'} @@ -1329,16 +1096,16 @@ packages: tinyexec@0.3.2: resolution: {integrity: sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==} + tinyexec@1.0.2: + resolution: {integrity: sha512-W/KYk+NFhkmsYpuHq5JykngiOCnxeVL8v8dFnqxSD8qEEdRfXk1SDM6JzNqcERbcGYj9tMrDQBYV9cjgnunFIg==} + engines: {node: '>=18'} + tinyglobby@0.2.15: resolution: {integrity: sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ==} engines: {node: '>=12.0.0'} - tinypool@0.8.4: - resolution: {integrity: sha512-i11VH5gS6IFeLY3gMBQ00/MmLncVP7JLXOw1vlgkytLmJK7QnEr7NXf0LBdxfmNPAeyetukOk0bOYrJrFGjYJQ==} - engines: {node: '>=14.0.0'} - - tinyspy@2.2.1: - resolution: {integrity: sha512-KYad6Vy5VDWV4GH3fjpseMQ/XU2BhIYP7Vzd0LG44qRWm/Yt2WCOTicFdvmgo6gWaqooMQCawTtILVQJupKu7A==} + tinyrainbow@3.0.3: + resolution: {integrity: sha512-PSkbLUoxOFRzJYjjxHJt9xro7D+iilgMX/C9lawzVuYiIdcihh9DXmVibBe8lmcFrRi/VzlPjBxbN7rH24q8/Q==} engines: {node: '>=14.0.0'} tree-kill@1.2.2: @@ -1377,10 +1144,6 @@ packages: resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==} engines: {node: '>= 0.8.0'} - type-detect@4.1.0: - resolution: {integrity: sha512-Acylog8/luQ8L7il+geoSxhEkazvkslg7PSNKOX59mbB9cOveP5aq9h74Y7YU8yDpJwetzQQrfIwtf4Wp4LKcw==} - engines: {node: '>=4'} - typescript@5.9.3: resolution: {integrity: sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==} engines: {node: '>=14.17'} @@ -1389,33 +1152,33 @@ packages: ufo@1.6.1: resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==} - undici-types@6.21.0: - resolution: {integrity: sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==} + undici-types@7.16.0: + resolution: {integrity: sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==} uri-js@4.4.1: resolution: {integrity: sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==} - vite-node@1.6.1: - resolution: {integrity: sha512-YAXkfvGtuTzwWbDSACdJSg4A4DZiAqckWe90Zapc/sEX3XvHcw1NdurM/6od8J207tSDqNbSsgdCacBgvJKFuA==} - engines: {node: ^18.0.0 || >=20.0.0} - hasBin: true - - vite@5.4.21: - resolution: {integrity: sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==} - engines: {node: ^18.0.0 || >=20.0.0} + vite@7.3.0: + resolution: {integrity: sha512-dZwN5L1VlUBewiP6H9s2+B3e3Jg96D0vzN+Ry73sOefebhYr9f94wwkMNN/9ouoU8pV1BqA1d1zGk8928cx0rg==} + engines: {node: ^20.19.0 || >=22.12.0} hasBin: true peerDependencies: - '@types/node': ^18.0.0 || >=20.0.0 - less: '*' + '@types/node': ^20.19.0 || >=22.12.0 + jiti: '>=1.21.0' + less: ^4.0.0 lightningcss: ^1.21.0 - sass: '*' - sass-embedded: '*' - stylus: '*' - sugarss: '*' - terser: ^5.4.0 + sass: ^1.70.0 + sass-embedded: ^1.70.0 + stylus: '>=0.54.8' + sugarss: ^5.0.0 + terser: ^5.16.0 + tsx: ^4.8.1 + yaml: ^2.4.2 peerDependenciesMeta: '@types/node': optional: true + jiti: + optional: true less: optional: true lightningcss: @@ -1430,24 +1193,37 @@ packages: optional: true terser: optional: true + tsx: + optional: true + yaml: + optional: true - vitest@1.6.1: - resolution: {integrity: sha512-Ljb1cnSJSivGN0LqXd/zmDbWEM0RNNg2t1QW/XUhYl/qPqyu7CsqeWtqQXHVaJsecLPuDoak2oJcZN2QoRIOag==} - engines: {node: ^18.0.0 || >=20.0.0} + vitest@4.0.16: + resolution: {integrity: sha512-E4t7DJ9pESL6E3I8nFjPa4xGUd3PmiWDLsDztS2qXSJWfHtbQnwAWylaBvSNY48I3vr8PTqIZlyK8TE3V3CA4Q==} + engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} hasBin: true peerDependencies: '@edge-runtime/vm': '*' - '@types/node': ^18.0.0 || >=20.0.0 - '@vitest/browser': 1.6.1 - '@vitest/ui': 1.6.1 + '@opentelemetry/api': ^1.9.0 + '@types/node': ^20.0.0 || ^22.0.0 || >=24.0.0 + '@vitest/browser-playwright': 4.0.16 + '@vitest/browser-preview': 4.0.16 + '@vitest/browser-webdriverio': 4.0.16 + '@vitest/ui': 4.0.16 happy-dom: '*' jsdom: '*' peerDependenciesMeta: '@edge-runtime/vm': optional: true + '@opentelemetry/api': + optional: true '@types/node': optional: true - '@vitest/browser': + '@vitest/browser-playwright': + optional: true + '@vitest/browser-preview': + optional: true + '@vitest/browser-webdriverio': optional: true '@vitest/ui': optional: true @@ -1470,24 +1246,12 @@ packages: resolution: {integrity: sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==} engines: {node: '>=0.10.0'} - wrappy@1.0.2: - resolution: {integrity: sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==} - yocto-queue@0.1.0: resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} engines: {node: '>=10'} - yocto-queue@1.2.2: - resolution: {integrity: sha512-4LCcse/U2MHZ63HAJVE+v71o7yOdIe4cZ70Wpf8D/IyjDKYQLV5GD46B+hSTjJsvV5PztjvHoU580EftxjDZFQ==} - engines: {node: '>=12.20'} - snapshots: - '@ampproject/remapping@2.3.0': - dependencies: - '@jridgewell/gen-mapping': 0.3.13 - '@jridgewell/trace-mapping': 0.3.31 - '@babel/helper-string-parser@7.27.1': {} '@babel/helper-validator-identifier@7.28.5': {} @@ -1501,152 +1265,83 @@ snapshots: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.28.5 - '@bcoe/v8-coverage@0.2.3': {} - - '@esbuild/aix-ppc64@0.21.5': - optional: true + '@bcoe/v8-coverage@1.0.2': {} '@esbuild/aix-ppc64@0.27.2': optional: true - '@esbuild/android-arm64@0.21.5': - optional: true - '@esbuild/android-arm64@0.27.2': optional: true - '@esbuild/android-arm@0.21.5': - optional: true - '@esbuild/android-arm@0.27.2': optional: true - '@esbuild/android-x64@0.21.5': - optional: true - '@esbuild/android-x64@0.27.2': optional: true - '@esbuild/darwin-arm64@0.21.5': - optional: true - '@esbuild/darwin-arm64@0.27.2': optional: true - '@esbuild/darwin-x64@0.21.5': - optional: true - '@esbuild/darwin-x64@0.27.2': optional: true - '@esbuild/freebsd-arm64@0.21.5': - optional: true - '@esbuild/freebsd-arm64@0.27.2': optional: true - '@esbuild/freebsd-x64@0.21.5': - optional: true - '@esbuild/freebsd-x64@0.27.2': optional: true - '@esbuild/linux-arm64@0.21.5': - optional: true - '@esbuild/linux-arm64@0.27.2': optional: true - '@esbuild/linux-arm@0.21.5': - optional: true - '@esbuild/linux-arm@0.27.2': optional: true - '@esbuild/linux-ia32@0.21.5': - optional: true - '@esbuild/linux-ia32@0.27.2': optional: true - '@esbuild/linux-loong64@0.21.5': - optional: true - '@esbuild/linux-loong64@0.27.2': optional: true - '@esbuild/linux-mips64el@0.21.5': - optional: true - '@esbuild/linux-mips64el@0.27.2': optional: true - '@esbuild/linux-ppc64@0.21.5': - optional: true - '@esbuild/linux-ppc64@0.27.2': optional: true - '@esbuild/linux-riscv64@0.21.5': - optional: true - '@esbuild/linux-riscv64@0.27.2': optional: true - '@esbuild/linux-s390x@0.21.5': - optional: true - '@esbuild/linux-s390x@0.27.2': optional: true - '@esbuild/linux-x64@0.21.5': - optional: true - '@esbuild/linux-x64@0.27.2': optional: true '@esbuild/netbsd-arm64@0.27.2': optional: true - '@esbuild/netbsd-x64@0.21.5': - optional: true - '@esbuild/netbsd-x64@0.27.2': optional: true '@esbuild/openbsd-arm64@0.27.2': optional: true - '@esbuild/openbsd-x64@0.21.5': - optional: true - '@esbuild/openbsd-x64@0.27.2': optional: true '@esbuild/openharmony-arm64@0.27.2': optional: true - '@esbuild/sunos-x64@0.21.5': - optional: true - '@esbuild/sunos-x64@0.27.2': optional: true - '@esbuild/win32-arm64@0.21.5': - optional: true - '@esbuild/win32-arm64@0.27.2': optional: true - '@esbuild/win32-ia32@0.21.5': - optional: true - '@esbuild/win32-ia32@0.27.2': optional: true - '@esbuild/win32-x64@0.21.5': - optional: true - '@esbuild/win32-x64@0.27.2': optional: true @@ -1707,12 +1402,6 @@ snapshots: '@humanwhocodes/retry@0.4.3': {} - '@istanbuljs/schema@0.1.3': {} - - '@jest/schemas@29.6.3': - dependencies: - '@sinclair/typebox': 0.27.8 - '@jridgewell/gen-mapping@0.3.13': dependencies: '@jridgewell/sourcemap-codec': 1.5.5 @@ -1793,15 +1482,22 @@ snapshots: '@rollup/rollup-win32-x64-msvc@4.54.0': optional: true - '@sinclair/typebox@0.27.8': {} + '@standard-schema/spec@1.1.0': {} + + '@types/chai@5.2.3': + dependencies: + '@types/deep-eql': 4.0.2 + assertion-error: 2.0.1 + + '@types/deep-eql@4.0.2': {} '@types/estree@1.0.8': {} '@types/json-schema@7.0.15': {} - '@types/node@20.19.27': + '@types/node@25.0.3': dependencies: - undici-types: 6.21.0 + undici-types: 7.16.0 '@typescript-eslint/eslint-plugin@8.50.1(@typescript-eslint/parser@8.50.1(eslint@9.39.2)(typescript@5.9.3))(eslint@9.39.2)(typescript@5.9.3)': dependencies: @@ -1894,62 +1590,66 @@ snapshots: '@typescript-eslint/types': 8.50.1 eslint-visitor-keys: 4.2.1 - '@vitest/coverage-v8@1.6.1(vitest@1.6.1(@types/node@20.19.27))': + '@vitest/coverage-v8@4.0.16(vitest@4.0.16(@types/node@25.0.3))': dependencies: - '@ampproject/remapping': 2.3.0 - '@bcoe/v8-coverage': 0.2.3 - debug: 4.4.3 + '@bcoe/v8-coverage': 1.0.2 + '@vitest/utils': 4.0.16 + ast-v8-to-istanbul: 0.3.10 istanbul-lib-coverage: 3.2.2 istanbul-lib-report: 3.0.1 istanbul-lib-source-maps: 5.0.6 istanbul-reports: 3.2.0 - magic-string: 0.30.21 - magicast: 0.3.5 - picocolors: 1.1.1 + magicast: 0.5.1 + obug: 2.1.1 std-env: 3.10.0 - strip-literal: 2.1.1 - test-exclude: 6.0.0 - vitest: 1.6.1(@types/node@20.19.27) + tinyrainbow: 3.0.3 + vitest: 4.0.16(@types/node@25.0.3) transitivePeerDependencies: - supports-color - '@vitest/expect@1.6.1': + '@vitest/expect@4.0.16': dependencies: - '@vitest/spy': 1.6.1 - '@vitest/utils': 1.6.1 - chai: 4.5.0 + '@standard-schema/spec': 1.1.0 + '@types/chai': 5.2.3 + '@vitest/spy': 4.0.16 + '@vitest/utils': 4.0.16 + chai: 6.2.2 + tinyrainbow: 3.0.3 - '@vitest/runner@1.6.1': + '@vitest/mocker@4.0.16(vite@7.3.0(@types/node@25.0.3))': dependencies: - '@vitest/utils': 1.6.1 - p-limit: 5.0.0 - pathe: 1.1.2 - - '@vitest/snapshot@1.6.1': - dependencies: - magic-string: 0.30.21 - pathe: 1.1.2 - pretty-format: 29.7.0 - - '@vitest/spy@1.6.1': - dependencies: - tinyspy: 2.2.1 - - '@vitest/utils@1.6.1': - dependencies: - diff-sequences: 29.6.3 + '@vitest/spy': 4.0.16 estree-walker: 3.0.3 - loupe: 2.3.7 - pretty-format: 29.7.0 + magic-string: 0.30.21 + optionalDependencies: + vite: 7.3.0(@types/node@25.0.3) + + '@vitest/pretty-format@4.0.16': + dependencies: + tinyrainbow: 3.0.3 + + '@vitest/runner@4.0.16': + dependencies: + '@vitest/utils': 4.0.16 + pathe: 2.0.3 + + '@vitest/snapshot@4.0.16': + dependencies: + '@vitest/pretty-format': 4.0.16 + magic-string: 0.30.21 + pathe: 2.0.3 + + '@vitest/spy@4.0.16': {} + + '@vitest/utils@4.0.16': + dependencies: + '@vitest/pretty-format': 4.0.16 + tinyrainbow: 3.0.3 acorn-jsx@5.3.2(acorn@8.15.0): dependencies: acorn: 8.15.0 - acorn-walk@8.3.4: - dependencies: - acorn: 8.15.0 - acorn@8.15.0: {} ajv@6.12.6: @@ -1963,13 +1663,17 @@ snapshots: dependencies: color-convert: 2.0.1 - ansi-styles@5.2.0: {} - any-promise@1.3.0: {} argparse@2.0.1: {} - assertion-error@1.1.0: {} + assertion-error@2.0.1: {} + + ast-v8-to-istanbul@0.3.10: + dependencies: + '@jridgewell/trace-mapping': 0.3.31 + estree-walker: 3.0.3 + js-tokens: 9.0.1 asynckit@0.4.0: {} @@ -2006,25 +1710,13 @@ snapshots: callsites@3.1.0: {} - chai@4.5.0: - dependencies: - assertion-error: 1.1.0 - check-error: 1.0.3 - deep-eql: 4.1.4 - get-func-name: 2.0.2 - loupe: 2.3.7 - pathval: 1.1.1 - type-detect: 4.1.0 + chai@6.2.2: {} chalk@4.1.2: dependencies: ansi-styles: 4.3.0 supports-color: 7.2.0 - check-error@1.0.3: - dependencies: - get-func-name: 2.0.2 - chokidar@4.0.3: dependencies: readdirp: 4.1.2 @@ -2057,16 +1749,10 @@ snapshots: dependencies: ms: 2.1.3 - deep-eql@4.1.4: - dependencies: - type-detect: 4.1.0 - deep-is@0.1.4: {} delayed-stream@1.0.0: {} - diff-sequences@29.6.3: {} - dunder-proto@1.0.1: dependencies: call-bind-apply-helpers: 1.0.2 @@ -2077,6 +1763,8 @@ snapshots: es-errors@1.3.0: {} + es-module-lexer@1.7.0: {} + es-object-atoms@1.1.1: dependencies: es-errors: 1.3.0 @@ -2088,32 +1776,6 @@ snapshots: has-tostringtag: 1.0.2 hasown: 2.0.2 - esbuild@0.21.5: - optionalDependencies: - '@esbuild/aix-ppc64': 0.21.5 - '@esbuild/android-arm': 0.21.5 - '@esbuild/android-arm64': 0.21.5 - '@esbuild/android-x64': 0.21.5 - '@esbuild/darwin-arm64': 0.21.5 - '@esbuild/darwin-x64': 0.21.5 - '@esbuild/freebsd-arm64': 0.21.5 - '@esbuild/freebsd-x64': 0.21.5 - '@esbuild/linux-arm': 0.21.5 - '@esbuild/linux-arm64': 0.21.5 - '@esbuild/linux-ia32': 0.21.5 - '@esbuild/linux-loong64': 0.21.5 - '@esbuild/linux-mips64el': 0.21.5 - '@esbuild/linux-ppc64': 0.21.5 - '@esbuild/linux-riscv64': 0.21.5 - '@esbuild/linux-s390x': 0.21.5 - '@esbuild/linux-x64': 0.21.5 - '@esbuild/netbsd-x64': 0.21.5 - '@esbuild/openbsd-x64': 0.21.5 - '@esbuild/sunos-x64': 0.21.5 - '@esbuild/win32-arm64': 0.21.5 - '@esbuild/win32-ia32': 0.21.5 - '@esbuild/win32-x64': 0.21.5 - esbuild@0.27.2: optionalDependencies: '@esbuild/aix-ppc64': 0.27.2 @@ -2215,17 +1877,7 @@ snapshots: esutils@2.0.3: {} - execa@8.0.1: - dependencies: - cross-spawn: 7.0.6 - get-stream: 8.0.1 - human-signals: 5.0.0 - is-stream: 3.0.0 - merge-stream: 2.0.0 - npm-run-path: 5.3.0 - onetime: 6.0.0 - signal-exit: 4.1.0 - strip-final-newline: 3.0.0 + expect-type@1.3.0: {} fast-deep-equal@3.1.3: {} @@ -2269,15 +1921,11 @@ snapshots: hasown: 2.0.2 mime-types: 2.1.35 - fs.realpath@1.0.0: {} - fsevents@2.3.3: optional: true function-bind@1.1.2: {} - get-func-name@2.0.2: {} - get-intrinsic@1.3.0: dependencies: call-bind-apply-helpers: 1.0.2 @@ -2296,21 +1944,10 @@ snapshots: dunder-proto: 1.0.1 es-object-atoms: 1.1.1 - get-stream@8.0.1: {} - glob-parent@6.0.2: dependencies: is-glob: 4.0.3 - glob@7.2.3: - dependencies: - fs.realpath: 1.0.0 - inflight: 1.0.6 - inherits: 2.0.4 - minimatch: 3.1.2 - once: 1.4.0 - path-is-absolute: 1.0.1 - globals@14.0.0: {} gopd@1.2.0: {} @@ -2329,8 +1966,6 @@ snapshots: html-escaper@2.0.2: {} - human-signals@5.0.0: {} - ignore@5.3.2: {} ignore@7.0.5: {} @@ -2342,21 +1977,12 @@ snapshots: imurmurhash@0.1.4: {} - inflight@1.0.6: - dependencies: - once: 1.4.0 - wrappy: 1.0.2 - - inherits@2.0.4: {} - is-extglob@2.1.1: {} is-glob@4.0.3: dependencies: is-extglob: 2.1.1 - is-stream@3.0.0: {} - isexe@2.0.0: {} istanbul-lib-coverage@3.2.2: {} @@ -2409,26 +2035,17 @@ snapshots: load-tsconfig@0.2.5: {} - local-pkg@0.5.1: - dependencies: - mlly: 1.8.0 - pkg-types: 1.3.1 - locate-path@6.0.0: dependencies: p-locate: 5.0.0 lodash.merge@4.6.2: {} - loupe@2.3.7: - dependencies: - get-func-name: 2.0.2 - magic-string@0.30.21: dependencies: '@jridgewell/sourcemap-codec': 1.5.5 - magicast@0.3.5: + magicast@0.5.1: dependencies: '@babel/parser': 7.28.5 '@babel/types': 7.28.5 @@ -2440,16 +2057,12 @@ snapshots: math-intrinsics@1.1.0: {} - merge-stream@2.0.0: {} - mime-db@1.52.0: {} mime-types@2.1.35: dependencies: mime-db: 1.52.0 - mimic-fn@4.0.0: {} - minimatch@3.1.2: dependencies: brace-expansion: 1.1.12 @@ -2477,19 +2090,9 @@ snapshots: natural-compare@1.4.0: {} - npm-run-path@5.3.0: - dependencies: - path-key: 4.0.0 - object-assign@4.1.1: {} - once@1.4.0: - dependencies: - wrappy: 1.0.2 - - onetime@6.0.0: - dependencies: - mimic-fn: 4.0.0 + obug@2.1.1: {} optionator@0.9.4: dependencies: @@ -2504,10 +2107,6 @@ snapshots: dependencies: yocto-queue: 0.1.0 - p-limit@5.0.0: - dependencies: - yocto-queue: 1.2.2 - p-locate@5.0.0: dependencies: p-limit: 3.1.0 @@ -2518,18 +2117,10 @@ snapshots: path-exists@4.0.0: {} - path-is-absolute@1.0.1: {} - path-key@3.1.1: {} - path-key@4.0.0: {} - - pathe@1.1.2: {} - pathe@2.0.3: {} - pathval@1.1.1: {} - picocolors@1.1.1: {} picomatch@4.0.3: {} @@ -2556,18 +2147,10 @@ snapshots: prelude-ls@1.2.1: {} - pretty-format@29.7.0: - dependencies: - '@jest/schemas': 29.6.3 - ansi-styles: 5.2.0 - react-is: 18.3.1 - proxy-from-env@1.1.0: {} punycode@2.3.1: {} - react-is@18.3.1: {} - readdirp@4.1.2: {} resolve-from@4.0.0: {} @@ -2612,8 +2195,6 @@ snapshots: siginfo@2.0.0: {} - signal-exit@4.1.0: {} - source-map-js@1.2.1: {} source-map@0.7.6: {} @@ -2622,14 +2203,8 @@ snapshots: std-env@3.10.0: {} - strip-final-newline@3.0.0: {} - strip-json-comments@3.1.1: {} - strip-literal@2.1.1: - dependencies: - js-tokens: 9.0.1 - sucrase@3.35.1: dependencies: '@jridgewell/gen-mapping': 0.3.13 @@ -2644,12 +2219,6 @@ snapshots: dependencies: has-flag: 4.0.0 - test-exclude@6.0.0: - dependencies: - '@istanbuljs/schema': 0.1.3 - glob: 7.2.3 - minimatch: 3.1.2 - thenify-all@1.6.0: dependencies: thenify: 3.3.1 @@ -2662,14 +2231,14 @@ snapshots: tinyexec@0.3.2: {} + tinyexec@1.0.2: {} + tinyglobby@0.2.15: dependencies: fdir: 6.5.0(picomatch@4.0.3) picomatch: 4.0.3 - tinypool@0.8.4: {} - - tinyspy@2.2.1: {} + tinyrainbow@3.0.3: {} tree-kill@1.2.2: {} @@ -2711,78 +2280,64 @@ snapshots: dependencies: prelude-ls: 1.2.1 - type-detect@4.1.0: {} - typescript@5.9.3: {} ufo@1.6.1: {} - undici-types@6.21.0: {} + undici-types@7.16.0: {} uri-js@4.4.1: dependencies: punycode: 2.3.1 - vite-node@1.6.1(@types/node@20.19.27): + vite@7.3.0(@types/node@25.0.3): dependencies: - cac: 6.7.14 - debug: 4.4.3 - pathe: 1.1.2 - picocolors: 1.1.1 - vite: 5.4.21(@types/node@20.19.27) - transitivePeerDependencies: - - '@types/node' - - less - - lightningcss - - sass - - sass-embedded - - stylus - - sugarss - - supports-color - - terser - - vite@5.4.21(@types/node@20.19.27): - dependencies: - esbuild: 0.21.5 + esbuild: 0.27.2 + fdir: 6.5.0(picomatch@4.0.3) + picomatch: 4.0.3 postcss: 8.5.6 rollup: 4.54.0 + tinyglobby: 0.2.15 optionalDependencies: - '@types/node': 20.19.27 + '@types/node': 25.0.3 fsevents: 2.3.3 - vitest@1.6.1(@types/node@20.19.27): + vitest@4.0.16(@types/node@25.0.3): dependencies: - '@vitest/expect': 1.6.1 - '@vitest/runner': 1.6.1 - '@vitest/snapshot': 1.6.1 - '@vitest/spy': 1.6.1 - '@vitest/utils': 1.6.1 - acorn-walk: 8.3.4 - chai: 4.5.0 - debug: 4.4.3 - execa: 8.0.1 - local-pkg: 0.5.1 + '@vitest/expect': 4.0.16 + '@vitest/mocker': 4.0.16(vite@7.3.0(@types/node@25.0.3)) + '@vitest/pretty-format': 4.0.16 + '@vitest/runner': 4.0.16 + '@vitest/snapshot': 4.0.16 + '@vitest/spy': 4.0.16 + '@vitest/utils': 4.0.16 + es-module-lexer: 1.7.0 + expect-type: 1.3.0 magic-string: 0.30.21 - pathe: 1.1.2 - picocolors: 1.1.1 + obug: 2.1.1 + pathe: 2.0.3 + picomatch: 4.0.3 std-env: 3.10.0 - strip-literal: 2.1.1 tinybench: 2.9.0 - tinypool: 0.8.4 - vite: 5.4.21(@types/node@20.19.27) - vite-node: 1.6.1(@types/node@20.19.27) + tinyexec: 1.0.2 + tinyglobby: 0.2.15 + tinyrainbow: 3.0.3 + vite: 7.3.0(@types/node@25.0.3) why-is-node-running: 2.3.0 optionalDependencies: - '@types/node': 20.19.27 + '@types/node': 25.0.3 transitivePeerDependencies: + - jiti - less - lightningcss + - msw - sass - sass-embedded - stylus - sugarss - - supports-color - terser + - tsx + - yaml which@2.0.2: dependencies: @@ -2795,8 +2350,4 @@ snapshots: word-wrap@1.2.5: {} - wrappy@1.0.2: {} - yocto-queue@0.1.0: {} - - yocto-queue@1.2.2: {} diff --git a/sdks/nodejs-client/pnpm-workspace.yaml b/sdks/nodejs-client/pnpm-workspace.yaml new file mode 100644 index 0000000000..efc037aa84 --- /dev/null +++ b/sdks/nodejs-client/pnpm-workspace.yaml @@ -0,0 +1,2 @@ +onlyBuiltDependencies: + - esbuild diff --git a/web/.dockerignore b/web/.dockerignore index 31eb66c210..91437a2259 100644 --- a/web/.dockerignore +++ b/web/.dockerignore @@ -7,8 +7,12 @@ logs # node node_modules +dist +build +coverage .husky .next +.pnpm-store # vscode .vscode @@ -22,3 +26,7 @@ node_modules # Jetbrains .idea + +# git +.git +.gitignore \ No newline at end of file diff --git a/web/.env.example b/web/.env.example index b488c31057..df4e725c51 100644 --- a/web/.env.example +++ b/web/.env.example @@ -31,6 +31,8 @@ NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false # The timeout for the text generation in millisecond NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 +# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS at container startup (Docker only) +TEXT_GENERATION_TIMEOUT_MS=60000 # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP NEXT_PUBLIC_CSP_WHITELIST= @@ -47,6 +49,8 @@ NEXT_PUBLIC_TOP_K_MAX_VALUE=10 # The maximum number of tokens for segmentation NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 +# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH at container startup (Docker only) +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 # Maximum loop count in the workflow NEXT_PUBLIC_LOOP_NODE_MAX_COUNT=100 @@ -73,3 +77,6 @@ NEXT_PUBLIC_MAX_TREE_DEPTH=50 # The API key of amplitude NEXT_PUBLIC_AMPLITUDE_API_KEY= + +# number of concurrency +NEXT_PUBLIC_BATCH_CONCURRENCY=5 diff --git a/web/.gitignore b/web/.gitignore index 9de3dc83f9..a4ae324795 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -64,3 +64,5 @@ public/fallback-*.js .vscode/settings.json .vscode/mcp.json + +.eslintcache diff --git a/web/.npmrc b/web/.npmrc new file mode 100644 index 0000000000..cffe8cdef1 --- /dev/null +++ b/web/.npmrc @@ -0,0 +1 @@ +save-exact=true diff --git a/web/.nvmrc b/web/.nvmrc new file mode 100644 index 0000000000..a45fd52cc5 --- /dev/null +++ b/web/.nvmrc @@ -0,0 +1 @@ +24 diff --git a/web/.storybook/main.ts b/web/.storybook/main.ts index ca56261431..918860c786 100644 --- a/web/.storybook/main.ts +++ b/web/.storybook/main.ts @@ -1,27 +1,15 @@ -import type { StorybookConfig } from '@storybook/nextjs' -import path from 'node:path' -import { fileURLToPath } from 'node:url' - -const storybookDir = path.dirname(fileURLToPath(import.meta.url)) +import type { StorybookConfig } from '@storybook/nextjs-vite' const config: StorybookConfig = { stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'], addons: [ - '@storybook/addon-onboarding', + // Not working with Storybook Vite framework + // '@storybook/addon-onboarding', '@storybook/addon-links', '@storybook/addon-docs', '@chromatic-com/storybook', ], - framework: { - name: '@storybook/nextjs', - options: { - builder: { - useSWC: true, - lazyCompilation: false, - }, - nextConfigPath: undefined, - }, - }, + framework: '@storybook/nextjs-vite', staticDirs: ['../public'], core: { disableWhatsNewNotifications: true, @@ -29,17 +17,5 @@ const config: StorybookConfig = { docs: { defaultName: 'Documentation', }, - webpackFinal: async (config) => { - // Add alias to mock problematic modules with circular dependencies - config.resolve = config.resolve || {} - config.resolve.alias = { - ...config.resolve.alias, - // Mock the plugin index files to avoid circular dependencies - [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/context-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/context-block.tsx'), - [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/history-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/history-block.tsx'), - [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/query-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/query-block.tsx'), - } - return config - }, } export default config diff --git a/web/.storybook/preview.tsx b/web/.storybook/preview.tsx index 37c636cc75..5b38424776 100644 --- a/web/.storybook/preview.tsx +++ b/web/.storybook/preview.tsx @@ -1,8 +1,10 @@ import type { Preview } from '@storybook/react' +import type { Resource } from 'i18next' import { withThemeByDataAttribute } from '@storybook/addon-themes' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { ToastProvider } from '../app/components/base/toast' -import I18N from '../app/components/i18n' +import { I18nClientProvider as I18N } from '../app/components/provider/i18n' +import commonEnUS from '../i18n/en-US/common.json' import '../app/styles/globals.css' import '../app/styles/markdown.scss' @@ -16,6 +18,14 @@ const queryClient = new QueryClient({ }, }) +const storyResources: Resource = { + 'en-US': { + // Preload the most common namespace to avoid missing keys during initial render; + // other namespaces will be loaded on demand via resourcesToBackend. + common: commonEnUS as unknown as Record, + }, +} + export const decorators = [ withThemeByDataAttribute({ themes: { @@ -28,7 +38,7 @@ export const decorators = [ (Story) => { return ( - + diff --git a/web/.vscode/extensions.json b/web/.vscode/extensions.json index 68f5c7bf0e..01235650ab 100644 --- a/web/.vscode/extensions.json +++ b/web/.vscode/extensions.json @@ -1,6 +1,8 @@ { "recommendations": [ "bradlc.vscode-tailwindcss", - "kisstkondoros.vscode-codemetrics" + "kisstkondoros.vscode-codemetrics", + "johnsoncodehk.vscode-tsslint", + "dbaeumer.vscode-eslint" ] } diff --git a/web/Dockerfile b/web/Dockerfile index f24e9f2fc3..d71b1b6ba6 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -1,5 +1,5 @@ # base image -FROM node:22-alpine3.21 AS base +FROM node:24-alpine AS base LABEL maintainer="takatost@gmail.com" # if you located in China, you can use aliyun mirror to speed up @@ -12,7 +12,8 @@ RUN apk add --no-cache tzdata RUN corepack enable ENV PNPM_HOME="/pnpm" ENV PATH="$PNPM_HOME:$PATH" -ENV NEXT_PUBLIC_BASE_PATH="" +ARG NEXT_PUBLIC_BASE_PATH="" +ENV NEXT_PUBLIC_BASE_PATH="$NEXT_PUBLIC_BASE_PATH" # install packages diff --git a/web/README.md b/web/README.md index 7f5740a471..9c731a081a 100644 --- a/web/README.md +++ b/web/README.md @@ -8,8 +8,18 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next Before starting the web frontend service, please make sure the following environment is ready. -- [Node.js](https://nodejs.org) >= v22.11.x -- [pnpm](https://pnpm.io) v10.x +- [Node.js](https://nodejs.org) +- [pnpm](https://pnpm.io) + +> [!TIP] +> It is recommended to install and enable Corepack to manage package manager versions automatically: +> +> ```bash +> npm install -g corepack +> corepack enable +> ``` +> +> Learn more: [Corepack](https://github.com/nodejs/corepack#readme) First, install the dependencies: @@ -128,7 +138,7 @@ This will help you determine the testing strategy. See [web/testing/testing.md]( ## Documentation -Visit to view the full documentation. +Visit to view the full documentation. ## Community diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts index c69a2ad1d2..373c2f86d3 100644 --- a/web/__mocks__/provider-context.ts +++ b/web/__mocks__/provider-context.ts @@ -1,6 +1,7 @@ import type { Plan, UsagePlanInfo } from '@/app/components/billing/type' import type { ProviderContextState } from '@/context/provider-context' -import { merge, noop } from 'es-toolkit/compat' +import { merge } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import { defaultPlan } from '@/app/components/billing/config' // Avoid being mocked in tests diff --git a/web/__mocks__/zustand.ts b/web/__mocks__/zustand.ts new file mode 100644 index 0000000000..7255e5ef86 --- /dev/null +++ b/web/__mocks__/zustand.ts @@ -0,0 +1,56 @@ +import type * as ZustandExportedTypes from 'zustand' +import { act } from '@testing-library/react' + +export * from 'zustand' + +const { create: actualCreate, createStore: actualCreateStore } + // eslint-disable-next-line antfu/no-top-level-await + = await vi.importActual('zustand') + +export const storeResetFns = new Set<() => void>() + +const createUncurried = ( + stateCreator: ZustandExportedTypes.StateCreator, +) => { + const store = actualCreate(stateCreator) + const initialState = store.getInitialState() + storeResetFns.add(() => { + store.setState(initialState, true) + }) + return store +} + +export const create = (( + stateCreator: ZustandExportedTypes.StateCreator, +) => { + return typeof stateCreator === 'function' + ? createUncurried(stateCreator) + : createUncurried +}) as typeof ZustandExportedTypes.create + +const createStoreUncurried = ( + stateCreator: ZustandExportedTypes.StateCreator, +) => { + const store = actualCreateStore(stateCreator) + const initialState = store.getInitialState() + storeResetFns.add(() => { + store.setState(initialState, true) + }) + return store +} + +export const createStore = (( + stateCreator: ZustandExportedTypes.StateCreator, +) => { + return typeof stateCreator === 'function' + ? createStoreUncurried(stateCreator) + : createStoreUncurried +}) as typeof ZustandExportedTypes.createStore + +afterEach(() => { + act(() => { + storeResetFns.forEach((resetFn) => { + resetFn() + }) + }) +}) diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts index a6f86d8107..9f573bda10 100644 --- a/web/__tests__/check-i18n.test.ts +++ b/web/__tests__/check-i18n.test.ts @@ -3,7 +3,7 @@ import path from 'node:path' import vm from 'node:vm' import { transpile } from 'typescript' -describe('check-i18n script functionality', () => { +describe('i18n:check script functionality', () => { const testDir = path.join(__dirname, '../i18n-test') const testEnDir = path.join(testDir, 'en-US') const testZhDir = path.join(testDir, 'zh-Hans') diff --git a/web/__tests__/embedded-user-id-store.test.tsx b/web/__tests__/embedded-user-id-store.test.tsx index 276b22bcd7..901218e76b 100644 --- a/web/__tests__/embedded-user-id-store.test.tsx +++ b/web/__tests__/embedded-user-id-store.test.tsx @@ -53,6 +53,7 @@ vi.mock('@/context/global-public-context', () => { ) return { useGlobalPublicStore, + useIsSystemFeaturesPending: () => false, } }) diff --git a/web/__tests__/goto-anything/search-error-handling.test.ts b/web/__tests__/goto-anything/search-error-handling.test.ts index 3a495834cd..42eb829583 100644 --- a/web/__tests__/goto-anything/search-error-handling.test.ts +++ b/web/__tests__/goto-anything/search-error-handling.test.ts @@ -14,6 +14,14 @@ import { fetchAppList } from '@/service/apps' import { postMarketplace } from '@/service/base' import { fetchDatasets } from '@/service/datasets' +// Mock react-i18next before importing modules that use it +vi.mock('react-i18next', () => ({ + getI18n: () => ({ + t: (key: string) => key, + language: 'en', + }), +})) + // Mock API functions vi.mock('@/service/base', () => ({ postMarketplace: vi.fn(), diff --git a/web/__tests__/i18n-upload-features.test.ts b/web/__tests__/i18n-upload-features.test.ts index af418262d6..3d0f1d30dd 100644 --- a/web/__tests__/i18n-upload-features.test.ts +++ b/web/__tests__/i18n-upload-features.test.ts @@ -16,7 +16,7 @@ const getSupportedLocales = (): string[] => { // Helper function to load translation file content const loadTranslationContent = (locale: string): string => { - const filePath = path.join(I18N_DIR, locale, 'app-debug.ts') + const filePath = path.join(I18N_DIR, locale, 'app-debug.json') if (!fs.existsSync(filePath)) throw new Error(`Translation file not found: ${filePath}`) @@ -24,14 +24,14 @@ const loadTranslationContent = (locale: string): string => { return fs.readFileSync(filePath, 'utf-8') } -// Helper function to check if upload features exist +// Helper function to check if upload features exist (supports flattened JSON) const hasUploadFeatures = (content: string): { [key: string]: boolean } => { return { - fileUpload: /fileUpload\s*:\s*\{/.test(content), - imageUpload: /imageUpload\s*:\s*\{/.test(content), - documentUpload: /documentUpload\s*:\s*\{/.test(content), - audioUpload: /audioUpload\s*:\s*\{/.test(content), - featureBar: /bar\s*:\s*\{/.test(content), + fileUpload: /"feature\.fileUpload\.title"/.test(content), + imageUpload: /"feature\.imageUpload\.title"/.test(content), + documentUpload: /"feature\.documentUpload\.title"/.test(content), + audioUpload: /"feature\.audioUpload\.title"/.test(content), + featureBar: /"feature\.bar\.empty"/.test(content), } } @@ -45,7 +45,7 @@ describe('Upload Features i18n Translations - Issue #23062', () => { it('all locales should have translation files', () => { supportedLocales.forEach((locale) => { - const filePath = path.join(I18N_DIR, locale, 'app-debug.ts') + const filePath = path.join(I18N_DIR, locale, 'app-debug.json') expect(fs.existsSync(filePath)).toBe(true) }) }) @@ -76,12 +76,9 @@ describe('Upload Features i18n Translations - Issue #23062', () => { previouslyMissingLocales.forEach((locale) => { const content = loadTranslationContent(locale) - // Verify audioUpload exists - expect(/audioUpload\s*:\s*\{/.test(content)).toBe(true) - - // Verify it has title and description - expect(/audioUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/audioUpload[^}]*description\s*:/.test(content)).toBe(true) + // Verify audioUpload exists with title and description (flattened JSON format) + expect(/"feature\.audioUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.audioUpload\.description"/.test(content)).toBe(true) console.log(`✅ ${locale} - Issue #23062 resolved: audioUpload feature present`) }) @@ -91,28 +88,28 @@ describe('Upload Features i18n Translations - Issue #23062', () => { supportedLocales.forEach((locale) => { const content = loadTranslationContent(locale) - // Check fileUpload has required properties - if (/fileUpload\s*:\s*\{/.test(content)) { - expect(/fileUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/fileUpload[^}]*description\s*:/.test(content)).toBe(true) + // Check fileUpload has required properties (flattened JSON format) + if (/"feature\.fileUpload\.title"/.test(content)) { + expect(/"feature\.fileUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.fileUpload\.description"/.test(content)).toBe(true) } // Check imageUpload has required properties - if (/imageUpload\s*:\s*\{/.test(content)) { - expect(/imageUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/imageUpload[^}]*description\s*:/.test(content)).toBe(true) + if (/"feature\.imageUpload\.title"/.test(content)) { + expect(/"feature\.imageUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.imageUpload\.description"/.test(content)).toBe(true) } // Check documentUpload has required properties - if (/documentUpload\s*:\s*\{/.test(content)) { - expect(/documentUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/documentUpload[^}]*description\s*:/.test(content)).toBe(true) + if (/"feature\.documentUpload\.title"/.test(content)) { + expect(/"feature\.documentUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.documentUpload\.description"/.test(content)).toBe(true) } // Check audioUpload has required properties - if (/audioUpload\s*:\s*\{/.test(content)) { - expect(/audioUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/audioUpload[^}]*description\s*:/.test(content)).toBe(true) + if (/"feature\.audioUpload\.title"/.test(content)) { + expect(/"feature\.audioUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.audioUpload\.description"/.test(content)).toBe(true) } }) }) diff --git a/web/__tests__/workflow-parallel-limit.test.tsx b/web/__tests__/workflow-parallel-limit.test.tsx index 18657f4bd2..ba3840ac3e 100644 --- a/web/__tests__/workflow-parallel-limit.test.tsx +++ b/web/__tests__/workflow-parallel-limit.test.tsx @@ -64,7 +64,6 @@ vi.mock('i18next', () => ({ // Mock the useConfig hook vi.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({ - __esModule: true, default: () => ({ inputs: { is_parallel: true, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index c1feb3ea5d..470f4477fa 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -70,7 +70,7 @@ const AppDetailLayout: FC = (props) => { const navConfig = [ ...(isCurrentWorkspaceEditor ? [{ - name: t('common.appMenus.promptEng'), + name: t('appMenus.promptEng', { ns: 'common' }), href: `/app/${appId}/${(mode === AppModeEnum.WORKFLOW || mode === AppModeEnum.ADVANCED_CHAT) ? 'workflow' : 'configuration'}`, icon: RiTerminalWindowLine, selectedIcon: RiTerminalWindowFill, @@ -78,7 +78,7 @@ const AppDetailLayout: FC = (props) => { : [] ), { - name: t('common.appMenus.apiAccess'), + name: t('appMenus.apiAccess', { ns: 'common' }), href: `/app/${appId}/develop`, icon: RiTerminalBoxLine, selectedIcon: RiTerminalBoxFill, @@ -86,8 +86,8 @@ const AppDetailLayout: FC = (props) => { ...(isCurrentWorkspaceEditor ? [{ name: mode !== AppModeEnum.WORKFLOW - ? t('common.appMenus.logAndAnn') - : t('common.appMenus.logs'), + ? t('appMenus.logAndAnn', { ns: 'common' }) + : t('appMenus.logs', { ns: 'common' }), href: `/app/${appId}/logs`, icon: RiFileList3Line, selectedIcon: RiFileList3Fill, @@ -95,7 +95,7 @@ const AppDetailLayout: FC = (props) => { : [] ), { - name: t('common.appMenus.overview'), + name: t('appMenus.overview', { ns: 'common' }), href: `/app/${appId}/overview`, icon: RiDashboard2Line, selectedIcon: RiDashboard2Fill, @@ -104,7 +104,7 @@ const AppDetailLayout: FC = (props) => { return navConfig }, [t]) - useDocumentTitle(appDetail?.name || t('common.menus.appDetail')) + useDocumentTitle(appDetail?.name || t('menus.appDetail', { ns: 'common' })) useEffect(() => { if (appDetail) { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index 34bb0f82f8..f07b2932c9 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -4,7 +4,7 @@ import type { IAppCardProps } from '@/app/components/app/overview/app-card' import type { BlockEnum } from '@/app/components/workflow/types' import type { UpdateAppSiteCodeResponse } from '@/models/app' import type { App } from '@/types/app' -import * as React from 'react' +import type { I18nKeysByPrefix } from '@/types/i18n' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -16,7 +16,6 @@ import { ToastContext } from '@/app/components/base/toast' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' import { isTriggerNode } from '@/app/components/workflow/types' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' -import { useDocLink } from '@/context/i18n' import { fetchAppDetail, updateAppSiteAccessToken, @@ -35,7 +34,6 @@ export type ICardViewProps = { const CardView: FC = ({ appId, isInPanel, className }) => { const { t } = useTranslation() - const docLink = useDocLink() const { notify } = useContext(ToastContext) const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) @@ -58,32 +56,22 @@ const CardView: FC = ({ appId, isInPanel, className }) => { const shouldRenderAppCards = !isWorkflowApp || hasTriggerNode === false const disableAppCards = !shouldRenderAppCards - const triggerDocUrl = docLink('/guides/workflow/node/start') const buildTriggerModeMessage = useCallback((featureName: string) => (
- {t('appOverview.overview.disableTooltip.triggerMode', { feature: featureName })} -
-
{ - event.stopPropagation() - window.open(triggerDocUrl, '_blank') - }} - > - {t('appOverview.overview.appInfo.enableTooltip.learnMore')} + {t('overview.disableTooltip.triggerMode', { ns: 'appOverview', feature: featureName })}
- ), [t, triggerDocUrl]) + ), [t]) const disableWebAppTooltip = disableAppCards - ? buildTriggerModeMessage(t('appOverview.overview.appInfo.title')) + ? buildTriggerModeMessage(t('overview.appInfo.title', { ns: 'appOverview' })) : null const disableApiTooltip = disableAppCards - ? buildTriggerModeMessage(t('appOverview.overview.apiInfo.title')) + ? buildTriggerModeMessage(t('overview.apiInfo.title', { ns: 'appOverview' })) : null const disableMcpTooltip = disableAppCards - ? buildTriggerModeMessage(t('tools.mcp.server.title')) + ? buildTriggerModeMessage(t('mcp.server.title', { ns: 'tools' })) : null const updateAppDetail = async () => { @@ -94,7 +82,7 @@ const CardView: FC = ({ appId, isInPanel, className }) => { catch (error) { console.error(error) } } - const handleCallbackResult = (err: Error | null, message?: string) => { + const handleCallbackResult = (err: Error | null, message?: I18nKeysByPrefix<'common', 'actionMsg.'>) => { const type = err ? 'error' : 'success' message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully') @@ -104,7 +92,7 @@ const CardView: FC = ({ appId, isInPanel, className }) => { notify({ type, - message: t(`common.actionMsg.${message}` as any) as string, + message: t(`actionMsg.${message}`, { ns: 'common' }) as string, }) } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx index e1fca90c5d..b6e902f456 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx @@ -1,5 +1,6 @@ 'use client' import type { PeriodParams } from '@/app/components/app/overview/app-chart' +import type { I18nKeysByPrefix } from '@/types/i18n' import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import * as React from 'react' @@ -16,7 +17,9 @@ dayjs.extend(quarterOfYear) const today = dayjs() -const TIME_PERIOD_MAPPING = [ +type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'> + +const TIME_PERIOD_MAPPING: { value: number, name: TimePeriodName }[] = [ { value: 0, name: 'today' }, { value: 7, name: 'last7days' }, { value: 30, name: 'last30days' }, @@ -35,8 +38,8 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) { const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow' const isWorkflow = appDetail?.mode === 'workflow' const [period, setPeriod] = useState(IS_CLOUD_EDITION - ? { name: t('appLog.filter.period.today'), query: { start: today.startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } } - : { name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }, + ? { name: t('filter.period.today', { ns: 'appLog' }), query: { start: today.startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } } + : { name: t('filter.period.last7days', { ns: 'appLog' }), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }, ) if (!appDetail) @@ -45,7 +48,7 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) { return (
-
{t('common.appMenus.overview')}
+
{t('appMenus.overview', { ns: 'common' })}
{IS_CLOUD_EDITION ? ( diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx index 55ed906a45..f7178d7ac2 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx @@ -2,13 +2,16 @@ import type { FC } from 'react' import type { PeriodParams } from '@/app/components/app/overview/app-chart' import type { Item } from '@/app/components/base/select' +import type { I18nKeysByPrefix } from '@/types/i18n' import dayjs from 'dayjs' import * as React from 'react' import { useTranslation } from 'react-i18next' import { SimpleSelect } from '@/app/components/base/select' +type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'> + type Props = { - periodMapping: { [key: string]: { value: number, name: string } } + periodMapping: { [key: string]: { value: number, name: TimePeriodName } } onSelect: (payload: PeriodParams) => void queryDateFormat: string } @@ -25,9 +28,9 @@ const LongTimeRangePicker: FC = ({ const handleSelect = React.useCallback((item: Item) => { const id = item.value const value = periodMapping[id]?.value ?? '-1' - const name = item.name || t('appLog.filter.period.allTime') + const name = item.name || t('filter.period.allTime', { ns: 'appLog' }) if (value === -1) { - onSelect({ name: t('appLog.filter.period.allTime'), query: undefined }) + onSelect({ name: t('filter.period.allTime', { ns: 'appLog' }), query: undefined }) } else if (value === 0) { const startOfToday = today.startOf('day').format(queryDateFormat) @@ -53,7 +56,7 @@ const LongTimeRangePicker: FC = ({ return ( ({ value: k, name: t(`appLog.filter.period.${v.name}` as any) as string }))} + items={Object.entries(periodMapping).map(([k, v]) => ({ value: k, name: t(`filter.period.${v.name}`, { ns: 'appLog' }) }))} className="mt-0 !w-40" notClearable={true} onSelect={handleSelect} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx index 004f83afc5..368c3dcfc3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx @@ -4,11 +4,11 @@ import type { FC } from 'react' import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types' import { RiCalendarLine } from '@remixicon/react' import dayjs from 'dayjs' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import * as React from 'react' import { useCallback } from 'react' import Picker from '@/app/components/base/date-and-time-picker/date-picker' -import { useI18N } from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { cn } from '@/utils/classnames' import { formatToLocalTime } from '@/utils/format' @@ -26,7 +26,7 @@ const DatePicker: FC = ({ onStartChange, onEndChange, }) => { - const { locale } = useI18N() + const locale = useLocale() const renderDate = useCallback(({ value, handleClickTrigger, isOpen }: TriggerProps) => { return ( diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx index 469bc97737..53794ad8db 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx @@ -2,19 +2,22 @@ import type { Dayjs } from 'dayjs' import type { FC } from 'react' import type { PeriodParams, PeriodParamsWithTimeRange } from '@/app/components/app/overview/app-chart' +import type { I18nKeysByPrefix } from '@/types/i18n' import dayjs from 'dayjs' import * as React from 'react' import { useCallback, useState } from 'react' import { HourglassShape } from '@/app/components/base/icons/src/vender/other' -import { useI18N } from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { formatToLocalTime } from '@/utils/format' import DatePicker from './date-picker' import RangeSelector from './range-selector' const today = dayjs() +type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'> + type Props = { - ranges: { value: number, name: string }[] + ranges: { value: number, name: TimePeriodName }[] onSelect: (payload: PeriodParams) => void queryDateFormat: string } @@ -24,7 +27,7 @@ const TimeRangePicker: FC = ({ onSelect, queryDateFormat, }) => { - const { locale } = useI18N() + const locale = useLocale() const [isCustomRange, setIsCustomRange] = useState(false) const [start, setStart] = useState(today) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx index 88cb79ce0d..986170728f 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx @@ -2,6 +2,7 @@ import type { FC } from 'react' import type { PeriodParamsWithTimeRange, TimeRange } from '@/app/components/app/overview/app-chart' import type { Item } from '@/app/components/base/select' +import type { I18nKeysByPrefix } from '@/types/i18n' import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' import dayjs from 'dayjs' import * as React from 'react' @@ -12,9 +13,11 @@ import { cn } from '@/utils/classnames' const today = dayjs() +type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'> + type Props = { isCustomRange: boolean - ranges: { value: number, name: string }[] + ranges: { value: number, name: TimePeriodName }[] onSelect: (payload: PeriodParamsWithTimeRange) => void } @@ -42,7 +45,7 @@ const RangeSelector: FC = ({ const renderTrigger = useCallback((item: Item | null, isOpen: boolean) => { return (
-
{isCustomRange ? t('appLog.filter.period.custom') : item?.name}
+
{isCustomRange ? t('filter.period.custom', { ns: 'appLog' }) : item?.name}
) @@ -66,7 +69,7 @@ const RangeSelector: FC = ({ }, []) return ( ({ ...v, name: t(`appLog.filter.period.${v.name}` as any) as string }))} + items={ranges.map(v => ({ ...v, name: t(`filter.period.${v.name}`, { ns: 'appLog' }) }))} className="mt-0 !w-40" notClearable={true} onSelect={handleSelectRange} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 35ab8a61ec..4469459b52 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -15,7 +15,7 @@ import ProviderPanel from './provider-panel' import TracingIcon from './tracing-icon' import { TracingProvider } from './type' -const I18N_PREFIX = 'app.tracing' +const I18N_PREFIX = 'tracing' export type PopupProps = { appId: string @@ -327,19 +327,19 @@ const ConfigPopup: FC = ({
-
{t(`${I18N_PREFIX}.tracing`)}
+
{t(`${I18N_PREFIX}.tracing`, { ns: 'app' })}
- {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`)} + {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`, { ns: 'app' })}
{!readOnly && ( <> {providerAllNotConfigured ? ( {switchContent} @@ -351,14 +351,14 @@ const ConfigPopup: FC = ({
- {t(`${I18N_PREFIX}.tracingDescription`)} + {t(`${I18N_PREFIX}.tracingDescription`, { ns: 'app' })}
{(providerAllConfigured || providerAllNotConfigured) ? ( <> -
{t(`${I18N_PREFIX}.configProviderTitle.${providerAllConfigured ? 'configured' : 'notConfigured'}`)}
+
{t(`${I18N_PREFIX}.configProviderTitle.${providerAllConfigured ? 'configured' : 'notConfigured'}`, { ns: 'app' })}
{langfusePanel} {langSmithPanel} @@ -375,11 +375,11 @@ const ConfigPopup: FC = ({ ) : ( <> -
{t(`${I18N_PREFIX}.configProviderTitle.configured`)}
+
{t(`${I18N_PREFIX}.configProviderTitle.configured`, { ns: 'app' })}
{configuredProviderPanel()}
-
{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`)}
+
{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`, { ns: 'app' })}
{moreProviderPanel()}
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 438dbddfb8..5e7d98d191 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -23,7 +23,7 @@ import ConfigButton from './config-button' import TracingIcon from './tracing-icon' import { TracingProvider } from './type' -const I18N_PREFIX = 'app.tracing' +const I18N_PREFIX = 'tracing' const Panel: FC = () => { const { t } = useTranslation() @@ -45,7 +45,7 @@ const Panel: FC = () => { if (!noToast) { Toast.notify({ type: 'success', - message: t('common.api.success'), + message: t('api.success', { ns: 'common' }), }) } } @@ -254,7 +254,7 @@ const Panel: FC = () => { )} > -
{t(`${I18N_PREFIX}.title`)}
+
{t(`${I18N_PREFIX}.title`, { ns: 'app' })}
@@ -295,7 +295,7 @@ const Panel: FC = () => {
- {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`)} + {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`, { ns: 'app' })}
{InUseProviderIcon && } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index 0ef97e6970..ff78712c3c 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -30,7 +30,7 @@ type Props = { onChosen: (provider: TracingProvider) => void } -const I18N_PREFIX = 'app.tracing.configProvider' +const I18N_PREFIX = 'tracing.configProvider' const arizeConfigTemplate = { api_key: '', @@ -157,7 +157,7 @@ const ProviderConfigModal: FC = ({ }) Toast.notify({ type: 'success', - message: t('common.api.remove'), + message: t('api.remove', { ns: 'common' }), }) onRemoved() hideRemoveConfirm() @@ -177,37 +177,37 @@ const ProviderConfigModal: FC = ({ if (type === TracingProvider.arize) { const postData = config as ArizeConfig if (!postData.api_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' }) if (!postData.space_id) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Space ID' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Space ID' }) if (!errorMessage && !postData.project) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) }) } if (type === TracingProvider.phoenix) { const postData = config as PhoenixConfig if (!postData.api_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' }) if (!errorMessage && !postData.project) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) }) } if (type === TracingProvider.langSmith) { const postData = config as LangSmithConfig if (!postData.api_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' }) if (!errorMessage && !postData.project) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) }) } if (type === TracingProvider.langfuse) { const postData = config as LangFuseConfig if (!errorMessage && !postData.secret_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.secretKey`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.secretKey`, { ns: 'app' }) }) if (!errorMessage && !postData.public_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.publicKey`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.publicKey`, { ns: 'app' }) }) if (!errorMessage && !postData.host) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Host' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Host' }) } if (type === TracingProvider.opik) { @@ -218,43 +218,43 @@ const ProviderConfigModal: FC = ({ if (type === TracingProvider.weave) { const postData = config as WeaveConfig if (!errorMessage && !postData.api_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' }) if (!errorMessage && !postData.project) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) }) } if (type === TracingProvider.aliyun) { const postData = config as AliyunConfig if (!errorMessage && !postData.app_name) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'App Name' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'App Name' }) if (!errorMessage && !postData.license_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'License Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'License Key' }) if (!errorMessage && !postData.endpoint) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Endpoint' }) } if (type === TracingProvider.mlflow) { const postData = config as MLflowConfig if (!errorMessage && !postData.tracking_uri) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Tracking URI' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Tracking URI' }) } if (type === TracingProvider.databricks) { const postData = config as DatabricksConfig if (!errorMessage && !postData.experiment_id) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Experiment ID' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Experiment ID' }) if (!errorMessage && !postData.host) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Host' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Host' }) } if (type === TracingProvider.tencent) { const postData = config as TencentConfig if (!errorMessage && !postData.token) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Token' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Token' }) if (!errorMessage && !postData.endpoint) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Endpoint' }) if (!errorMessage && !postData.service_name) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Service Name' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Service Name' }) } return errorMessage @@ -281,7 +281,7 @@ const ProviderConfigModal: FC = ({ }) Toast.notify({ type: 'success', - message: t('common.api.success'), + message: t('api.success', { ns: 'common' }), }) onSaved(config) if (isAdd) @@ -303,8 +303,8 @@ const ProviderConfigModal: FC = ({
- {t(`${I18N_PREFIX}.title`)} - {t(`app.tracing.${type}.title`)} + {t(`${I18N_PREFIX}.title`, { ns: 'app' })} + {t(`tracing.${type}.title`, { ns: 'app' })}
@@ -317,7 +317,7 @@ const ProviderConfigModal: FC = ({ isRequired value={(config as ArizeConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ isRequired value={(config as ArizeConfig).space_id} onChange={handleConfigChange('space_id')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Space ID' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Space ID' })!} /> = ({ isRequired value={(config as PhoenixConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ isRequired value={(config as AliyunConfig).license_key} onChange={handleConfigChange('license_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'License Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'License Key' })!} /> = ({ isRequired value={(config as TencentConfig).token} onChange={handleConfigChange('token')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Token' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Token' })!} /> = ({ isRequired value={(config as WeaveConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ isRequired value={(config as LangSmithConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ {type === TracingProvider.langfuse && ( <> = ({ labelClassName="!text-sm" value={(config as OpikConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ {type === TracingProvider.mlflow && ( <> = ({ placeholder="http://localhost:5000" /> )} {type === TracingProvider.databricks && ( <> )} @@ -634,7 +634,7 @@ const ProviderConfigModal: FC = ({ target="_blank" href={docURL[type]} > - {t(`${I18N_PREFIX}.viewDocsLink`, { key: t(`app.tracing.${type}.title`) })} + {t(`${I18N_PREFIX}.viewDocsLink`, { ns: 'app', key: t(`tracing.${type}.title`, { ns: 'app' }) })}
@@ -644,7 +644,7 @@ const ProviderConfigModal: FC = ({ className="h-9 text-sm font-medium text-text-secondary" onClick={showRemoveConfirm} > - {t('common.operation.remove')} + {t('operation.remove', { ns: 'common' })} @@ -653,7 +653,7 @@ const ProviderConfigModal: FC = ({ className="mr-2 h-9 text-sm font-medium text-text-secondary" onClick={onCancel} > - {t('common.operation.cancel')} + {t('operation.cancel', { ns: 'common' })}
@@ -670,7 +670,7 @@ const ProviderConfigModal: FC = ({
- {t('common.modelProvider.encrypted.front')} + {t('modelProvider.encrypted.front', { ns: 'common' })} = ({ > PKCS1_OAEP - {t('common.modelProvider.encrypted.back')} + {t('modelProvider.encrypted.back', { ns: 'common' })}
@@ -691,8 +691,8 @@ const ProviderConfigModal: FC = ({ diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index 6c66b19ad3..dc0fe2fbbc 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -11,7 +11,7 @@ import { Eye as View } from '@/app/components/base/icons/src/vender/solid/genera import { cn } from '@/utils/classnames' import { TracingProvider } from './type' -const I18N_PREFIX = 'app.tracing' +const I18N_PREFIX = 'tracing' type Props = { type: TracingProvider @@ -82,14 +82,14 @@ const ProviderPanel: FC = ({
- {isChosen &&
{t(`${I18N_PREFIX}.inUse`)}
} + {isChosen &&
{t(`${I18N_PREFIX}.inUse`, { ns: 'app' })}
}
{!readOnly && (
{hasConfigured && (
-
{t(`${I18N_PREFIX}.view`)}
+
{t(`${I18N_PREFIX}.view`, { ns: 'app' })}
)}
= ({ onClick={handleConfigBtnClick} > -
{t(`${I18N_PREFIX}.config`)}
+
{t(`${I18N_PREFIX}.config`, { ns: 'app' })}
)}
- {t(`${I18N_PREFIX}.${type}.description`)} + {t(`${I18N_PREFIX}.${type}.description`, { ns: 'app' })}
) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx index 4135482dd9..a918ae2786 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx @@ -15,7 +15,7 @@ const AppDetail: FC = ({ children }) => { const router = useRouter() const { isCurrentWorkspaceDatasetOperator } = useAppContext() const { t } = useTranslation() - useDocumentTitle(t('common.menus.appDetail')) + useDocumentTitle(t('menus.appDetail', { ns: 'common' })) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 10a12d75e1..1c5434924f 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -70,14 +70,14 @@ const DatasetDetailLayout: FC = (props) => { const navigation = useMemo(() => { const baseNavigation = [ { - name: t('common.datasetMenus.hitTesting'), + name: t('datasetMenus.hitTesting', { ns: 'common' }), href: `/datasets/${datasetId}/hitTesting`, icon: RiFocus2Line, selectedIcon: RiFocus2Fill, disabled: isButtonDisabledWithPipeline, }, { - name: t('common.datasetMenus.settings'), + name: t('datasetMenus.settings', { ns: 'common' }), href: `/datasets/${datasetId}/settings`, icon: RiEqualizer2Line, selectedIcon: RiEqualizer2Fill, @@ -87,14 +87,14 @@ const DatasetDetailLayout: FC = (props) => { if (datasetRes?.provider !== 'external') { baseNavigation.unshift({ - name: t('common.datasetMenus.pipeline'), + name: t('datasetMenus.pipeline', { ns: 'common' }), href: `/datasets/${datasetId}/pipeline`, icon: PipelineLine as RemixiconComponentType, selectedIcon: PipelineFill as RemixiconComponentType, disabled: false, }) baseNavigation.unshift({ - name: t('common.datasetMenus.documents'), + name: t('datasetMenus.documents', { ns: 'common' }), href: `/datasets/${datasetId}/documents`, icon: RiFileTextLine, selectedIcon: RiFileTextFill, @@ -105,7 +105,7 @@ const DatasetDetailLayout: FC = (props) => { return baseNavigation }, [t, datasetId, isButtonDisabledWithPipeline, datasetRes?.provider]) - useDocumentTitle(datasetRes?.name || t('common.menus.datasets')) + useDocumentTitle(datasetRes?.name || t('menus.datasets', { ns: 'common' })) const setAppSidebarExpand = useStore(state => state.setAppSidebarExpand) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index aa64df3449..1d65e4de53 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -1,15 +1,13 @@ -import * as React from 'react' +import { useTranslation } from '#i18n' import Form from '@/app/components/datasets/settings/form' -import { getLocaleOnServer, getTranslation } from '@/i18n-config/server' -const Settings = async () => { - const locale = await getLocaleOnServer() - const { t } = await getTranslation(locale, 'dataset-settings') +const Settings = () => { + const { t } = useTranslation('datasetSettings') return (
-
{t('title') as any}
+
{t('title')}
{t('desc')}
diff --git a/web/app/(commonLayout)/explore/layout.tsx b/web/app/(commonLayout)/explore/layout.tsx index 5928308cdc..7d59d397f9 100644 --- a/web/app/(commonLayout)/explore/layout.tsx +++ b/web/app/(commonLayout)/explore/layout.tsx @@ -7,7 +7,7 @@ import useDocumentTitle from '@/hooks/use-document-title' const ExploreLayout: FC = ({ children }) => { const { t } = useTranslation() - useDocumentTitle(t('common.menus.explore')) + useDocumentTitle(t('menus.explore', { ns: 'common' })) return ( {children} diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 91bdb3f99a..a0ccde957d 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -1,5 +1,6 @@ import type { ReactNode } from 'react' import * as React from 'react' +import { AppInitializer } from '@/app/components/app-initializer' import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import Zendesk from '@/app/components/base/zendesk' @@ -7,7 +8,6 @@ import GotoAnything from '@/app/components/goto-anything' import Header from '@/app/components/header' import HeaderWrapper from '@/app/components/header/header-wrapper' import ReadmePanel from '@/app/components/plugins/readme-panel' -import SwrInitializer from '@/app/components/swr-initializer' import { AppContextProvider } from '@/context/app-context' import { EventEmitterContextProvider } from '@/context/event-emitter' import { ModalContextProvider } from '@/context/modal-context' @@ -20,7 +20,7 @@ const Layout = ({ children }: { children: ReactNode }) => { <> - + @@ -38,7 +38,7 @@ const Layout = ({ children }: { children: ReactNode }) => { - + ) } diff --git a/web/app/(commonLayout)/plugins/page.tsx b/web/app/(commonLayout)/plugins/page.tsx index 2df9cf23c4..f366200cf9 100644 --- a/web/app/(commonLayout)/plugins/page.tsx +++ b/web/app/(commonLayout)/plugins/page.tsx @@ -1,14 +1,12 @@ import Marketplace from '@/app/components/plugins/marketplace' import PluginPage from '@/app/components/plugins/plugin-page' import PluginsPanel from '@/app/components/plugins/plugin-page/plugins-panel' -import { getLocaleOnServer } from '@/i18n-config/server' -const PluginList = async () => { - const locale = await getLocaleOnServer() +const PluginList = () => { return ( } - marketplace={} + marketplace={} /> ) } diff --git a/web/app/(commonLayout)/tools/page.tsx b/web/app/(commonLayout)/tools/page.tsx index 2d5c1a8e44..3e88050eba 100644 --- a/web/app/(commonLayout)/tools/page.tsx +++ b/web/app/(commonLayout)/tools/page.tsx @@ -12,7 +12,7 @@ const ToolsList: FC = () => { const router = useRouter() const { isCurrentWorkspaceDatasetOperator } = useAppContext() const { t } = useTranslation() - useDocumentTitle(t('common.menus.tools')) + useDocumentTitle(t('menus.tools', { ns: 'common' })) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index 00288b7a61..113f3b5680 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -81,7 +81,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { return (
- {t('common.userProfile.logout')} + {t('userProfile.logout', { ns: 'common' })}
) } diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index b8ea1b2e56..9f89a03993 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -94,8 +94,8 @@ const Splash: FC = ({ children }) => { if (message) { return (
- - {code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')} + + {code === '403' ? t('userProfile.logout', { ns: 'common' }) : t('login.backToHome', { ns: 'share' })}
) } diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index 9ce058340c..fbf45259e5 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -3,12 +3,12 @@ import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' -import I18NContext from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common' export default function CheckCode() { @@ -19,21 +19,21 @@ export default function CheckCode() { const token = decodeURIComponent(searchParams.get('token') as string) const [code, setVerifyCode] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const verify = async () => { try { if (!code.trim()) { Toast.notify({ type: 'error', - message: t('login.checkCode.emptyCode'), + message: t('checkCode.emptyCode', { ns: 'login' }), }) return } if (!/\d{6}/.test(code)) { Toast.notify({ type: 'error', - message: t('login.checkCode.invalidCode'), + message: t('checkCode.invalidCode', { ns: 'login' }), }) return } @@ -69,22 +69,22 @@ export default function CheckCode() {
-

{t('login.checkCode.checkYourEmail')}

+

{t('checkCode.checkYourEmail', { ns: 'login' })}

- {t('login.checkCode.tipsPrefix')} + {t('checkCode.tipsPrefix', { ns: 'login' })} {email}
- {t('login.checkCode.validTime')} + {t('checkCode.validTime', { ns: 'login' })}

- - setVerifyCode(e.target.value)} maxLength={6} className="mt-1" placeholder={t('login.checkCode.verificationCodePlaceholder') || ''} /> - + + setVerifyCode(e.target.value)} maxLength={6} className="mt-1" placeholder={t('checkCode.verificationCodePlaceholder', { ns: 'login' }) || ''} /> +
@@ -94,7 +94,7 @@ export default function CheckCode() {
- {t('login.back')} + {t('back', { ns: 'login' })}
) diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index 108bd4b22e..9b9a853cdd 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -1,17 +1,17 @@ 'use client' import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import Link from 'next/link' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import useDocumentTitle from '@/hooks/use-document-title' import { sendResetPasswordCode } from '@/service/common' @@ -22,19 +22,19 @@ export default function CheckCode() { const router = useRouter() const [email, setEmail] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const handleGetEMailVerificationCode = async () => { try { if (!email) { - Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) return } if (!emailRegex.test(email)) { Toast.notify({ type: 'error', - message: t('login.error.emailInValid'), + message: t('error.emailInValid', { ns: 'login' }), }) return } @@ -50,7 +50,7 @@ export default function CheckCode() { else if (res.code === 'account_not_found') { Toast.notify({ type: 'error', - message: t('login.error.registrationNotAllowed'), + message: t('error.registrationNotAllowed', { ns: 'login' }), }) } else { @@ -74,21 +74,21 @@ export default function CheckCode() {
-

{t('login.resetPassword')}

+

{t('resetPassword', { ns: 'login' })}

- {t('login.resetPasswordDesc')} + {t('resetPasswordDesc', { ns: 'login' })}

- +
- setEmail(e.target.value)} /> + setEmail(e.target.value)} />
- +
@@ -99,7 +99,7 @@ export default function CheckCode() {
- {t('login.backToLogin')} + {t('backToLogin', { ns: 'login' })}
) diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index bfb71d9c6f..9f59e8f9eb 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -45,15 +45,15 @@ const ChangePasswordForm = () => { const valid = useCallback(() => { if (!password.trim()) { - showErrorMessage(t('login.error.passwordEmpty')) + showErrorMessage(t('error.passwordEmpty', { ns: 'login' })) return false } if (!validPassword.test(password)) { - showErrorMessage(t('login.error.passwordInvalid')) + showErrorMessage(t('error.passwordInvalid', { ns: 'login' })) return false } if (password !== confirmPassword) { - showErrorMessage(t('common.account.notEqual')) + showErrorMessage(t('account.notEqual', { ns: 'common' })) return false } return true @@ -92,10 +92,10 @@ const ChangePasswordForm = () => {

- {t('login.changePassword')} + {t('changePassword', { ns: 'login' })}

- {t('login.changePasswordTip')} + {t('changePasswordTip', { ns: 'login' })}

@@ -104,7 +104,7 @@ const ChangePasswordForm = () => { {/* Password */}
{ type={showPassword ? 'text' : 'password'} value={password} onChange={e => setPassword(e.target.value)} - placeholder={t('login.passwordPlaceholder') || ''} + placeholder={t('passwordPlaceholder', { ns: 'login' }) || ''} />
@@ -125,12 +125,12 @@ const ChangePasswordForm = () => {
-
{t('login.error.passwordInvalid')}
+
{t('error.passwordInvalid', { ns: 'login' })}
{/* Confirm Password */}
{ type={showConfirmPassword ? 'text' : 'password'} value={confirmPassword} onChange={e => setConfirmPassword(e.target.value)} - placeholder={t('login.confirmPasswordPlaceholder') || ''} + placeholder={t('confirmPasswordPlaceholder', { ns: 'login' }) || ''} />
@@ -171,7 +171,7 @@ const ChangePasswordForm = () => {

- {t('login.passwordChangedTip')} + {t('passwordChangedTip', { ns: 'login' })}

@@ -183,7 +183,7 @@ const ChangePasswordForm = () => { router.replace(getSignInUrl()) }} > - {t('login.passwordChanged')} + {t('passwordChanged', { ns: 'login' })} {' '} ( {Math.round(countdown / 1000)} diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index ee7fc22bea..afea9d668b 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -4,16 +4,16 @@ import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' -import I18NContext from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' +import { encryptVerificationCode } from '@/utils/encryption' export default function CheckCode() { const { t } = useTranslation() @@ -23,7 +23,7 @@ export default function CheckCode() { const token = decodeURIComponent(searchParams.get('token') as string) const [code, setVerifyCode] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const codeInputRef = useRef(null) const redirectUrl = searchParams.get('redirect_url') const embeddedUserId = useWebAppStore(s => s.embeddedUserId) @@ -45,28 +45,30 @@ export default function CheckCode() { if (!code.trim()) { Toast.notify({ type: 'error', - message: t('login.checkCode.emptyCode'), + message: t('checkCode.emptyCode', { ns: 'login' }), }) return } if (!/\d{6}/.test(code)) { Toast.notify({ type: 'error', - message: t('login.checkCode.invalidCode'), + message: t('checkCode.invalidCode', { ns: 'login' }), }) return } if (!redirectUrl || !appCode) { Toast.notify({ type: 'error', - message: t('login.error.redirectUrlMissing'), + message: t('error.redirectUrlMissing', { ns: 'login' }), }) return } setIsLoading(true) - const ret = await webAppEmailLoginWithCode({ email, code, token }) + const ret = await webAppEmailLoginWithCode({ email, code: encryptVerificationCode(code), token }) if (ret.result === 'success') { - setWebAppAccessToken(ret.data.access_token) + if (ret?.data?.access_token) { + setWebAppAccessToken(ret.data.access_token) + } const { access_token } = await fetchAccessToken({ appCode: appCode!, userId: embeddedUserId || undefined, @@ -108,19 +110,19 @@ export default function CheckCode() {
-

{t('login.checkCode.checkYourEmail')}

+

{t('checkCode.checkYourEmail', { ns: 'login' })}

- {t('login.checkCode.tipsPrefix')} + {t('checkCode.tipsPrefix', { ns: 'login' })} {email}
- {t('login.checkCode.validTime')} + {t('checkCode.validTime', { ns: 'login' })}

- + setVerifyCode(e.target.value)} maxLength={6} className="mt-1" - placeholder={t('login.checkCode.verificationCodePlaceholder') || ''} + placeholder={t('checkCode.verificationCodePlaceholder', { ns: 'login' }) || ''} /> - +
@@ -140,7 +142,7 @@ export default function CheckCode() {
- {t('login.back')} + {t('back', { ns: 'login' })}
) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index 8b611b9eea..5aa9d9f141 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -1,14 +1,13 @@ -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { sendWebAppEMailLoginCode } from '@/service/common' export default function MailAndCodeAuth() { @@ -18,19 +17,19 @@ export default function MailAndCodeAuth() { const emailFromLink = decodeURIComponent(searchParams.get('email') || '') const [email, setEmail] = useState(emailFromLink) const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const handleGetEMailVerificationCode = async () => { try { if (!email) { - Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) return } if (!emailRegex.test(email)) { Toast.notify({ type: 'error', - message: t('login.error.emailInValid'), + message: t('error.emailInValid', { ns: 'login' }), }) return } @@ -56,12 +55,12 @@ export default function MailAndCodeAuth() {
- +
- setEmail(e.target.value)} /> + setEmail(e.target.value)} />
- +
diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index 46645ed68c..e49559401d 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -1,19 +1,19 @@ 'use client' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import Link from 'next/link' import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' import { webAppLogin } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' +import { encryptPassword } from '@/utils/encryption' type MailAndPasswordAuthProps = { isEmailSetup: boolean @@ -21,7 +21,7 @@ type MailAndPasswordAuthProps = { export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAuthProps) { const { t } = useTranslation() - const { locale } = useContext(I18NContext) + const locale = useLocale() const router = useRouter() const searchParams = useSearchParams() const [showPassword, setShowPassword] = useState(false) @@ -46,25 +46,25 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut const appCode = getAppCodeFromRedirectUrl() const handleEmailPasswordLogin = async () => { if (!email) { - Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) return } if (!emailRegex.test(email)) { Toast.notify({ type: 'error', - message: t('login.error.emailInValid'), + message: t('error.emailInValid', { ns: 'login' }), }) return } if (!password?.trim()) { - Toast.notify({ type: 'error', message: t('login.error.passwordEmpty') }) + Toast.notify({ type: 'error', message: t('error.passwordEmpty', { ns: 'login' }) }) return } if (!redirectUrl || !appCode) { Toast.notify({ type: 'error', - message: t('login.error.redirectUrlMissing'), + message: t('error.redirectUrlMissing', { ns: 'login' }), }) return } @@ -72,7 +72,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut setIsLoading(true) const loginData: Record = { email, - password, + password: encryptPassword(password), language: locale, remember_me: true, } @@ -82,7 +82,9 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut body: loginData, }) if (res.result === 'success') { - setWebAppAccessToken(res.data.access_token) + if (res?.data?.access_token) { + setWebAppAccessToken(res.data.access_token) + } const { access_token } = await fetchAccessToken({ appCode: appCode!, @@ -111,7 +113,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
@@ -128,14 +130,14 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
@@ -149,7 +151,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut }} type={showPassword ? 'text' : 'password'} autoComplete="current-password" - placeholder={t('login.passwordPlaceholder') || ''} + placeholder={t('passwordPlaceholder', { ns: 'login' }) || ''} tabIndex={2} />
@@ -172,7 +174,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut disabled={isLoading || !email || !password} className="w-full" > - {t('login.signBtn')} + {t('signBtn', { ns: 'login' })}
diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx index 472952c2c8..d8f3854868 100644 --- a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx @@ -82,7 +82,7 @@ const SSOAuth: FC = ({ className="w-full" > - {t('login.withSSO')} + {t('withSSO', { ns: 'login' })} ) } diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx index dd4510a541..21cb0e1f57 100644 --- a/web/app/(shareLayout)/webapp-signin/layout.tsx +++ b/web/app/(shareLayout)/webapp-signin/layout.tsx @@ -9,7 +9,7 @@ import { cn } from '@/utils/classnames' export default function SignInLayout({ children }: PropsWithChildren) { const { t } = useTranslation() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) - useDocumentTitle(t('login.webapp.login')) + useDocumentTitle(t('webapp.login', { ns: 'login' })) return ( <>
diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index 40d34dcaf5..b15145346f 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -60,8 +60,8 @@ const NormalForm = () => {
-

{t('login.licenseLost')}

-

{t('login.licenseLostTip')}

+

{t('licenseLost', { ns: 'login' })}

+

{t('licenseLostTip', { ns: 'login' })}

@@ -76,8 +76,8 @@ const NormalForm = () => {
-

{t('login.licenseExpired')}

-

{t('login.licenseExpiredTip')}

+

{t('licenseExpired', { ns: 'login' })}

+

{t('licenseExpiredTip', { ns: 'login' })}

@@ -92,8 +92,8 @@ const NormalForm = () => { -

{t('login.licenseInactive')}

-

{t('login.licenseInactiveTip')}

+

{t('licenseInactive', { ns: 'login' })}

+

{t('licenseInactiveTip', { ns: 'login' })}

@@ -104,8 +104,8 @@ const NormalForm = () => { <>
-

{systemFeatures.branding.enabled ? t('login.pageTitleForE') : t('login.pageTitle')}

-

{t('login.welcome')}

+

{systemFeatures.branding.enabled ? t('pageTitleForE', { ns: 'login' }) : t('pageTitle', { ns: 'login' })}

+

{t('welcome', { ns: 'login' })}

@@ -122,7 +122,7 @@ const NormalForm = () => {
- {t('login.or')} + {t('or', { ns: 'login' })}
)} @@ -134,7 +134,7 @@ const NormalForm = () => { {systemFeatures.enable_email_password_login && (
{ updateAuthType('password') }}> - {t('login.usePassword')} + {t('usePassword', { ns: 'login' })}
)} @@ -144,7 +144,7 @@ const NormalForm = () => { {systemFeatures.enable_email_code_login && (
{ updateAuthType('code') }}> - {t('login.useVerificationCode')} + {t('useVerificationCode', { ns: 'login' })}
)} @@ -158,8 +158,8 @@ const NormalForm = () => {
-

{t('login.noLoginMethod')}

-

{t('login.noLoginMethodTip')}

+

{t('noLoginMethod', { ns: 'login' })}

+

{t('noLoginMethodTip', { ns: 'login' })}

{step === STEP.start && ( <> -
{t('common.account.changeEmail.title')}
+
{t('account.changeEmail.title', { ns: 'common' })}
-
{t('common.account.changeEmail.authTip')}
+
{t('account.changeEmail.authTip', { ns: 'common' })}
}} values={{ email }} /> @@ -227,34 +228,35 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { variant="primary" onClick={sendCodeToOriginEmail} > - {t('common.account.changeEmail.sendVerifyCode')} + {t('account.changeEmail.sendVerifyCode', { ns: 'common' })}
)} {step === STEP.verifyOrigin && ( <> -
{t('common.account.changeEmail.verifyEmail')}
+
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
}} values={{ email }} />
-
{t('common.account.changeEmail.codeLabel')}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
setCode(e.target.value)} maxLength={6} @@ -267,46 +269,46 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { variant="primary" onClick={handleVerifyOriginEmail} > - {t('common.account.changeEmail.continue')} + {t('account.changeEmail.continue', { ns: 'common' })}
- {t('common.account.changeEmail.resendTip')} + {t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( - {t('common.account.changeEmail.resendCount', { count: time })} + {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('common.account.changeEmail.resend')} + {t('account.changeEmail.resend', { ns: 'common' })} )}
)} {step === STEP.newEmail && ( <> -
{t('common.account.changeEmail.newEmail')}
+
{t('account.changeEmail.newEmail', { ns: 'common' })}
-
{t('common.account.changeEmail.content3')}
+
{t('account.changeEmail.content3', { ns: 'common' })}
-
{t('common.account.changeEmail.emailLabel')}
+
{t('account.changeEmail.emailLabel', { ns: 'common' })}
handleNewEmailValueChange(e.target.value)} destructive={newEmailExited || unAvailableEmail} /> {newEmailExited && ( -
{t('common.account.changeEmail.existingEmail')}
+
{t('account.changeEmail.existingEmail', { ns: 'common' })}
)} {unAvailableEmail && ( -
{t('common.account.changeEmail.unAvailableEmail')}
+
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
)}
@@ -316,34 +318,35 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { variant="primary" onClick={sendCodeToNewEmail} > - {t('common.account.changeEmail.sendVerifyCode')} + {t('account.changeEmail.sendVerifyCode', { ns: 'common' })}
)} {step === STEP.verifyNew && ( <> -
{t('common.account.changeEmail.verifyNew')}
+
{t('account.changeEmail.verifyNew', { ns: 'common' })}
}} values={{ email: mail }} />
-
{t('common.account.changeEmail.codeLabel')}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
setCode(e.target.value)} maxLength={6} @@ -356,22 +359,22 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { variant="primary" onClick={submitNewEmail} > - {t('common.account.changeEmail.changeTo', { email: mail })} + {t('account.changeEmail.changeTo', { ns: 'common', email: mail })}
- {t('common.account.changeEmail.resendTip')} + {t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( - {t('common.account.changeEmail.resendCount', { count: time })} + {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('common.account.changeEmail.resend')} + {t('account.changeEmail.resend', { ns: 'common' })} )}
diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index baa9759ffe..f01efc002c 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -61,7 +61,7 @@ export default function AccountPage() { try { setEditing(true) await updateUserProfile({ url: 'account/name', body: { name: editName } }) - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) mutateUserProfile() setEditNameModalVisible(false) setEditing(false) @@ -80,15 +80,15 @@ export default function AccountPage() { } const valid = () => { if (!password.trim()) { - showErrorMessage(t('login.error.passwordEmpty')) + showErrorMessage(t('error.passwordEmpty', { ns: 'login' })) return false } if (!validPassword.test(password)) { - showErrorMessage(t('login.error.passwordInvalid')) + showErrorMessage(t('error.passwordInvalid', { ns: 'login' })) return false } if (password !== confirmPassword) { - showErrorMessage(t('common.account.notEqual')) + showErrorMessage(t('account.notEqual', { ns: 'common' })) return false } @@ -112,7 +112,7 @@ export default function AccountPage() { repeat_new_password: confirmPassword, }, }) - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) mutateUserProfile() setEditPasswordModalVisible(false) resetPasswordForm() @@ -146,7 +146,7 @@ export default function AccountPage() { return ( <>
-

{t('common.account.myAccount')}

+

{t('account.myAccount', { ns: 'common' })}

@@ -164,25 +164,25 @@ export default function AccountPage() {
-
{t('common.account.name')}
+
{t('account.name', { ns: 'common' })}
{userProfile.name}
- {t('common.operation.edit')} + {t('operation.edit', { ns: 'common' })}
-
{t('common.account.email')}
+
{t('account.email', { ns: 'common' })}
{userProfile.email}
{systemFeatures.enable_change_email && (
setShowUpdateEmail(true)}> - {t('common.operation.change')} + {t('operation.change', { ns: 'common' })}
)}
@@ -191,26 +191,26 @@ export default function AccountPage() { systemFeatures.enable_email_password_login && (
-
{t('common.account.password')}
-
{t('common.account.passwordTip')}
+
{t('account.password', { ns: 'common' })}
+
{t('account.passwordTip', { ns: 'common' })}
- +
) }
-
{t('common.account.langGeniusAccount')}
-
{t('common.account.langGeniusAccountTip')}
+
{t('account.langGeniusAccount', { ns: 'common' })}
+
{t('account.langGeniusAccountTip', { ns: 'common' })}
{!!apps.length && ( ({ ...app, key: app.id, name: app.name }))} renderItem={renderAppItem} wrapperClassName="mt-2" /> )} - {!IS_CE_EDITION && } + {!IS_CE_EDITION && }
{ editNameModalVisible && ( @@ -219,21 +219,21 @@ export default function AccountPage() { onClose={() => setEditNameModalVisible(false)} className="!w-[420px] !p-6" > -
{t('common.account.editName')}
-
{t('common.account.name')}
+
{t('account.editName', { ns: 'common' })}
+
{t('account.name', { ns: 'common' })}
setEditName(e.target.value)} />
- +
@@ -249,10 +249,10 @@ export default function AccountPage() { }} className="!w-[420px] !p-6" > -
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
+
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
{userProfile.is_password_set && ( <> -
{t('common.account.currentPassword')}
+
{t('account.currentPassword', { ns: 'common' })}
)}
- {userProfile.is_password_set ? t('common.account.newPassword') : t('common.account.password')} + {userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
-
{t('common.account.confirmPassword')}
+
{t('account.confirmPassword', { ns: 'common' })}
- {t('common.operation.cancel')} + {t('operation.cancel', { ns: 'common' })}
diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index 3cbbb47a76..8ea29e8e45 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -94,7 +94,7 @@ export default function AppSelector() { className="group flex h-9 cursor-pointer items-center justify-start rounded-lg px-3 hover:bg-state-base-hover" > -
{t('common.userProfile.logout')}
+
{t('userProfile.logout', { ns: 'common' })}
diff --git a/web/app/account/(commonLayout)/delete-account/components/check-email.tsx b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx index c0bbf4422a..65a58c936e 100644 --- a/web/app/account/(commonLayout)/delete-account/components/check-email.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx @@ -31,22 +31,22 @@ export default function CheckEmail(props: DeleteAccountProps) { return ( <>
- {t('common.account.deleteTip')} + {t('account.deleteTip', { ns: 'common' })}
- {t('common.account.deletePrivacyLinkTip')} - {t('common.account.deletePrivacyLink')} + {t('account.deletePrivacyLinkTip', { ns: 'common' })} + {t('account.deletePrivacyLink', { ns: 'common' })}
- + { setUserInputEmail(e.target.value) }} />
- - + +
) diff --git a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx index 4a1b41cb20..67fea3c141 100644 --- a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx @@ -28,7 +28,7 @@ export default function FeedBack(props: DeleteAccountProps) { await logout() // Tokens are now stored in cookies and cleared by backend router.push('/signin') - Toast.notify({ type: 'info', message: t('common.account.deleteSuccessTip') }) + Toast.notify({ type: 'info', message: t('account.deleteSuccessTip', { ns: 'common' }) }) } catch (error) { console.error(error) } }, [router, t]) @@ -50,22 +50,22 @@ export default function FeedBack(props: DeleteAccountProps) { - + @@ -116,10 +116,10 @@ const MCPServerModal = ({ {latestParams.length > 0 && (
-
{t('tools.mcp.server.modal.parameters')}
+
{t('mcp.server.modal.parameters', { ns: 'tools' })}
-
{t('tools.mcp.server.modal.parametersTip')}
+
{t('mcp.server.modal.parametersTip', { ns: 'tools' })}
{latestParams.map(paramItem => (
- - + +
) diff --git a/web/app/components/tools/mcp/mcp-server-param-item.spec.tsx b/web/app/components/tools/mcp/mcp-server-param-item.spec.tsx new file mode 100644 index 0000000000..6e3a48e330 --- /dev/null +++ b/web/app/components/tools/mcp/mcp-server-param-item.spec.tsx @@ -0,0 +1,165 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import MCPServerParamItem from './mcp-server-param-item' + +describe('MCPServerParamItem', () => { + const defaultProps = { + data: { + label: 'Test Label', + variable: 'test_variable', + type: 'string', + }, + value: '', + onChange: vi.fn(), + } + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('Test Label')).toBeInTheDocument() + }) + + it('should display label', () => { + render() + expect(screen.getByText('Test Label')).toBeInTheDocument() + }) + + it('should display variable name', () => { + render() + expect(screen.getByText('test_variable')).toBeInTheDocument() + }) + + it('should display type', () => { + render() + expect(screen.getByText('string')).toBeInTheDocument() + }) + + it('should display separator dot', () => { + render() + expect(screen.getByText('·')).toBeInTheDocument() + }) + + it('should render textarea with placeholder', () => { + render() + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(textarea).toBeInTheDocument() + }) + }) + + describe('Value Display', () => { + it('should display empty value by default', () => { + render() + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(textarea).toHaveValue('') + }) + + it('should display provided value', () => { + render() + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(textarea).toHaveValue('test value') + }) + + it('should display long text value', () => { + const longValue = 'This is a very long text value that might span multiple lines' + render() + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(textarea).toHaveValue(longValue) + }) + }) + + describe('User Interactions', () => { + it('should call onChange when text is entered', () => { + const onChange = vi.fn() + render() + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + fireEvent.change(textarea, { target: { value: 'new value' } }) + + expect(onChange).toHaveBeenCalledWith('new value') + }) + + it('should call onChange with empty string when cleared', () => { + const onChange = vi.fn() + render() + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + fireEvent.change(textarea, { target: { value: '' } }) + + expect(onChange).toHaveBeenCalledWith('') + }) + + it('should handle multiple changes', () => { + const onChange = vi.fn() + render() + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + + fireEvent.change(textarea, { target: { value: 'first' } }) + fireEvent.change(textarea, { target: { value: 'second' } }) + fireEvent.change(textarea, { target: { value: 'third' } }) + + expect(onChange).toHaveBeenCalledTimes(3) + expect(onChange).toHaveBeenLastCalledWith('third') + }) + }) + + describe('Different Data Types', () => { + it('should display number type', () => { + const props = { + ...defaultProps, + data: { label: 'Count', variable: 'count', type: 'number' }, + } + render() + expect(screen.getByText('number')).toBeInTheDocument() + }) + + it('should display boolean type', () => { + const props = { + ...defaultProps, + data: { label: 'Enabled', variable: 'enabled', type: 'boolean' }, + } + render() + expect(screen.getByText('boolean')).toBeInTheDocument() + }) + + it('should display array type', () => { + const props = { + ...defaultProps, + data: { label: 'Items', variable: 'items', type: 'array' }, + } + render() + expect(screen.getByText('array')).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle special characters in label', () => { + const props = { + ...defaultProps, + data: { label: 'Test
)}
+ {showMCPServerModal && ( )} - {/* button copy link/ button regenerate */} + {showConfirmDelete && ( { - onGenCode() - setShowConfirmDelete(false) - }} - onCancel={() => setShowConfirmDelete(false)} + onConfirm={onConfirmRegenerate} + onCancel={closeConfirmDelete} /> )} diff --git a/web/app/components/tools/mcp/modal.spec.tsx b/web/app/components/tools/mcp/modal.spec.tsx new file mode 100644 index 0000000000..c2fe8b46c3 --- /dev/null +++ b/web/app/components/tools/mcp/modal.spec.tsx @@ -0,0 +1,745 @@ +import type { ReactNode } from 'react' +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import { describe, expect, it, vi } from 'vitest' +import MCPModal from './modal' + +// Mock the service API +vi.mock('@/service/common', () => ({ + uploadRemoteFileInfo: vi.fn().mockResolvedValue({ url: 'https://example.com/icon.png' }), +})) + +// Mock the AppIconPicker component +type IconPayload = { + type: string + icon: string + background: string +} + +type AppIconPickerProps = { + onSelect: (payload: IconPayload) => void + onClose: () => void +} + +vi.mock('@/app/components/base/app-icon-picker', () => ({ + default: ({ onSelect, onClose }: AppIconPickerProps) => ( +
+ + +
+ ), +})) + +// Mock the plugins service to avoid React Query issues from TabSlider +vi.mock('@/service/use-plugins', () => ({ + useInstalledPluginList: () => ({ + data: { pages: [] }, + hasNextPage: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + isLoading: false, + isSuccess: true, + }), +})) + +describe('MCPModal', () => { + const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + return ({ children }: { children: ReactNode }) => + React.createElement(QueryClientProvider, { client: queryClient }, children) + } + + const defaultProps = { + show: true, + onConfirm: vi.fn(), + onHide: vi.fn(), + } + + describe('Rendering', () => { + it('should render without crashing', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.title')).toBeInTheDocument() + }) + + it('should not render when show is false', () => { + render(, { wrapper: createWrapper() }) + expect(screen.queryByText('tools.mcp.modal.title')).not.toBeInTheDocument() + }) + + it('should render create title when no data is provided', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.title')).toBeInTheDocument() + }) + + it('should render edit title when data is provided', () => { + const mockData = { + id: 'test-id', + name: 'Test Server', + server_url: 'https://example.com/mcp', + server_identifier: 'test-server', + icon: { content: '🔗', background: '#6366F1' }, + } as unknown as ToolWithProvider + + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.editTitle')).toBeInTheDocument() + }) + }) + + describe('Form Fields', () => { + it('should render server URL input', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.serverUrl')).toBeInTheDocument() + }) + + it('should render name input', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.name')).toBeInTheDocument() + }) + + it('should render server identifier input', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.serverIdentifier')).toBeInTheDocument() + }) + + it('should render auth method tabs', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.authentication')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.modal.headers')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.modal.configurations')).toBeInTheDocument() + }) + }) + + describe('Form Interactions', () => { + it('should update URL input value', () => { + render(, { wrapper: createWrapper() }) + + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + fireEvent.change(urlInput, { target: { value: 'https://test.com/mcp' } }) + + expect(urlInput).toHaveValue('https://test.com/mcp') + }) + + it('should update name input value', () => { + render(, { wrapper: createWrapper() }) + + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + fireEvent.change(nameInput, { target: { value: 'My Server' } }) + + expect(nameInput).toHaveValue('My Server') + }) + + it('should update server identifier input value', () => { + render(, { wrapper: createWrapper() }) + + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + fireEvent.change(identifierInput, { target: { value: 'my-server' } }) + + expect(identifierInput).toHaveValue('my-server') + }) + }) + + describe('Tab Navigation', () => { + it('should show authentication section by default', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.useDynamicClientRegistration')).toBeInTheDocument() + }) + + it('should switch to headers section when clicked', async () => { + render(, { wrapper: createWrapper() }) + + const headersTab = screen.getByText('tools.mcp.modal.headers') + fireEvent.click(headersTab) + + await waitFor(() => { + expect(screen.getByText('tools.mcp.modal.headersTip')).toBeInTheDocument() + }) + }) + + it('should switch to configurations section when clicked', async () => { + render(, { wrapper: createWrapper() }) + + const configTab = screen.getByText('tools.mcp.modal.configurations') + fireEvent.click(configTab) + + await waitFor(() => { + expect(screen.getByText('tools.mcp.modal.timeout')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.modal.sseReadTimeout')).toBeInTheDocument() + }) + }) + }) + + describe('Action Buttons', () => { + it('should render confirm button', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.confirm')).toBeInTheDocument() + }) + + it('should render save button in edit mode', () => { + const mockData = { + id: 'test-id', + name: 'Test', + icon: { content: '🔗', background: '#6366F1' }, + } as unknown as ToolWithProvider + + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.save')).toBeInTheDocument() + }) + + it('should render cancel button', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.cancel')).toBeInTheDocument() + }) + + it('should call onHide when cancel is clicked', () => { + const onHide = vi.fn() + render(, { wrapper: createWrapper() }) + + const cancelButton = screen.getByText('tools.mcp.modal.cancel') + fireEvent.click(cancelButton) + + expect(onHide).toHaveBeenCalledTimes(1) + }) + + it('should call onHide when close icon is clicked', () => { + const onHide = vi.fn() + render(, { wrapper: createWrapper() }) + + // Find the close button by its parent div with cursor-pointer class + const closeButtons = document.querySelectorAll('.cursor-pointer') + const closeButton = Array.from(closeButtons).find(el => + el.querySelector('svg'), + ) + + if (closeButton) { + fireEvent.click(closeButton) + expect(onHide).toHaveBeenCalled() + } + }) + + it('should have confirm button disabled when form is empty', () => { + render(, { wrapper: createWrapper() }) + + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + expect(confirmButton).toBeDisabled() + }) + + it('should enable confirm button when required fields are filled', () => { + render(, { wrapper: createWrapper() }) + + // Fill required fields + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + + fireEvent.change(urlInput, { target: { value: 'https://example.com/mcp' } }) + fireEvent.change(nameInput, { target: { value: 'Test Server' } }) + fireEvent.change(identifierInput, { target: { value: 'test-server' } }) + + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + expect(confirmButton).not.toBeDisabled() + }) + }) + + describe('Form Submission', () => { + it('should call onConfirm with correct data when form is submitted', async () => { + const onConfirm = vi.fn().mockResolvedValue(undefined) + render(, { wrapper: createWrapper() }) + + // Fill required fields + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + + fireEvent.change(urlInput, { target: { value: 'https://example.com/mcp' } }) + fireEvent.change(nameInput, { target: { value: 'Test Server' } }) + fireEvent.change(identifierInput, { target: { value: 'test-server' } }) + + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(onConfirm).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'Test Server', + server_url: 'https://example.com/mcp', + server_identifier: 'test-server', + }), + ) + }) + }) + + it('should not call onConfirm with invalid URL', async () => { + const onConfirm = vi.fn() + render(, { wrapper: createWrapper() }) + + // Fill fields with invalid URL + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + + fireEvent.change(urlInput, { target: { value: 'not-a-valid-url' } }) + fireEvent.change(nameInput, { target: { value: 'Test Server' } }) + fireEvent.change(identifierInput, { target: { value: 'test-server' } }) + + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + fireEvent.click(confirmButton) + + // Wait a bit and verify onConfirm was not called + await new Promise(resolve => setTimeout(resolve, 100)) + expect(onConfirm).not.toHaveBeenCalled() + }) + + it('should not call onConfirm with invalid server identifier', async () => { + const onConfirm = vi.fn() + render(, { wrapper: createWrapper() }) + + // Fill fields with invalid server identifier + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + + fireEvent.change(urlInput, { target: { value: 'https://example.com/mcp' } }) + fireEvent.change(nameInput, { target: { value: 'Test Server' } }) + fireEvent.change(identifierInput, { target: { value: 'Invalid Server ID!' } }) + + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + fireEvent.click(confirmButton) + + // Wait a bit and verify onConfirm was not called + await new Promise(resolve => setTimeout(resolve, 100)) + expect(onConfirm).not.toHaveBeenCalled() + }) + }) + + describe('Edit Mode', () => { + const mockData = { + id: 'test-id', + name: 'Existing Server', + server_url: 'https://existing.com/mcp', + server_identifier: 'existing-server', + icon: { content: '🚀', background: '#FF0000' }, + configuration: { + timeout: 60, + sse_read_timeout: 600, + }, + masked_headers: { + Authorization: '***', + }, + is_dynamic_registration: false, + authentication: { + client_id: 'client-123', + client_secret: 'secret-456', + }, + } as unknown as ToolWithProvider + + it('should populate form with existing data', () => { + render(, { wrapper: createWrapper() }) + + expect(screen.getByDisplayValue('https://existing.com/mcp')).toBeInTheDocument() + expect(screen.getByDisplayValue('Existing Server')).toBeInTheDocument() + expect(screen.getByDisplayValue('existing-server')).toBeInTheDocument() + }) + + it('should show warning when URL is changed', () => { + render(, { wrapper: createWrapper() }) + + const urlInput = screen.getByDisplayValue('https://existing.com/mcp') + fireEvent.change(urlInput, { target: { value: 'https://new.com/mcp' } }) + + expect(screen.getByText('tools.mcp.modal.serverUrlWarning')).toBeInTheDocument() + }) + + it('should show warning when server identifier is changed', () => { + render(, { wrapper: createWrapper() }) + + const identifierInput = screen.getByDisplayValue('existing-server') + fireEvent.change(identifierInput, { target: { value: 'new-server' } }) + + expect(screen.getByText('tools.mcp.modal.serverIdentifierWarning')).toBeInTheDocument() + }) + }) + + describe('Form Key Reset', () => { + it('should reset form when switching from create to edit mode', () => { + const { rerender } = render(, { wrapper: createWrapper() }) + + // Fill some data in create mode + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + fireEvent.change(nameInput, { target: { value: 'New Server' } }) + + // Switch to edit mode with different data + const mockData = { + id: 'edit-id', + name: 'Edit Server', + icon: { content: '🔗', background: '#6366F1' }, + } as unknown as ToolWithProvider + + rerender() + + // Should show edit mode data + expect(screen.getByDisplayValue('Edit Server')).toBeInTheDocument() + }) + }) + + describe('URL Blur Handler', () => { + it('should trigger URL blur handler when URL input loses focus', () => { + render(, { wrapper: createWrapper() }) + + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + fireEvent.change(urlInput, { target: { value: ' https://test.com/mcp ' } }) + fireEvent.blur(urlInput) + + // The blur handler trims the value + expect(urlInput).toHaveValue(' https://test.com/mcp ') + }) + + it('should handle URL blur with empty value', () => { + render(, { wrapper: createWrapper() }) + + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + fireEvent.change(urlInput, { target: { value: '' } }) + fireEvent.blur(urlInput) + + expect(urlInput).toHaveValue('') + }) + }) + + describe('App Icon', () => { + it('should render app icon with default emoji', () => { + render(, { wrapper: createWrapper() }) + + // The app icon should be rendered + const appIcons = document.querySelectorAll('[class*="rounded-2xl"]') + expect(appIcons.length).toBeGreaterThan(0) + }) + + it('should render app icon in edit mode with custom icon', () => { + const mockData = { + id: 'test-id', + name: 'Test Server', + server_url: 'https://example.com/mcp', + server_identifier: 'test-server', + icon: { content: '🚀', background: '#FF0000' }, + } as unknown as ToolWithProvider + + render(, { wrapper: createWrapper() }) + + // The app icon should be rendered + const appIcons = document.querySelectorAll('[class*="rounded-2xl"]') + expect(appIcons.length).toBeGreaterThan(0) + }) + }) + + describe('Form Submission with Headers', () => { + it('should submit form with headers data', async () => { + const onConfirm = vi.fn().mockResolvedValue(undefined) + render(, { wrapper: createWrapper() }) + + // Fill required fields + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + + fireEvent.change(urlInput, { target: { value: 'https://example.com/mcp' } }) + fireEvent.change(nameInput, { target: { value: 'Test Server' } }) + fireEvent.change(identifierInput, { target: { value: 'test-server' } }) + + // Switch to headers tab and add a header + const headersTab = screen.getByText('tools.mcp.modal.headers') + fireEvent.click(headersTab) + + await waitFor(() => { + expect(screen.getByText('tools.mcp.modal.headersTip')).toBeInTheDocument() + }) + + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(onConfirm).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'Test Server', + server_url: 'https://example.com/mcp', + server_identifier: 'test-server', + }), + ) + }) + }) + + it('should submit with authentication data', async () => { + const onConfirm = vi.fn().mockResolvedValue(undefined) + render(, { wrapper: createWrapper() }) + + // Fill required fields + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + + fireEvent.change(urlInput, { target: { value: 'https://example.com/mcp' } }) + fireEvent.change(nameInput, { target: { value: 'Test Server' } }) + fireEvent.change(identifierInput, { target: { value: 'test-server' } }) + + // Submit form + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(onConfirm).toHaveBeenCalledWith( + expect.objectContaining({ + authentication: expect.objectContaining({ + client_id: '', + client_secret: '', + }), + }), + ) + }) + }) + + it('should format headers correctly when submitting with header keys', async () => { + const onConfirm = vi.fn().mockResolvedValue(undefined) + const mockData = { + id: 'test-id', + name: 'Test Server', + server_url: 'https://example.com/mcp', + server_identifier: 'test-server', + icon: { content: '🔗', background: '#6366F1' }, + masked_headers: { + 'Authorization': 'Bearer token', + 'X-Custom': 'value', + }, + } as unknown as ToolWithProvider + + render(, { wrapper: createWrapper() }) + + // Switch to headers tab + const headersTab = screen.getByText('tools.mcp.modal.headers') + fireEvent.click(headersTab) + + await waitFor(() => { + expect(screen.getByText('tools.mcp.modal.headersTip')).toBeInTheDocument() + }) + + // Submit form + const saveButton = screen.getByText('tools.mcp.modal.save') + fireEvent.click(saveButton) + + await waitFor(() => { + expect(onConfirm).toHaveBeenCalledWith( + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: expect.any(String), + }), + }), + ) + }) + }) + }) + + describe('Edit Mode Submission', () => { + it('should send hidden URL when URL is unchanged in edit mode', async () => { + const onConfirm = vi.fn().mockResolvedValue(undefined) + const mockData = { + id: 'test-id', + name: 'Existing Server', + server_url: 'https://existing.com/mcp', + server_identifier: 'existing-server', + icon: { content: '🚀', background: '#FF0000' }, + } as unknown as ToolWithProvider + + render(, { wrapper: createWrapper() }) + + // Don't change the URL, just submit + const saveButton = screen.getByText('tools.mcp.modal.save') + fireEvent.click(saveButton) + + await waitFor(() => { + expect(onConfirm).toHaveBeenCalledWith( + expect.objectContaining({ + server_url: '[__HIDDEN__]', + }), + ) + }) + }) + + it('should send new URL when URL is changed in edit mode', async () => { + const onConfirm = vi.fn().mockResolvedValue(undefined) + const mockData = { + id: 'test-id', + name: 'Existing Server', + server_url: 'https://existing.com/mcp', + server_identifier: 'existing-server', + icon: { content: '🚀', background: '#FF0000' }, + } as unknown as ToolWithProvider + + render(, { wrapper: createWrapper() }) + + // Change the URL + const urlInput = screen.getByDisplayValue('https://existing.com/mcp') + fireEvent.change(urlInput, { target: { value: 'https://new.com/mcp' } }) + + const saveButton = screen.getByText('tools.mcp.modal.save') + fireEvent.click(saveButton) + + await waitFor(() => { + expect(onConfirm).toHaveBeenCalledWith( + expect.objectContaining({ + server_url: 'https://new.com/mcp', + }), + ) + }) + }) + }) + + describe('Configuration Section', () => { + it('should submit with default timeout values', async () => { + const onConfirm = vi.fn().mockResolvedValue(undefined) + render(, { wrapper: createWrapper() }) + + // Fill required fields + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + + fireEvent.change(urlInput, { target: { value: 'https://example.com/mcp' } }) + fireEvent.change(nameInput, { target: { value: 'Test Server' } }) + fireEvent.change(identifierInput, { target: { value: 'test-server' } }) + + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(onConfirm).toHaveBeenCalledWith( + expect.objectContaining({ + configuration: expect.objectContaining({ + timeout: 30, + sse_read_timeout: 300, + }), + }), + ) + }) + }) + + it('should submit with custom timeout values', async () => { + const onConfirm = vi.fn().mockResolvedValue(undefined) + render(, { wrapper: createWrapper() }) + + // Fill required fields + const urlInput = screen.getByPlaceholderText('tools.mcp.modal.serverUrlPlaceholder') + const nameInput = screen.getByPlaceholderText('tools.mcp.modal.namePlaceholder') + const identifierInput = screen.getByPlaceholderText('tools.mcp.modal.serverIdentifierPlaceholder') + + fireEvent.change(urlInput, { target: { value: 'https://example.com/mcp' } }) + fireEvent.change(nameInput, { target: { value: 'Test Server' } }) + fireEvent.change(identifierInput, { target: { value: 'test-server' } }) + + // Switch to configurations tab + const configTab = screen.getByText('tools.mcp.modal.configurations') + fireEvent.click(configTab) + + await waitFor(() => { + expect(screen.getByText('tools.mcp.modal.timeout')).toBeInTheDocument() + }) + + const confirmButton = screen.getByText('tools.mcp.modal.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(onConfirm).toHaveBeenCalled() + }) + }) + }) + + describe('Dynamic Registration', () => { + it('should toggle dynamic registration', async () => { + render(, { wrapper: createWrapper() }) + + // Find the switch for dynamic registration + const switchElements = screen.getAllByRole('switch') + expect(switchElements.length).toBeGreaterThan(0) + + // Click the first switch (dynamic registration) + fireEvent.click(switchElements[0]) + + // The switch should toggle + expect(switchElements[0]).toBeInTheDocument() + }) + }) + + describe('App Icon Picker Interactions', () => { + it('should open app icon picker when app icon is clicked', async () => { + render(, { wrapper: createWrapper() }) + + // Find the app icon container with cursor-pointer and rounded-2xl classes + const appIconContainer = document.querySelector('[class*="rounded-2xl"][class*="cursor-pointer"]') + + if (appIconContainer) { + fireEvent.click(appIconContainer) + + // The mocked AppIconPicker should now be visible + await waitFor(() => { + expect(screen.getByTestId('app-icon-picker')).toBeInTheDocument() + }) + } + }) + + it('should close app icon picker and update icon when selecting an icon', async () => { + render(, { wrapper: createWrapper() }) + + // Open the icon picker + const appIconContainer = document.querySelector('[class*="rounded-2xl"][class*="cursor-pointer"]') + + if (appIconContainer) { + fireEvent.click(appIconContainer) + + await waitFor(() => { + expect(screen.getByTestId('app-icon-picker')).toBeInTheDocument() + }) + + // Click the select emoji button + const selectBtn = screen.getByTestId('select-emoji-btn') + fireEvent.click(selectBtn) + + // The picker should be closed + await waitFor(() => { + expect(screen.queryByTestId('app-icon-picker')).not.toBeInTheDocument() + }) + } + }) + + it('should close app icon picker and reset icon when close button is clicked', async () => { + render(, { wrapper: createWrapper() }) + + // Open the icon picker + const appIconContainer = document.querySelector('[class*="rounded-2xl"][class*="cursor-pointer"]') + + if (appIconContainer) { + fireEvent.click(appIconContainer) + + await waitFor(() => { + expect(screen.getByTestId('app-icon-picker')).toBeInTheDocument() + }) + + // Click the close button + const closeBtn = screen.getByTestId('close-picker-btn') + fireEvent.click(closeBtn) + + // The picker should be closed + await waitFor(() => { + expect(screen.queryByTestId('app-icon-picker')).not.toBeInTheDocument() + }) + } + }) + }) +}) diff --git a/web/app/components/tools/mcp/modal.tsx b/web/app/components/tools/mcp/modal.tsx index c5cde65674..76ba42f2bf 100644 --- a/web/app/components/tools/mcp/modal.tsx +++ b/web/app/components/tools/mcp/modal.tsx @@ -1,429 +1,298 @@ 'use client' -import type { HeaderItem } from './headers-input' +import type { FC } from 'react' import type { AppIconSelection } from '@/app/components/base/app-icon-picker' import type { ToolWithProvider } from '@/app/components/workflow/types' import type { AppIconType } from '@/types/app' import { RiCloseLine, RiEditLine } from '@remixicon/react' import { useHover } from 'ahooks' -import { noop } from 'es-toolkit/compat' -import * as React from 'react' -import { useCallback, useRef, useState } from 'react' +import { noop } from 'es-toolkit/function' import { useTranslation } from 'react-i18next' -import { getDomain } from 'tldts' -import { v4 as uuid } from 'uuid' import AppIcon from '@/app/components/base/app-icon' import AppIconPicker from '@/app/components/base/app-icon-picker' import Button from '@/app/components/base/button' import { Mcp } from '@/app/components/base/icons/src/vender/other' -import AlertTriangle from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback/AlertTriangle' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import Switch from '@/app/components/base/switch' import TabSlider from '@/app/components/base/tab-slider' import Toast from '@/app/components/base/toast' import { MCPAuthMethod } from '@/app/components/tools/types' -import { API_PREFIX } from '@/config' -import { uploadRemoteFileInfo } from '@/service/common' import { cn } from '@/utils/classnames' import { shouldUseMcpIconForAppIcon } from '@/utils/mcp' -import HeadersInput from './headers-input' +import { isValidServerID, isValidUrl, useMCPModalForm } from './hooks/use-mcp-modal-form' +import AuthenticationSection from './sections/authentication-section' +import ConfigurationsSection from './sections/configurations-section' +import HeadersSection from './sections/headers-section' + +export type MCPModalConfirmPayload = { + name: string + server_url: string + icon_type: AppIconType + icon: string + icon_background?: string | null + server_identifier: string + headers?: Record + is_dynamic_registration?: boolean + authentication?: { + client_id?: string + client_secret?: string + grant_type?: string + } + configuration: { + timeout: number + sse_read_timeout: number + } +} export type DuplicateAppModalProps = { data?: ToolWithProvider show: boolean - onConfirm: (info: { - name: string - server_url: string - icon_type: AppIconType - icon: string - icon_background?: string | null - server_identifier: string - headers?: Record - is_dynamic_registration?: boolean - authentication?: { - client_id?: string - client_secret?: string - grant_type?: string - } - configuration: { - timeout: number - sse_read_timeout: number - } - }) => void + onConfirm: (info: MCPModalConfirmPayload) => void onHide: () => void } -const DEFAULT_ICON = { type: 'emoji', icon: '🔗', background: '#6366F1' } -const extractFileId = (url: string) => { - const match = url.match(/files\/(.+?)\/file-preview/) - return match ? match[1] : null -} -const getIcon = (data?: ToolWithProvider) => { - if (!data) - return DEFAULT_ICON as AppIconSelection - if (typeof data.icon === 'string') - return { type: 'image', url: data.icon, fileId: extractFileId(data.icon) } as AppIconSelection - return { - ...data.icon, - icon: data.icon.content, - type: 'emoji', - } as unknown as AppIconSelection +type MCPModalContentProps = { + data?: ToolWithProvider + onConfirm: (info: MCPModalConfirmPayload) => void + onHide: () => void } -const MCPModal = ({ +const MCPModalContent: FC = ({ data, - show, onConfirm, onHide, -}: DuplicateAppModalProps) => { +}) => { const { t } = useTranslation() - const isCreate = !data + + const { + isCreate, + originalServerUrl, + originalServerID, + appIconRef, + state, + actions, + } = useMCPModalForm(data) + + const isHovering = useHover(appIconRef) const authMethods = [ - { - text: t('tools.mcp.modal.authentication'), - value: MCPAuthMethod.authentication, - }, - { - text: t('tools.mcp.modal.headers'), - value: MCPAuthMethod.headers, - }, - { - text: t('tools.mcp.modal.configurations'), - value: MCPAuthMethod.configurations, - }, + { text: t('mcp.modal.authentication', { ns: 'tools' }), value: MCPAuthMethod.authentication }, + { text: t('mcp.modal.headers', { ns: 'tools' }), value: MCPAuthMethod.headers }, + { text: t('mcp.modal.configurations', { ns: 'tools' }), value: MCPAuthMethod.configurations }, ] - const originalServerUrl = data?.server_url - const originalServerID = data?.server_identifier - const [url, setUrl] = React.useState(data?.server_url || '') - const [name, setName] = React.useState(data?.name || '') - const [appIcon, setAppIcon] = useState(() => getIcon(data)) - const [showAppIconPicker, setShowAppIconPicker] = useState(false) - const [serverIdentifier, setServerIdentifier] = React.useState(data?.server_identifier || '') - const [timeout, setMcpTimeout] = React.useState(data?.configuration?.timeout || 30) - const [sseReadTimeout, setSseReadTimeout] = React.useState(data?.configuration?.sse_read_timeout || 300) - const [headers, setHeaders] = React.useState( - Object.entries(data?.masked_headers || {}).map(([key, value]) => ({ id: uuid(), key, value })), - ) - const [isFetchingIcon, setIsFetchingIcon] = useState(false) - const appIconRef = useRef(null) - const isHovering = useHover(appIconRef) - const [authMethod, setAuthMethod] = useState(MCPAuthMethod.authentication) - const [isDynamicRegistration, setIsDynamicRegistration] = useState(isCreate ? true : data?.is_dynamic_registration) - const [clientID, setClientID] = useState(data?.authentication?.client_id || '') - const [credentials, setCredentials] = useState(data?.authentication?.client_secret || '') - - // Update states when data changes (for edit mode) - React.useEffect(() => { - if (data) { - setUrl(data.server_url || '') - setName(data.name || '') - setServerIdentifier(data.server_identifier || '') - setMcpTimeout(data.configuration?.timeout || 30) - setSseReadTimeout(data.configuration?.sse_read_timeout || 300) - setHeaders(Object.entries(data.masked_headers || {}).map(([key, value]) => ({ id: uuid(), key, value }))) - setAppIcon(getIcon(data)) - setIsDynamicRegistration(data.is_dynamic_registration) - setClientID(data.authentication?.client_id || '') - setCredentials(data.authentication?.client_secret || '') - } - else { - // Reset for create mode - setUrl('') - setName('') - setServerIdentifier('') - setMcpTimeout(30) - setSseReadTimeout(300) - setHeaders([]) - setAppIcon(DEFAULT_ICON as AppIconSelection) - setIsDynamicRegistration(true) - setClientID('') - setCredentials('') - } - }, [data]) - - const isValidUrl = (string: string) => { - try { - const url = new URL(string) - return url.protocol === 'http:' || url.protocol === 'https:' - } - catch { - return false - } - } - - const isValidServerID = (str: string) => { - return /^[a-z0-9_-]{1,24}$/.test(str) - } - - const handleBlur = async (url: string) => { - if (data) - return - if (!isValidUrl(url)) - return - const domain = getDomain(url) - const remoteIcon = `https://www.google.com/s2/favicons?domain=${domain}&sz=128` - setIsFetchingIcon(true) - try { - const res = await uploadRemoteFileInfo(remoteIcon, undefined, true) - setAppIcon({ type: 'image', url: res.url, fileId: extractFileId(res.url) || '' }) - } - catch (e) { - let errorMessage = 'Failed to fetch remote icon' - const errorData = await (e as Response).json() - if (errorData?.code) - errorMessage = `Upload failed: ${errorData.code}` - console.error('Failed to fetch remote icon:', e) - Toast.notify({ type: 'warning', message: errorMessage }) - } - finally { - setIsFetchingIcon(false) - } - } const submit = async () => { - if (!isValidUrl(url)) { + if (!isValidUrl(state.url)) { Toast.notify({ type: 'error', message: 'invalid server url' }) return } - if (!isValidServerID(serverIdentifier.trim())) { + if (!isValidServerID(state.serverIdentifier.trim())) { Toast.notify({ type: 'error', message: 'invalid server identifier' }) return } - const formattedHeaders = headers.reduce((acc, item) => { + const formattedHeaders = state.headers.reduce((acc, item) => { if (item.key.trim()) acc[item.key.trim()] = item.value return acc }, {} as Record) + await onConfirm({ - server_url: originalServerUrl === url ? '[__HIDDEN__]' : url.trim(), - name, - icon_type: appIcon.type, - icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, - icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, - server_identifier: serverIdentifier.trim(), + server_url: originalServerUrl === state.url ? '[__HIDDEN__]' : state.url.trim(), + name: state.name, + icon_type: state.appIcon.type, + icon: state.appIcon.type === 'emoji' ? state.appIcon.icon : state.appIcon.fileId, + icon_background: state.appIcon.type === 'emoji' ? state.appIcon.background : undefined, + server_identifier: state.serverIdentifier.trim(), headers: Object.keys(formattedHeaders).length > 0 ? formattedHeaders : undefined, - is_dynamic_registration: isDynamicRegistration, + is_dynamic_registration: state.isDynamicRegistration, authentication: { - client_id: clientID, - client_secret: credentials, + client_id: state.clientID, + client_secret: state.credentials, }, configuration: { - timeout: timeout || 30, - sse_read_timeout: sseReadTimeout || 300, + timeout: state.timeout || 30, + sse_read_timeout: state.sseReadTimeout || 300, }, }) if (isCreate) onHide() } - const handleAuthMethodChange = useCallback((value: string) => { - setAuthMethod(value as MCPAuthMethod) - }, []) + const handleIconSelect = (payload: AppIconSelection) => { + actions.setAppIcon(payload) + actions.setShowAppIconPicker(false) + } + + const handleIconClose = () => { + actions.resetIcon() + actions.setShowAppIconPicker(false) + } + + const isSubmitDisabled = !state.name || !state.url || !state.serverIdentifier || state.isFetchingIcon return ( <> - -
- -
-
{!isCreate ? t('tools.mcp.modal.editTitle') : t('tools.mcp.modal.title')}
-
-
-
- {t('tools.mcp.modal.serverUrl')} -
- setUrl(e.target.value)} - onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.serverUrlPlaceholder')} - /> - {originalServerUrl && originalServerUrl !== url && ( -
- {t('tools.mcp.modal.serverUrlWarning')} -
- )} +
+ +
+
+ {!isCreate ? t('mcp.modal.editTitle', { ns: 'tools' }) : t('mcp.modal.title', { ns: 'tools' })} +
+ +
+ {/* Server URL */} +
+
+ {t('mcp.modal.serverUrl', { ns: 'tools' })}
-
-
-
- {t('tools.mcp.modal.name')} -
- setName(e.target.value)} - placeholder={t('tools.mcp.modal.namePlaceholder')} - /> -
-
- : undefined} - size="xxl" - className="relative cursor-pointer rounded-2xl" - coverElement={ - isHovering - ? ( -
- -
- ) - : null - } - onClick={() => { setShowAppIconPicker(true) }} - /> -
-
-
-
- {t('tools.mcp.modal.serverIdentifier')} -
-
{t('tools.mcp.modal.serverIdentifierTip')}
- setServerIdentifier(e.target.value)} - placeholder={t('tools.mcp.modal.serverIdentifierPlaceholder')} - /> - {originalServerID && originalServerID !== serverIdentifier && ( -
- {t('tools.mcp.modal.serverIdentifierWarning')} -
- )} -
- { - return `flex-1 ${isActive && 'text-text-accent-light-mode-only'}` - }} - value={authMethod} - onChange={handleAuthMethodChange} - options={authMethods} + actions.setUrl(e.target.value)} + onBlur={e => actions.handleUrlBlur(e.target.value.trim())} + placeholder={t('mcp.modal.serverUrlPlaceholder', { ns: 'tools' })} /> - { - authMethod === MCPAuthMethod.authentication && ( - <> -
-
- - {t('tools.mcp.modal.useDynamicClientRegistration')} -
- {!isDynamicRegistration && ( -
- -
-
{t('tools.mcp.modal.redirectUrlWarning')}
- - {`${API_PREFIX}/mcp/oauth/callback`} - + {originalServerUrl && originalServerUrl !== state.url && ( +
+ {t('mcp.modal.serverUrlWarning', { ns: 'tools' })} +
+ )} +
+ + {/* Name and Icon */} +
+
+
+ {t('mcp.modal.name', { ns: 'tools' })} +
+ actions.setName(e.target.value)} + placeholder={t('mcp.modal.namePlaceholder', { ns: 'tools' })} + /> +
+
+ : undefined} + size="xxl" + className="relative cursor-pointer rounded-2xl" + coverElement={ + isHovering + ? ( +
+
-
- )} -
-
-
- {t('tools.mcp.modal.clientID')} -
- setClientID(e.target.value)} - onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.clientID')} - disabled={isDynamicRegistration} - /> -
-
-
- {t('tools.mcp.modal.clientSecret')} -
- setCredentials(e.target.value)} - onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.clientSecretPlaceholder')} - disabled={isDynamicRegistration} - /> -
- - ) - } - { - authMethod === MCPAuthMethod.headers && ( -
-
- {t('tools.mcp.modal.headers')} -
-
{t('tools.mcp.modal.headersTip')}
- item.key.trim()).length > 0} - /> -
- ) - } - { - authMethod === MCPAuthMethod.configurations && ( - <> -
-
- {t('tools.mcp.modal.timeout')} -
- setMcpTimeout(Number(e.target.value))} - onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.timeoutPlaceholder')} - /> -
-
-
- {t('tools.mcp.modal.sseReadTimeout')} -
- setSseReadTimeout(Number(e.target.value))} - onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.timeoutPlaceholder')} - /> -
- - ) - } + ) + : null + } + onClick={() => actions.setShowAppIconPicker(true)} + /> +
-
- - + + {/* Server Identifier */} +
+
+ {t('mcp.modal.serverIdentifier', { ns: 'tools' })} +
+
{t('mcp.modal.serverIdentifierTip', { ns: 'tools' })}
+ actions.setServerIdentifier(e.target.value)} + placeholder={t('mcp.modal.serverIdentifierPlaceholder', { ns: 'tools' })} + /> + {originalServerID && originalServerID !== state.serverIdentifier && ( +
+ {t('mcp.modal.serverIdentifierWarning', { ns: 'tools' })} +
+ )}
- - {showAppIconPicker && ( + + {/* Auth Method Tabs */} + `flex-1 ${isActive && 'text-text-accent-light-mode-only'}`} + value={state.authMethod} + onChange={actions.setAuthMethod} + options={authMethods} + /> + + {/* Tab Content */} + {state.authMethod === MCPAuthMethod.authentication && ( + + )} + {state.authMethod === MCPAuthMethod.headers && ( + + )} + {state.authMethod === MCPAuthMethod.configurations && ( + + )} +
+ + {/* Actions */} +
+ + +
+ + {state.showAppIconPicker && ( { - setAppIcon(payload) - setShowAppIconPicker(false) - }} - onClose={() => { - setAppIcon(getIcon(data)) - setShowAppIconPicker(false) - }} + onSelect={handleIconSelect} + onClose={handleIconClose} /> )} + ) +} +/** + * MCP Modal component for creating and editing MCP server configurations. + * + * Uses a keyed inner component to ensure form state resets when switching + * between create mode and edit mode with different data. + */ +const MCPModal: FC = ({ + data, + show, + onConfirm, + onHide, +}) => { + // Use data ID as key to reset form state when switching between items + const formKey = data?.id ?? 'create' + + return ( + + + ) } diff --git a/web/app/components/tools/mcp/provider-card.spec.tsx b/web/app/components/tools/mcp/provider-card.spec.tsx new file mode 100644 index 0000000000..216607ce5a --- /dev/null +++ b/web/app/components/tools/mcp/provider-card.spec.tsx @@ -0,0 +1,524 @@ +import type { ReactNode } from 'react' +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import MCPCard from './provider-card' + +// Mutable mock functions +const mockUpdateMCP = vi.fn().mockResolvedValue({ result: 'success' }) +const mockDeleteMCP = vi.fn().mockResolvedValue({ result: 'success' }) + +// Mock the services +vi.mock('@/service/use-tools', () => ({ + useUpdateMCP: () => ({ + mutateAsync: mockUpdateMCP, + }), + useDeleteMCP: () => ({ + mutateAsync: mockDeleteMCP, + }), +})) + +// Mock the MCPModal +type MCPModalForm = { + name: string + server_url: string +} + +type MCPModalProps = { + show: boolean + onConfirm: (form: MCPModalForm) => void + onHide: () => void +} + +vi.mock('./modal', () => ({ + default: ({ show, onConfirm, onHide }: MCPModalProps) => { + if (!show) + return null + return ( +
+ + +
+ ) + }, +})) + +// Mock the Confirm dialog +type ConfirmDialogProps = { + isShow: boolean + onConfirm: () => void + onCancel: () => void + isLoading: boolean +} + +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ isShow, onConfirm, onCancel, isLoading }: ConfirmDialogProps) => { + if (!isShow) + return null + return ( +
+ + +
+ ) + }, +})) + +// Mock the OperationDropdown +type OperationDropdownProps = { + onEdit: () => void + onRemove: () => void + onOpenChange: (open: boolean) => void +} + +vi.mock('./detail/operation-dropdown', () => ({ + default: ({ onEdit, onRemove, onOpenChange }: OperationDropdownProps) => ( +
+ + +
+ ), +})) + +// Mock the app context +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceManager: true, + isCurrentWorkspaceEditor: true, + }), +})) + +// Mock the format time hook +vi.mock('@/hooks/use-format-time-from-now', () => ({ + useFormatTimeFromNow: () => ({ + formatTimeFromNow: (_timestamp: number) => '2 hours ago', + }), +})) + +// Mock the plugins service +vi.mock('@/service/use-plugins', () => ({ + useInstalledPluginList: () => ({ + data: { pages: [] }, + hasNextPage: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + isLoading: false, + isSuccess: true, + }), +})) + +// Mock common service +vi.mock('@/service/common', () => ({ + uploadRemoteFileInfo: vi.fn().mockResolvedValue({ url: 'https://example.com/icon.png' }), +})) + +describe('MCPCard', () => { + const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + return ({ children }: { children: ReactNode }) => + React.createElement(QueryClientProvider, { client: queryClient }, children) + } + + const createMockData = (overrides = {}): ToolWithProvider => ({ + id: 'mcp-1', + name: 'Test MCP Server', + server_identifier: 'test-server', + icon: { content: '🔧', background: '#FF0000' }, + tools: [ + { name: 'tool1', description: 'Tool 1' }, + { name: 'tool2', description: 'Tool 2' }, + ], + is_team_authorization: true, + updated_at: Date.now() / 1000, + ...overrides, + } as unknown as ToolWithProvider) + + const defaultProps = { + data: createMockData(), + handleSelect: vi.fn(), + onUpdate: vi.fn(), + onDeleted: vi.fn(), + } + + beforeEach(() => { + mockUpdateMCP.mockClear() + mockDeleteMCP.mockClear() + mockUpdateMCP.mockResolvedValue({ result: 'success' }) + mockDeleteMCP.mockResolvedValue({ result: 'success' }) + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('Test MCP Server')).toBeInTheDocument() + }) + + it('should display MCP name', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('Test MCP Server')).toBeInTheDocument() + }) + + it('should display server identifier', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('test-server')).toBeInTheDocument() + }) + + it('should display tools count', () => { + render(, { wrapper: createWrapper() }) + // The tools count uses i18n with count parameter + expect(screen.getByText(/tools.mcp.toolsCount/)).toBeInTheDocument() + }) + + it('should display update time', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText(/tools.mcp.updateTime/)).toBeInTheDocument() + }) + }) + + describe('No Tools State', () => { + it('should show no tools message when tools array is empty', () => { + const dataWithNoTools = createMockData({ tools: [] }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.noTools')).toBeInTheDocument() + }) + + it('should show not configured badge when not authorized', () => { + const dataNotAuthorized = createMockData({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.noConfigured')).toBeInTheDocument() + }) + + it('should show not configured badge when no tools', () => { + const dataWithNoTools = createMockData({ tools: [], is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.noConfigured')).toBeInTheDocument() + }) + }) + + describe('Selected State', () => { + it('should apply selected styles when current provider matches', () => { + render( + , + { wrapper: createWrapper() }, + ) + const card = document.querySelector('[class*="border-components-option-card-option-selected-border"]') + expect(card).toBeInTheDocument() + }) + + it('should not apply selected styles when different provider', () => { + const differentProvider = createMockData({ id: 'different-id' }) + render( + , + { wrapper: createWrapper() }, + ) + const card = document.querySelector('[class*="border-components-option-card-option-selected-border"]') + expect(card).not.toBeInTheDocument() + }) + }) + + describe('User Interactions', () => { + it('should call handleSelect when card is clicked', () => { + const handleSelect = vi.fn() + render( + , + { wrapper: createWrapper() }, + ) + + const card = screen.getByText('Test MCP Server').closest('[class*="cursor-pointer"]') + if (card) { + fireEvent.click(card) + expect(handleSelect).toHaveBeenCalledWith('mcp-1') + } + }) + }) + + describe('Card Icon', () => { + it('should render card icon', () => { + render(, { wrapper: createWrapper() }) + // Icon component is rendered + const iconContainer = document.querySelector('[class*="rounded-xl"][class*="border"]') + expect(iconContainer).toBeInTheDocument() + }) + }) + + describe('Status Indicator', () => { + it('should show green indicator when authorized and has tools', () => { + const data = createMockData({ is_team_authorization: true, tools: [{ name: 'tool1' }] }) + render( + , + { wrapper: createWrapper() }, + ) + // Should have green indicator (not showing red badge) + expect(screen.queryByText('tools.mcp.noConfigured')).not.toBeInTheDocument() + }) + + it('should show red indicator when not configured', () => { + const data = createMockData({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.noConfigured')).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle long MCP name', () => { + const longName = 'A'.repeat(100) + const data = createMockData({ name: longName }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText(longName)).toBeInTheDocument() + }) + + it('should handle special characters in name', () => { + const data = createMockData({ name: 'Test