diff --git a/.claude/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md
similarity index 99%
rename from .claude/skills/component-refactoring/SKILL.md
rename to .agents/skills/component-refactoring/SKILL.md
index ea695ea442..7006c382c8 100644
--- a/.claude/skills/component-refactoring/SKILL.md
+++ b/.agents/skills/component-refactoring/SKILL.md
@@ -187,7 +187,7 @@ const Template = useMemo(() => {
**When**: Component directly handles API calls, data transformation, or complex async operations.
-**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks. Project is migrating from SWR to React Query.
+**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
```typescript
// ❌ Before: API logic in component
diff --git a/.claude/skills/component-refactoring/references/complexity-patterns.md b/.agents/skills/component-refactoring/references/complexity-patterns.md
similarity index 100%
rename from .claude/skills/component-refactoring/references/complexity-patterns.md
rename to .agents/skills/component-refactoring/references/complexity-patterns.md
diff --git a/.claude/skills/component-refactoring/references/component-splitting.md b/.agents/skills/component-refactoring/references/component-splitting.md
similarity index 100%
rename from .claude/skills/component-refactoring/references/component-splitting.md
rename to .agents/skills/component-refactoring/references/component-splitting.md
diff --git a/.claude/skills/component-refactoring/references/hook-extraction.md b/.agents/skills/component-refactoring/references/hook-extraction.md
similarity index 100%
rename from .claude/skills/component-refactoring/references/hook-extraction.md
rename to .agents/skills/component-refactoring/references/hook-extraction.md
diff --git a/.agents/skills/frontend-code-review/SKILL.md b/.agents/skills/frontend-code-review/SKILL.md
new file mode 100644
index 0000000000..6cc23ca171
--- /dev/null
+++ b/.agents/skills/frontend-code-review/SKILL.md
@@ -0,0 +1,73 @@
+---
+name: frontend-code-review
+description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
+---
+
+# Frontend Code Review
+
+## Intent
+Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
+
+1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission.
+2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings.
+
+Stick to the checklist below for every applicable file and mode.
+
+## Checklist
+See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
+
+Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
+
+## Review Process
+1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
+2. For each rule in the review point, note where the code deviates and capture a representative snippet.
+3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
+
+## Required output
+When invoked, the response must exactly follow one of the two templates:
+
+### Template A (any findings)
+```
+# Code review
+Found urgent issues need to be fixed:
+
+## 1
+FilePath: line
+
+
+
+### Suggested fix
+
+
+---
+... (repeat for each urgent issue) ...
+
+Found suggestions for improvement:
+
+## 1
+FilePath: line
+
+
+
+### Suggested fix
+
+
+---
+
+... (repeat for each suggestion) ...
+```
+
+If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
+
+If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
+
+Don't compress the blank lines between sections; keep them as-is for readability.
+
+If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
+
+### Template B (no issues)
+```
+## Code review
+No issues found.
+```
+
diff --git a/.agents/skills/frontend-code-review/references/business-logic.md b/.agents/skills/frontend-code-review/references/business-logic.md
new file mode 100644
index 0000000000..4584f99dfc
--- /dev/null
+++ b/.agents/skills/frontend-code-review/references/business-logic.md
@@ -0,0 +1,15 @@
+# Rule Catalog — Business Logic
+
+## Can't use workflowStore in Node components
+
+IsUrgent: True
+
+### Description
+
+File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
+
+Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
+
+### Suggested Fix
+
+Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.
diff --git a/.agents/skills/frontend-code-review/references/code-quality.md b/.agents/skills/frontend-code-review/references/code-quality.md
new file mode 100644
index 0000000000..afdd40deb3
--- /dev/null
+++ b/.agents/skills/frontend-code-review/references/code-quality.md
@@ -0,0 +1,44 @@
+# Rule Catalog — Code Quality
+
+## Conditional class names use utility function
+
+IsUrgent: True
+Category: Code Quality
+
+### Description
+
+Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
+
+### Suggested Fix
+
+```ts
+import { cn } from '@/utils/classnames'
+const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
+```
+
+## Tailwind-first styling
+
+IsUrgent: True
+Category: Code Quality
+
+### Description
+
+Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
+
+Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
+
+## Classname ordering for easy overrides
+
+### Description
+
+When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles.
+
+Example:
+
+```tsx
+import { cn } from '@/utils/classnames'
+
+const Button = ({ className }) => {
+ return
+}
+```
diff --git a/.agents/skills/frontend-code-review/references/performance.md b/.agents/skills/frontend-code-review/references/performance.md
new file mode 100644
index 0000000000..2d60072f5c
--- /dev/null
+++ b/.agents/skills/frontend-code-review/references/performance.md
@@ -0,0 +1,45 @@
+# Rule Catalog — Performance
+
+## React Flow data usage
+
+IsUrgent: True
+Category: Performance
+
+### Description
+
+When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
+
+## Complex prop memoization
+
+IsUrgent: True
+Category: Performance
+
+### Description
+
+Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
+
+Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
+
+Wrong:
+
+```tsx
+
+```
+
+Right:
+
+```tsx
+const config = useMemo(() => ({
+ provider: ...,
+ detail: ...
+}), [provider, detail]);
+
+
+```
diff --git a/.claude/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md
similarity index 98%
rename from .claude/skills/frontend-testing/SKILL.md
rename to .agents/skills/frontend-testing/SKILL.md
index dd9677a78e..0716c81ef7 100644
--- a/.claude/skills/frontend-testing/SKILL.md
+++ b/.agents/skills/frontend-testing/SKILL.md
@@ -83,6 +83,9 @@ vi.mock('next/navigation', () => ({
usePathname: () => '/test',
}))
+// ✅ Zustand stores: Use real stores (auto-mocked globally)
+// Set test state with: useAppStore.setState({ ... })
+
// Shared state for mocks (if needed)
let mockSharedState = false
@@ -296,7 +299,7 @@ For each test file generated, aim for:
For more detailed information, refer to:
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
-- `references/mocking.md` - Mock patterns and best practices
+- `references/mocking.md` - Mock patterns, Zustand store testing, and best practices
- `references/async-testing.md` - Async operations and API calls
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `references/common-patterns.md` - Frequently used testing patterns
diff --git a/.claude/skills/frontend-testing/assets/component-test.template.tsx b/.agents/skills/frontend-testing/assets/component-test.template.tsx
similarity index 97%
rename from .claude/skills/frontend-testing/assets/component-test.template.tsx
rename to .agents/skills/frontend-testing/assets/component-test.template.tsx
index c39baff916..6b7803bd4b 100644
--- a/.claude/skills/frontend-testing/assets/component-test.template.tsx
+++ b/.agents/skills/frontend-testing/assets/component-test.template.tsx
@@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event'
// i18n (automatically mocked)
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
-// No explicit mock needed - it returns translation keys as-is
+// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
+// No explicit mock needed for most tests
+//
// Override only if custom translations are required:
-// vi.mock('react-i18next', () => ({
-// useTranslation: () => ({
-// t: (key: string) => {
-// const customTranslations: Record = {
-// 'my.custom.key': 'Custom Translation',
-// }
-// return customTranslations[key] || key
-// },
-// }),
+// import { createReactI18nextMock } from '@/test/i18n-mock'
+// vi.mock('react-i18next', () => createReactI18nextMock({
+// 'my.custom.key': 'Custom Translation',
+// 'button.save': 'Save',
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)
diff --git a/.claude/skills/frontend-testing/assets/hook-test.template.ts b/.agents/skills/frontend-testing/assets/hook-test.template.ts
similarity index 100%
rename from .claude/skills/frontend-testing/assets/hook-test.template.ts
rename to .agents/skills/frontend-testing/assets/hook-test.template.ts
diff --git a/.claude/skills/frontend-testing/assets/utility-test.template.ts b/.agents/skills/frontend-testing/assets/utility-test.template.ts
similarity index 100%
rename from .claude/skills/frontend-testing/assets/utility-test.template.ts
rename to .agents/skills/frontend-testing/assets/utility-test.template.ts
diff --git a/.claude/skills/frontend-testing/references/async-testing.md b/.agents/skills/frontend-testing/references/async-testing.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/async-testing.md
rename to .agents/skills/frontend-testing/references/async-testing.md
diff --git a/.claude/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/checklist.md
rename to .agents/skills/frontend-testing/references/checklist.md
diff --git a/.claude/skills/frontend-testing/references/common-patterns.md b/.agents/skills/frontend-testing/references/common-patterns.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/common-patterns.md
rename to .agents/skills/frontend-testing/references/common-patterns.md
diff --git a/.claude/skills/frontend-testing/references/domain-components.md b/.agents/skills/frontend-testing/references/domain-components.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/domain-components.md
rename to .agents/skills/frontend-testing/references/domain-components.md
diff --git a/.claude/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md
similarity index 59%
rename from .claude/skills/frontend-testing/references/mocking.md
rename to .agents/skills/frontend-testing/references/mocking.md
index 23889c8d3d..86bd375987 100644
--- a/.claude/skills/frontend-testing/references/mocking.md
+++ b/.agents/skills/frontend-testing/references/mocking.md
@@ -37,38 +37,64 @@ Only mock these categories:
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
1. **i18n** - Always mock to return keys
+### Zustand Stores - DO NOT Mock Manually
+
+**Zustand is globally mocked** in `web/vitest.setup.ts`. Use real stores with `setState()`:
+
+```typescript
+// ✅ CORRECT: Use real store, set test state
+import { useAppStore } from '@/app/components/app/store'
+
+useAppStore.setState({ appDetail: { id: 'test', name: 'Test' } })
+render()
+
+// ❌ WRONG: Don't mock the store module
+vi.mock('@/app/components/app/store', () => ({ ... }))
+```
+
+See [Zustand Store Testing](#zustand-store-testing) section for full details.
+
## Mock Placement
| Location | Purpose |
|----------|---------|
-| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
+| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) |
+| `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) |
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
+**Note**: Zustand is special - it's globally mocked but you should NOT mock store modules manually. See [Zustand Store Testing](#zustand-store-testing).
+
## Essential Mocks
### 1. i18n (Auto-loaded via Global Mock)
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
-**No explicit mock needed** for most tests - it returns translation keys as-is.
-For tests requiring custom translations, override the mock:
+The global mock provides:
+
+- `useTranslation` - returns translation keys with namespace prefix
+- `Trans` component - renders i18nKey and components
+- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
+- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
+
+**Default behavior**: Most tests should use the global mock (no local override needed).
+
+**For custom translations**: Use the helper function from `@/test/i18n-mock`:
```typescript
-vi.mock('react-i18next', () => ({
- useTranslation: () => ({
- t: (key: string) => {
- const translations: Record = {
- 'my.custom.key': 'Custom translation',
- }
- return translations[key] || key
- },
- }),
+import { createReactI18nextMock } from '@/test/i18n-mock'
+
+vi.mock('react-i18next', () => createReactI18nextMock({
+ 'my.custom.key': 'Custom translation',
+ 'button.save': 'Save',
}))
```
+**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
+
### 2. Next.js Router
```typescript
@@ -270,6 +296,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
+1. **Use real Zustand stores** - Set test state via `store.setState()`
1. **Reset mocks in `beforeEach`**, not `afterEach`
1. **Match actual component behavior** in mocks (when mocking is necessary)
1. **Use factory functions** for complex mock data
@@ -279,6 +306,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
+1. **Don't mock Zustand store modules** - Use real stores with `setState()`
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
@@ -302,10 +330,151 @@ Need to use a component in test?
├─ Is it a third-party lib with side effects?
│ └─ YES → Mock it (next/navigation, external SDKs)
│
+├─ Is it a Zustand store?
+│ └─ YES → DO NOT mock the module!
+│ Use real store + setState() to set test state
+│ (Global mock handles auto-reset)
+│
└─ Is it i18n?
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
+## Zustand Store Testing
+
+### Global Zustand Mock (Auto-loaded)
+
+Zustand is globally mocked in `web/vitest.setup.ts` following the [official Zustand testing guide](https://zustand.docs.pmnd.rs/guides/testing). The mock in `web/__mocks__/zustand.ts` provides:
+
+- Real store behavior with `getState()`, `setState()`, `subscribe()` methods
+- Automatic store reset after each test via `afterEach`
+- Proper test isolation between tests
+
+### ✅ Recommended: Use Real Stores (Official Best Practice)
+
+**DO NOT mock store modules manually.** Import and use the real store, then use `setState()` to set test state:
+
+```typescript
+// ✅ CORRECT: Use real store with setState
+import { useAppStore } from '@/app/components/app/store'
+
+describe('MyComponent', () => {
+ it('should render app details', () => {
+ // Arrange: Set test state via setState
+ useAppStore.setState({
+ appDetail: {
+ id: 'test-app',
+ name: 'Test App',
+ mode: 'chat',
+ },
+ })
+
+ // Act
+ render()
+
+ // Assert
+ expect(screen.getByText('Test App')).toBeInTheDocument()
+ // Can also verify store state directly
+ expect(useAppStore.getState().appDetail?.name).toBe('Test App')
+ })
+
+ // No cleanup needed - global mock auto-resets after each test
+})
+```
+
+### ❌ Avoid: Manual Store Module Mocking
+
+Manual mocking conflicts with the global Zustand mock and loses store functionality:
+
+```typescript
+// ❌ WRONG: Don't mock the store module
+vi.mock('@/app/components/app/store', () => ({
+ useStore: (selector) => mockSelector(selector), // Missing getState, setState!
+}))
+
+// ❌ WRONG: This conflicts with global zustand mock
+vi.mock('@/app/components/workflow/store', () => ({
+ useWorkflowStore: vi.fn(() => mockState),
+}))
+```
+
+**Problems with manual mocking:**
+
+1. Loses `getState()`, `setState()`, `subscribe()` methods
+1. Conflicts with global Zustand mock behavior
+1. Requires manual maintenance of store API
+1. Tests don't reflect actual store behavior
+
+### When Manual Store Mocking is Necessary
+
+In rare cases where the store has complex initialization or side effects, you can mock it, but ensure you provide the full store API:
+
+```typescript
+// If you MUST mock (rare), include full store API
+const mockStore = {
+ appDetail: { id: 'test', name: 'Test' },
+ setAppDetail: vi.fn(),
+}
+
+vi.mock('@/app/components/app/store', () => ({
+ useStore: Object.assign(
+ (selector: (state: typeof mockStore) => unknown) => selector(mockStore),
+ {
+ getState: () => mockStore,
+ setState: vi.fn(),
+ subscribe: vi.fn(),
+ },
+ ),
+}))
+```
+
+### Store Testing Decision Tree
+
+```
+Need to test a component using Zustand store?
+│
+├─ Can you use the real store?
+│ └─ YES → Use real store + setState (RECOMMENDED)
+│ useAppStore.setState({ ... })
+│
+├─ Does the store have complex initialization/side effects?
+│ └─ YES → Consider mocking, but include full API
+│ (getState, setState, subscribe)
+│
+└─ Are you testing the store itself (not a component)?
+ └─ YES → Test store directly with getState/setState
+ const store = useMyStore
+ store.setState({ count: 0 })
+ store.getState().increment()
+ expect(store.getState().count).toBe(1)
+```
+
+### Example: Testing Store Actions
+
+```typescript
+import { useCounterStore } from '@/stores/counter'
+
+describe('Counter Store', () => {
+ it('should increment count', () => {
+ // Initial state (auto-reset by global mock)
+ expect(useCounterStore.getState().count).toBe(0)
+
+ // Call action
+ useCounterStore.getState().increment()
+
+ // Verify state change
+ expect(useCounterStore.getState().count).toBe(1)
+ })
+
+ it('should reset to initial state', () => {
+ // Set some state
+ useCounterStore.setState({ count: 100 })
+ expect(useCounterStore.getState().count).toBe(100)
+
+ // After this test, global mock will reset to initial state
+ })
+})
+```
+
## Factory Function Pattern
```typescript
diff --git a/.claude/skills/frontend-testing/references/workflow.md b/.agents/skills/frontend-testing/references/workflow.md
similarity index 100%
rename from .claude/skills/frontend-testing/references/workflow.md
rename to .agents/skills/frontend-testing/references/workflow.md
diff --git a/.agents/skills/orpc-contract-first/SKILL.md b/.agents/skills/orpc-contract-first/SKILL.md
new file mode 100644
index 0000000000..4e3bfc7a37
--- /dev/null
+++ b/.agents/skills/orpc-contract-first/SKILL.md
@@ -0,0 +1,46 @@
+---
+name: orpc-contract-first
+description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
+---
+
+# oRPC Contract-First Development
+
+## Project Structure
+
+```
+web/contract/
+├── base.ts # Base contract (inputStructure: 'detailed')
+├── router.ts # Router composition & type exports
+├── marketplace.ts # Marketplace contracts
+└── console/ # Console contracts by domain
+ ├── system.ts
+ └── billing.ts
+```
+
+## Workflow
+
+1. **Create contract** in `web/contract/console/{domain}.ts`
+ - Import `base` from `../base` and `type` from `@orpc/contract`
+ - Define route with `path`, `method`, `input`, `output`
+
+2. **Register in router** at `web/contract/router.ts`
+ - Import directly from domain file (no barrel files)
+ - Nest by API prefix: `billing: { invoices, bindPartnerStack }`
+
+3. **Create hooks** in `web/service/use-{domain}.ts`
+ - Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
+ - Use `consoleClient.{group}.{contract}()` for API calls
+
+## Key Rules
+
+- **Input structure**: Always use `{ params, query?, body? }` format
+- **Path params**: Use `{paramName}` in path, match in `params` object
+- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`)
+- **No barrel files**: Import directly from specific files
+- **Types**: Import from `@/types/`, use `type()` helper
+
+## Type Export
+
+```typescript
+export type ConsoleInputs = InferContractRouterInputs
+```
diff --git a/.claude/settings.json b/.claude/settings.json
new file mode 100644
index 0000000000..fe108722be
--- /dev/null
+++ b/.claude/settings.json
@@ -0,0 +1,15 @@
+{
+ "hooks": {
+ "PreToolUse": [
+ {
+ "matcher": "Bash",
+ "hooks": [
+ {
+ "type": "command",
+ "command": "npx -y block-no-verify@1.1.1"
+ }
+ ]
+ }
+ ]
+ }
+}
diff --git a/.claude/settings.json.example b/.claude/settings.json.example
deleted file mode 100644
index 1149895340..0000000000
--- a/.claude/settings.json.example
+++ /dev/null
@@ -1,19 +0,0 @@
-{
- "permissions": {
- "allow": [],
- "deny": []
- },
- "env": {
- "__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
- "GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
- },
- "enabledMcpjsonServers": [
- "context7",
- "sequential-thinking",
- "github",
- "fetch",
- "playwright",
- "ide"
- ],
- "enableAllProjectMcpServers": true
- }
\ No newline at end of file
diff --git a/.claude/skills/component-refactoring b/.claude/skills/component-refactoring
new file mode 120000
index 0000000000..53ae67e2f2
--- /dev/null
+++ b/.claude/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.claude/skills/frontend-code-review b/.claude/skills/frontend-code-review
new file mode 120000
index 0000000000..55654ffbd7
--- /dev/null
+++ b/.claude/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.claude/skills/frontend-testing b/.claude/skills/frontend-testing
new file mode 120000
index 0000000000..092cec7745
--- /dev/null
+++ b/.claude/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.claude/skills/orpc-contract-first b/.claude/skills/orpc-contract-first
new file mode 120000
index 0000000000..da47b335c7
--- /dev/null
+++ b/.claude/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.codex/skills b/.codex/skills
deleted file mode 120000
index 454b8427cd..0000000000
--- a/.codex/skills
+++ /dev/null
@@ -1 +0,0 @@
-../.claude/skills
\ No newline at end of file
diff --git a/.codex/skills/component-refactoring b/.codex/skills/component-refactoring
new file mode 120000
index 0000000000..53ae67e2f2
--- /dev/null
+++ b/.codex/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.codex/skills/frontend-code-review b/.codex/skills/frontend-code-review
new file mode 120000
index 0000000000..55654ffbd7
--- /dev/null
+++ b/.codex/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.codex/skills/frontend-testing b/.codex/skills/frontend-testing
new file mode 120000
index 0000000000..092cec7745
--- /dev/null
+++ b/.codex/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.codex/skills/orpc-contract-first b/.codex/skills/orpc-contract-first
new file mode 120000
index 0000000000..da47b335c7
--- /dev/null
+++ b/.codex/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh
index 220f77e5ce..637593b9de 100755
--- a/.devcontainer/post_create_command.sh
+++ b/.devcontainer/post_create_command.sh
@@ -8,7 +8,7 @@ pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
-echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
+echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
diff --git a/.github/labeler.yml b/.github/labeler.yml
new file mode 100644
index 0000000000..d1d324d381
--- /dev/null
+++ b/.github/labeler.yml
@@ -0,0 +1,3 @@
+web:
+ - changed-files:
+ - any-glob-to-any-file: 'web/**'
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index aa5a50918a..50dbde2aee 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
-- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
+- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index 76cbf64fca..190e00d9fe 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -39,12 +39,6 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- - name: Run pyrefly check
- run: |
- cd api
- uv add --dev pyrefly
- uv run pyrefly check || true
-
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
@@ -57,7 +51,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
- uses: hoverkraft-tech/compose-action@v2.0.2
+ uses: hoverkraft-tech/compose-action@v2
with:
compose-file: |
docker/docker-compose.middleware.yaml
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
index 97027c2218..4a8c61e7d2 100644
--- a/.github/workflows/autofix.yml
+++ b/.github/workflows/autofix.yml
@@ -12,22 +12,22 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Check Docker Compose inputs
id: docker-compose-changes
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- - uses: actions/setup-python@v5
+ - uses: actions/setup-python@v6
with:
python-version: "3.11"
- - uses: astral-sh/setup-uv@v6
+ - uses: astral-sh/setup-uv@v7
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
@@ -79,9 +79,32 @@ jobs:
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
+ - name: Install pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Setup Node.js
+ uses: actions/setup-node@v6
+ with:
+ node-version: 24
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Install web dependencies
+ run: |
+ cd web
+ pnpm install --frozen-lockfile
+
+ - name: ESLint autofix
+ run: |
+ cd web
+ pnpm lint:fix || true
+
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
- uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
+ uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index 44c9a92ab7..ac7f3a6b48 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -91,7 +91,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
@@ -113,7 +113,7 @@ jobs:
context: "web"
steps:
- name: Download digests
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v7
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*
diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml
index 101d973466..e20cf9850b 100644
--- a/.github/workflows/db-migration-test.yml
+++ b/.github/workflows/db-migration-test.yml
@@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: "3.12"
@@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: "3.12"
diff --git a/.github/workflows/deploy-trigger-dev.yml b/.github/workflows/deploy-agent-dev.yml
similarity index 69%
rename from .github/workflows/deploy-trigger-dev.yml
rename to .github/workflows/deploy-agent-dev.yml
index 2d9a904fc5..dd759f7ba5 100644
--- a/.github/workflows/deploy-trigger-dev.yml
+++ b/.github/workflows/deploy-agent-dev.yml
@@ -1,4 +1,4 @@
-name: Deploy Trigger Dev
+name: Deploy Agent Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- - "deploy/trigger-dev"
+ - "deploy/agent-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
- github.event.workflow_run.head_branch == 'deploy/trigger-dev'
+ github.event.workflow_run.head_branch == 'deploy/agent-dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v0.1.8
+ uses: appleboy/ssh-action@v1
with:
- host: ${{ secrets.TRIGGER_SSH_HOST }}
+ host: ${{ secrets.AGENT_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml
index cd1c86e668..38fa0b9a7f 100644
--- a/.github/workflows/deploy-dev.yml
+++ b/.github/workflows/deploy-dev.yml
@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v0.1.8
+ uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml
new file mode 100644
index 0000000000..7d5f0a22e7
--- /dev/null
+++ b/.github/workflows/deploy-hitl.yml
@@ -0,0 +1,29 @@
+name: Deploy HITL
+
+on:
+ workflow_run:
+ workflows: ["Build and Push API & Web"]
+ branches:
+ - "feat/hitl-frontend"
+ - "feat/hitl-backend"
+ types:
+ - completed
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ if: |
+ github.event.workflow_run.conclusion == 'success' &&
+ (
+ github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
+ github.event.workflow_run.head_branch == 'feat/hitl-backend'
+ )
+ steps:
+ - name: Deploy to server
+ uses: appleboy/ssh-action@v1
+ with:
+ host: ${{ secrets.HITL_SSH_HOST }}
+ username: ${{ secrets.SSH_USER }}
+ key: ${{ secrets.SSH_PRIVATE_KEY }}
+ script: |
+ ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
new file mode 100644
index 0000000000..06782b53c1
--- /dev/null
+++ b/.github/workflows/labeler.yml
@@ -0,0 +1,14 @@
+name: "Pull Request Labeler"
+on:
+ pull_request_target:
+
+jobs:
+ labeler:
+ permissions:
+ contents: read
+ pull-requests: write
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/labeler@v6
+ with:
+ sync-labels: true
diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml
index 876ec23a3d..d6653de950 100644
--- a/.github/workflows/main-ci.yml
+++ b/.github/workflows/main-ci.yml
@@ -27,7 +27,7 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- uses: dorny/paths-filter@v3
id: changes
with:
@@ -38,6 +38,7 @@ jobs:
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
+ - '.github/workflows/web-tests.yml'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 1870b1f670..b6df1d7e93 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- - uses: actions/stale@v5
+ - uses: actions/stale@v10
with:
days-before-issue-stale: 15
days-before-issue-close: 3
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index 8710f422fc..fdc05d1d65 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
files: |
api/**
@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: false
python-version: "3.12"
@@ -65,18 +65,23 @@ jobs:
defaults:
run:
working-directory: ./web
+ permissions:
+ checks: write
+ pull-requests: read
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
- files: web/**
+ files: |
+ web/**
+ .github/workflows/style.yml
- name: Install pnpm
uses: pnpm/action-setup@v4
@@ -85,10 +90,10 @@ jobs:
run_install: false
- name: Setup NodeJS
- uses: actions/setup-node@v4
+ uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
- node-version: 22
+ node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@@ -101,12 +106,31 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
- pnpm run lint
+ pnpm run lint:ci
+ # pnpm run lint:report
+ # continue-on-error: true
+
+ # - name: Annotate Code
+ # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
+ # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
+ # with:
+ # eslint-report: web/eslint_report.json
+ # github-token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Web tsslint
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm run lint:tss
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
- run: pnpm run type-check:tsgo
+ run: pnpm run type-check
+
+ - name: Web dead code check
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm run knip
superlinter:
name: SuperLinter
@@ -114,14 +138,14 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
files: |
**.sh
diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml
index b1ccd7417a..ec392cb3b2 100644
--- a/.github/workflows/tool-test-sdks.yaml
+++ b/.github/workflows/tool-test-sdks.yaml
@@ -16,23 +16,19 @@ jobs:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
- strategy:
- matrix:
- node-version: [16, 18, 20, 22]
-
defaults:
run:
working-directory: sdks/nodejs-client
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
with:
persist-credentials: false
- - name: Use Node.js ${{ matrix.node-version }}
- uses: actions/setup-node@v4
+ - name: Use Node.js
+ uses: actions/setup-node@v6
with:
- node-version: ${{ matrix.node-version }}
+ node-version: 24
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'
diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml
deleted file mode 100644
index 87e24a4f90..0000000000
--- a/.github/workflows/translate-i18n-base-on-english.yml
+++ /dev/null
@@ -1,85 +0,0 @@
-name: Translate i18n Files Based on English
-
-on:
- push:
- branches: [main]
- paths:
- - 'web/i18n/en-US/*.ts'
-
-permissions:
- contents: write
- pull-requests: write
-
-jobs:
- check-and-update:
- if: github.repository == 'langgenius/dify'
- runs-on: ubuntu-latest
- defaults:
- run:
- working-directory: web
- steps:
- - uses: actions/checkout@v4
- with:
- fetch-depth: 0
- token: ${{ secrets.GITHUB_TOKEN }}
-
- - name: Check for file changes in i18n/en-US
- id: check_files
- run: |
- git fetch origin "${{ github.event.before }}" || true
- git fetch origin "${{ github.sha }}" || true
- changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
- echo "Changed files: $changed_files"
- if [ -n "$changed_files" ]; then
- echo "FILES_CHANGED=true" >> $GITHUB_ENV
- file_args=""
- for file in $changed_files; do
- filename=$(basename "$file" .ts)
- file_args="$file_args --file $filename"
- done
- echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
- echo "File arguments: $file_args"
- else
- echo "FILES_CHANGED=false" >> $GITHUB_ENV
- fi
-
- - name: Install pnpm
- uses: pnpm/action-setup@v4
- with:
- package_json_file: web/package.json
- run_install: false
-
- - name: Set up Node.js
- if: env.FILES_CHANGED == 'true'
- uses: actions/setup-node@v4
- with:
- node-version: 'lts/*'
- cache: pnpm
- cache-dependency-path: ./web/pnpm-lock.yaml
-
- - name: Install dependencies
- if: env.FILES_CHANGED == 'true'
- working-directory: ./web
- run: pnpm install --frozen-lockfile
-
- - name: Generate i18n translations
- if: env.FILES_CHANGED == 'true'
- working-directory: ./web
- run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
-
- - name: Create Pull Request
- if: env.FILES_CHANGED == 'true'
- uses: peter-evans/create-pull-request@v6
- with:
- token: ${{ secrets.GITHUB_TOKEN }}
- commit-message: 'chore(i18n): update translations based on en-US changes'
- title: 'chore(i18n): translate i18n files based on en-US changes'
- body: |
- This PR was automatically created to update i18n translation files based on changes in en-US locale.
-
- **Triggered by:** ${{ github.sha }}
-
- **Changes included:**
- - Updated translation files for all locales
- branch: chore/automated-i18n-updates-${{ github.sha }}
- delete-branch: true
diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml
new file mode 100644
index 0000000000..5d9440ff35
--- /dev/null
+++ b/.github/workflows/translate-i18n-claude.yml
@@ -0,0 +1,440 @@
+name: Translate i18n Files with Claude Code
+
+# Note: claude-code-action doesn't support push events directly.
+# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
+# See: https://github.com/langgenius/dify/issues/30743
+
+on:
+ repository_dispatch:
+ types: [i18n-sync]
+ workflow_dispatch:
+ inputs:
+ files:
+ description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.'
+ required: false
+ type: string
+ languages:
+ description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.'
+ required: false
+ type: string
+ mode:
+ description: 'Sync mode: incremental (only changes) or full (re-check all keys)'
+ required: false
+ default: 'incremental'
+ type: choice
+ options:
+ - incremental
+ - full
+
+permissions:
+ contents: write
+ pull-requests: write
+
+jobs:
+ translate:
+ if: github.repository == 'langgenius/dify'
+ runs-on: ubuntu-latest
+ timeout-minutes: 60
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Configure Git
+ run: |
+ git config --global user.name "github-actions[bot]"
+ git config --global user.email "github-actions[bot]@users.noreply.github.com"
+
+ - name: Install pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Set up Node.js
+ uses: actions/setup-node@v6
+ with:
+ node-version: 24
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Detect changed files and generate diff
+ id: detect_changes
+ run: |
+ if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
+ # Manual trigger
+ if [ -n "${{ github.event.inputs.files }}" ]; then
+ echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT
+ else
+ # Get all JSON files in en-US directory
+ files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ')
+ echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT
+ fi
+ echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT
+ echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT
+
+ # For manual trigger with incremental mode, get diff from last commit
+ # For full mode, we'll do a complete check anyway
+ if [ "${{ github.event.inputs.mode }}" == "full" ]; then
+ echo "Full mode: will check all keys" > /tmp/i18n-diff.txt
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ else
+ git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
+ if [ -s /tmp/i18n-diff.txt ]; then
+ echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
+ else
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ fi
+ fi
+ elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
+ # Triggered by push via trigger-i18n-sync.yml workflow
+ # Validate required payload fields
+ if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
+ echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
+ exit 1
+ fi
+ echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
+ echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
+ echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
+
+ # Decode the base64-encoded diff from the trigger workflow
+ if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
+ if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
+ echo "Warning: Failed to decode base64 diff payload" >&2
+ echo "" > /tmp/i18n-diff.txt
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ elif [ -s /tmp/i18n-diff.txt ]; then
+ echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
+ else
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ fi
+ else
+ echo "" > /tmp/i18n-diff.txt
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ fi
+ else
+ echo "Unsupported event type: ${{ github.event_name }}"
+ exit 1
+ fi
+
+ # Truncate diff if too large (keep first 50KB)
+ if [ -f /tmp/i18n-diff.txt ]; then
+ head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
+ mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
+ fi
+
+ echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')"
+
+ - name: Run Claude Code for Translation Sync
+ if: steps.detect_changes.outputs.CHANGED_FILES != ''
+ uses: anthropics/claude-code-action@v1
+ with:
+ anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ # Allow github-actions bot to trigger this workflow via repository_dispatch
+ # See: https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
+ allowed_bots: 'github-actions[bot]'
+ prompt: |
+ You are a professional i18n synchronization engineer for the Dify project.
+ Your task is to keep all language translations in sync with the English source (en-US).
+
+ ## CRITICAL TOOL RESTRICTIONS
+ - Use **Read** tool to read files (NOT cat or bash)
+ - Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts)
+ - Use **Bash** ONLY for: git commands, gh commands, pnpm commands
+ - Run bash commands ONE BY ONE, never combine with && or ||
+ - NEVER use `$()` command substitution - it's not supported. Split into separate commands instead.
+
+ ## WORKING DIRECTORY & ABSOLUTE PATHS
+ Claude Code sandbox working directory may vary. Always use absolute paths:
+ - For pnpm: `pnpm --dir ${{ github.workspace }}/web `
+ - For git: `git -C ${{ github.workspace }} `
+ - For gh: `gh --repo ${{ github.repository }} `
+ - For file paths: `${{ github.workspace }}/web/i18n/`
+
+ ## EFFICIENCY RULES
+ - **ONE Edit per language file** - batch all key additions into a single Edit
+ - Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them
+ - Translate ALL keys for a language mentally first, then do ONE Edit
+
+ ## Context
+ - Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
+ - Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }}
+ - Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
+ - Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json
+ - Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts
+ - Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }}
+
+ ## CRITICAL DESIGN: Verify First, Then Sync
+
+ You MUST follow this three-phase approach:
+
+ ═══════════════════════════════════════════════════════════════
+ ║ PHASE 1: VERIFY - Analyze and Generate Change Report ║
+ ═══════════════════════════════════════════════════════════════
+
+ ### Step 1.1: Analyze Git Diff (for incremental mode)
+ Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff.
+
+ Parse the diff to categorize changes:
+ - Lines with `+` (not `+++`): Added or modified values
+ - Lines with `-` (not `---`): Removed or old values
+ - Identify specific keys for each category:
+ * ADD: Keys that appear only in `+` lines (new keys)
+ * UPDATE: Keys that appear in both `-` and `+` lines (value changed)
+ * DELETE: Keys that appear only in `-` lines (removed keys)
+
+ ### Step 1.2: Read Language Configuration
+ Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`.
+ Extract all languages with `supported: true`.
+
+ ### Step 1.3: Run i18n:check for Each Language
+ ```bash
+ pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile
+ ```
+ ```bash
+ pnpm --dir ${{ github.workspace }}/web run i18n:check
+ ```
+
+ This will report:
+ - Missing keys (need to ADD)
+ - Extra keys (need to DELETE)
+
+ ### Step 1.4: Generate Change Report
+
+ Create a structured report identifying:
+ ```
+ ╔══════════════════════════════════════════════════════════════╗
+ ║ I18N SYNC CHANGE REPORT ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ Files to process: [list] ║
+ ║ Languages to sync: [list] ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ ADD (New Keys): ║
+ ║ - [filename].[key]: "English value" ║
+ ║ ... ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ UPDATE (Modified Keys - MUST re-translate): ║
+ ║ - [filename].[key]: "Old value" → "New value" ║
+ ║ ... ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ DELETE (Extra Keys): ║
+ ║ - [language]/[filename].[key] ║
+ ║ ... ║
+ ╚══════════════════════════════════════════════════════════════╝
+ ```
+
+ **IMPORTANT**: For UPDATE detection, compare git diff to find keys where
+ the English value changed. These MUST be re-translated even if target
+ language already has a translation (it's now stale!).
+
+ ═══════════════════════════════════════════════════════════════
+ ║ PHASE 2: SYNC - Execute Changes Based on Report ║
+ ═══════════════════════════════════════════════════════════════
+
+ ### Step 2.1: Process ADD Operations (BATCH per language file)
+
+ **CRITICAL WORKFLOW for efficiency:**
+ 1. First, translate ALL new keys for ALL languages mentally
+ 2. Then, for EACH language file, do ONE Edit operation:
+ - Read the file once
+ - Insert ALL new keys at the beginning (right after the opening `{`)
+ - Don't worry about alphabetical order - lint:fix will sort them later
+
+ Example Edit (adding 3 keys to zh-Hans/app.json):
+ ```
+ old_string: '{\n "accessControl"'
+ new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"'
+ ```
+
+ **IMPORTANT**:
+ - ONE Edit per language file (not one Edit per key!)
+ - Always use the Edit tool. NEVER use bash scripts, node, or jq.
+
+ ### Step 2.2: Process UPDATE Operations
+
+ **IMPORTANT: Special handling for zh-Hans and ja-JP**
+ If zh-Hans or ja-JP files were ALSO modified in the same push:
+ - Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files
+ - If found, it means someone manually translated them. Apply these rules:
+
+ 1. **Missing keys**: Still ADD them (completeness required)
+ 2. **Existing translations**: Compare with the NEW English value:
+ - If translation is **completely wrong** or **unrelated** → Update it
+ - If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work
+ - When in doubt, **keep the manual translation**
+
+ Example:
+ - English changed: "Save" → "Save Changes"
+ - Manual translation: "保存更改" → Keep it (correct meaning)
+ - Manual translation: "删除" → Update it (completely wrong)
+
+ For other languages:
+ Use Edit tool to replace the old value with the new translation.
+ You can batch multiple updates in one Edit if they are adjacent.
+
+ ### Step 2.3: Process DELETE Operations
+ For extra keys reported by i18n:check:
+ - Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove`
+ - Or manually remove from target language JSON files
+
+ ## Translation Guidelines
+
+ - PRESERVE all placeholders exactly as-is:
+ - `{{variable}}` - Mustache interpolation
+ - `${variable}` - Template literal
+ - `content` - HTML tags
+ - `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
+
+ **CRITICAL: Variable names and tag names MUST stay in English - NEVER translate them**
+
+ ✅ CORRECT examples:
+ - English: "{{count}} items" → Japanese: "{{count}} 個のアイテム"
+ - English: "{{name}} updated" → Korean: "{{name}} 업데이트됨"
+ - English: "{{email}}" → Chinese: "{{email}}"
+ - English: "Marketplace" → Japanese: "マーケットプレイス"
+
+ ❌ WRONG examples (NEVER do this - will break the application):
+ - "{{count}}" → "{{カウント}}" ❌ (variable name translated to Japanese)
+ - "{{name}}" → "{{이름}}" ❌ (variable name translated to Korean)
+ - "{{email}}" → "{{邮箱}}" ❌ (variable name translated to Chinese)
+ - "" → "<メール>" ❌ (tag name translated)
+ - "" → "<自定义链接>" ❌ (component name translated)
+
+ - Use appropriate language register (formal/informal) based on existing translations
+ - Match existing translation style in each language
+ - Technical terms: check existing conventions per language
+ - For CJK languages: no spaces between characters unless necessary
+ - For RTL languages (ar-TN, fa-IR): ensure proper text handling
+
+ ## Output Format Requirements
+ - Alphabetical key ordering (if original file uses it)
+ - 2-space indentation
+ - Trailing newline at end of file
+ - Valid JSON (use proper escaping for special characters)
+
+ ═══════════════════════════════════════════════════════════════
+ ║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║
+ ═══════════════════════════════════════════════════════════════
+
+ ### Step 3.1: Run Lint Fix (IMPORTANT!)
+ ```bash
+ pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json'
+ ```
+ This ensures:
+ - JSON keys are sorted alphabetically (jsonc/sort-keys rule)
+ - Valid i18n keys (dify-i18n/valid-i18n-keys rule)
+ - No extra keys (dify-i18n/no-extra-keys rule)
+
+ ### Step 3.2: Run Final i18n Check
+ ```bash
+ pnpm --dir ${{ github.workspace }}/web run i18n:check
+ ```
+
+ ### Step 3.3: Fix Any Remaining Issues
+ If check reports issues:
+ - Go back to PHASE 2 for unresolved items
+ - Repeat until check passes
+
+ ### Step 3.4: Generate Final Summary
+ ```
+ ╔══════════════════════════════════════════════════════════════╗
+ ║ SYNC COMPLETED SUMMARY ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ Language │ Added │ Updated │ Deleted │ Status ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║
+ ║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║
+ ║ ... │ ... │ ... │ ... │ ... ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ i18n:check │ PASSED - All keys in sync ║
+ ╚══════════════════════════════════════════════════════════════╝
+ ```
+
+ ## Mode-Specific Behavior
+
+ **SYNC_MODE = "incremental"** (default):
+ - Focus on keys identified from git diff
+ - Also check i18n:check output for any missing/extra keys
+ - Efficient for small changes
+
+ **SYNC_MODE = "full"**:
+ - Compare ALL keys between en-US and each language
+ - Run i18n:check to identify all discrepancies
+ - Use for first-time sync or fixing historical issues
+
+ ## Important Notes
+
+ 1. Always run i18n:check BEFORE and AFTER making changes
+ 2. The check script is the source of truth for missing/extra keys
+ 3. For UPDATE scenario: git diff is the source of truth for changed values
+ 4. Create a single commit with all translation changes
+ 5. If any translation fails, continue with others and report failures
+
+ ═══════════════════════════════════════════════════════════════
+ ║ PHASE 4: COMMIT AND CREATE PR ║
+ ═══════════════════════════════════════════════════════════════
+
+ After all translations are complete and verified:
+
+ ### Step 4.1: Check for changes
+ ```bash
+ git -C ${{ github.workspace }} status --porcelain
+ ```
+
+ If there are changes:
+
+ ### Step 4.2: Create a new branch and commit
+ Run these git commands ONE BY ONE (not combined with &&).
+ **IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands:
+
+ 1. First, get the timestamp:
+ ```bash
+ date +%Y%m%d-%H%M%S
+ ```
+ (Note the output, e.g., "20260115-143052")
+
+ 2. Then create branch using the timestamp value:
+ ```bash
+ git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052
+ ```
+ (Replace "20260115-143052" with the actual timestamp from step 1)
+
+ 3. Stage changes:
+ ```bash
+ git -C ${{ github.workspace }} add web/i18n/
+ ```
+
+ 4. Commit:
+ ```bash
+ git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}"
+ ```
+
+ 5. Push:
+ ```bash
+ git -C ${{ github.workspace }} push origin HEAD
+ ```
+
+ ### Step 4.3: Create Pull Request
+ ```bash
+ gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary
+
+ This PR was automatically generated to sync i18n translation files.
+
+ ### Changes
+ - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
+ - Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
+
+ ### Verification
+ - [x] \`i18n:check\` passed
+ - [x] \`lint:fix\` applied
+
+ 🤖 Generated with Claude Code GitHub Action" --base main
+ ```
+
+ claude_args: |
+ --max-turns 150
+ --allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"
diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml
new file mode 100644
index 0000000000..66a29453b4
--- /dev/null
+++ b/.github/workflows/trigger-i18n-sync.yml
@@ -0,0 +1,66 @@
+name: Trigger i18n Sync on Push
+
+# This workflow bridges the push event to repository_dispatch
+# because claude-code-action doesn't support push events directly.
+# See: https://github.com/langgenius/dify/issues/30743
+
+on:
+ push:
+ branches: [main]
+ paths:
+ - 'web/i18n/en-US/*.json'
+
+permissions:
+ contents: write
+
+jobs:
+ trigger:
+ if: github.repository == 'langgenius/dify'
+ runs-on: ubuntu-latest
+ timeout-minutes: 5
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+
+ - name: Detect changed files and generate diff
+ id: detect
+ run: |
+ BEFORE_SHA="${{ github.event.before }}"
+ # Handle edge case: force push may have null/zero SHA
+ if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
+ BEFORE_SHA="HEAD~1"
+ fi
+
+ # Detect changed i18n files
+ changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
+ echo "changed_files=$changed" >> $GITHUB_OUTPUT
+
+ # Generate diff for context
+ git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
+
+ # Truncate if too large (keep first 50KB to match receiving workflow)
+ head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
+ mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
+
+ # Base64 encode the diff for safe JSON transport (portable, single-line)
+ diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
+ echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
+
+ if [ -n "$changed" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ echo "Detected changed files: $changed"
+ else
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ echo "No i18n changes detected"
+ fi
+
+ - name: Trigger i18n sync workflow
+ if: steps.detect.outputs.has_changes == 'true'
+ uses: peter-evans/repository-dispatch@v3
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+ event-type: i18n-sync
+ client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'
diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml
index 291171e5c7..7735afdaca 100644
--- a/.github/workflows/vdb-tests.yml
+++ b/.github/workflows/vdb-tests.yml
@@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- name: Free Disk Space
- uses: endersonmenezes/free-disk-space@v2
+ uses: endersonmenezes/free-disk-space@v3
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index 1a8925e38d..191ce56aaa 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -18,7 +18,7 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
@@ -29,9 +29,9 @@ jobs:
run_install: false
- name: Setup Node.js
- uses: actions/setup-node@v4
+ uses: actions/setup-node@v6
with:
- node-version: 22
+ node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@@ -360,9 +360,54 @@ jobs:
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error
+
+ web-build:
+ name: Web Build
+ runs-on: ubuntu-latest
+ defaults:
+ run:
+ working-directory: ./web
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v6
+ with:
+ persist-credentials: false
+
+ - name: Check changed files
+ id: changed-files
+ uses: tj-actions/changed-files@v47
+ with:
+ files: |
+ web/**
+ .github/workflows/web-tests.yml
+
+ - name: Install pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Setup NodeJS
+ uses: actions/setup-node@v6
+ if: steps.changed-files.outputs.any_changed == 'true'
+ with:
+ node-version: 24
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Web dependencies
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm install --frozen-lockfile
+
+ - name: Web build check
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm run build
diff --git a/.gitignore b/.gitignore
index 17a2bd5b7b..7bd919f095 100644
--- a/.gitignore
+++ b/.gitignore
@@ -235,3 +235,4 @@ scripts/stress-test/reports/
# settings
*.local.json
+*.local.md
diff --git a/.mcp.json b/.mcp.json
deleted file mode 100644
index 8eceaf9ead..0000000000
--- a/.mcp.json
+++ /dev/null
@@ -1,34 +0,0 @@
-{
- "mcpServers": {
- "context7": {
- "type": "http",
- "url": "https://mcp.context7.com/mcp"
- },
- "sequential-thinking": {
- "type": "stdio",
- "command": "npx",
- "args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
- "env": {}
- },
- "github": {
- "type": "stdio",
- "command": "npx",
- "args": ["-y", "@modelcontextprotocol/server-github"],
- "env": {
- "GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
- }
- },
- "fetch": {
- "type": "stdio",
- "command": "uvx",
- "args": ["mcp-server-fetch"],
- "env": {}
- },
- "playwright": {
- "type": "stdio",
- "command": "npx",
- "args": ["-y", "@playwright/mcp@latest"],
- "env": {}
- }
- }
- }
\ No newline at end of file
diff --git a/.nvmrc b/.nvmrc
deleted file mode 100644
index 7af24b7ddb..0000000000
--- a/.nvmrc
+++ /dev/null
@@ -1 +0,0 @@
-22.11.0
diff --git a/AGENTS.md b/AGENTS.md
index 782861ad36..7d96ac3a6d 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -12,12 +12,8 @@ The codebase is split into:
## Backend Workflow
+- Read `api/AGENTS.md` for details
- Run backend CLI commands through `uv run --project api `.
-
-- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
-
-- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
-
- Integration tests are CI-only and are not expected to run in the local environment.
## Frontend Workflow
@@ -29,6 +25,30 @@ pnpm type-check:tsgo
pnpm test
```
+### Frontend Linting
+
+ESLint is used for frontend code quality. Available commands:
+
+```bash
+# Lint all files (report only)
+pnpm lint
+
+# Lint and auto-fix issues
+pnpm lint:fix
+
+# Lint specific files or directories
+pnpm lint:fix app/components/base/button/
+pnpm lint:fix app/components/base/button/index.tsx
+
+# Lint quietly (errors only, no warnings)
+pnpm lint:quiet
+
+# Check code complexity
+pnpm lint:complexity
+```
+
+**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files.
+
## Testing & Quality Practices
- Follow TDD: red → green → refactor.
diff --git a/Makefile b/Makefile
index 07afd8187e..e92a7b1314 100644
--- a/Makefile
+++ b/Makefile
@@ -60,9 +60,11 @@ check:
@echo "✅ Code check complete"
lint:
- @echo "🔧 Running ruff format, check with fixes, and import linter..."
- @uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
+ @echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
+ @uv run --project api --dev ruff format ./api
+ @uv run --project api --dev ruff check --fix ./api
@uv run --directory api --dev lint-imports
+ @uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
type-check:
@@ -72,7 +74,12 @@ type-check:
test:
@echo "🧪 Running backend unit tests..."
- @uv run --project api --dev dev/pytest/pytest_unit_tests.sh
+ @if [ -n "$(TARGET_TESTS)" ]; then \
+ echo "Target: $(TARGET_TESTS)"; \
+ uv run --project api --dev pytest $(TARGET_TESTS); \
+ else \
+ uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
+ fi
@echo "✅ Tests complete"
# Build Docker images
@@ -122,9 +129,9 @@ help:
@echo "Backend Code Quality:"
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
- @echo " make lint - Format and fix code with ruff"
+ @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright"
- @echo " make test - Run backend unit tests"
+ @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
diff --git a/api/.env.example b/api/.env.example
index c195f1ae87..bf998a6cdc 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
+# Workflow run and Conversation archive storage (S3-compatible)
+ARCHIVE_STORAGE_ENABLED=false
+ARCHIVE_STORAGE_ENDPOINT=
+ARCHIVE_STORAGE_ARCHIVE_BUCKET=
+ARCHIVE_STORAGE_EXPORT_BUCKET=
+ARCHIVE_STORAGE_ACCESS_KEY=
+ARCHIVE_STORAGE_SECRET_KEY=
+ARCHIVE_STORAGE_REGION=auto
+
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
@@ -128,6 +137,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key
TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
+TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
@@ -407,6 +417,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
+# Optional: override the local hostname used for SMTP HELO/EHLO
+SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration
@@ -492,6 +504,8 @@ LOG_FILE_BACKUP_COUNT=5
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
+# Log output format: text or json
+LOG_OUTPUT_FORMAT=text
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
@@ -563,6 +577,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
+# Control flag for whether to write the `graph` field to LogStore.
+# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
+# otherwise write an empty {} instead. Defaults to writing the `graph` field.
+LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@@ -573,6 +591,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
+ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
@@ -721,3 +740,7 @@ PUBSUB_REDIS_USE_CLUSTERS=false
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
# Human input timeout check interval in minutes
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1
+
+
+SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
+
diff --git a/api/.importlinter b/api/.importlinter
index c99c72c5e3..d7f3767442 100644
--- a/api/.importlinter
+++ b/api/.importlinter
@@ -3,9 +3,11 @@ root_packages =
core
configs
controllers
+ extensions
models
tasks
services
+include_external_packages = True
[importlinter:contract:workflow]
name = Workflow
@@ -25,7 +27,9 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
- core.workflow.nodes.node_factory -> core.workflow.graph
+ core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
+ core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
+
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
@@ -35,6 +39,270 @@ ignore_imports =
# TODO(QuantumGhost): fix the import violation later
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
+[importlinter:contract:workflow-infrastructure-dependencies]
+name = Workflow Infrastructure Dependencies
+type = forbidden
+source_modules =
+ core.workflow
+forbidden_modules =
+ extensions.ext_database
+ extensions.ext_redis
+allow_indirect_imports = True
+ignore_imports =
+ core.workflow.nodes.agent.agent_node -> extensions.ext_database
+ core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
+ core.workflow.nodes.llm.file_saver -> extensions.ext_database
+ core.workflow.nodes.llm.llm_utils -> extensions.ext_database
+ core.workflow.nodes.llm.node -> extensions.ext_database
+ core.workflow.nodes.tool.tool_node -> extensions.ext_database
+ core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
+ core.workflow.graph_engine.manager -> extensions.ext_redis
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
+
+[importlinter:contract:workflow-external-imports]
+name = Workflow External Imports
+type = forbidden
+source_modules =
+ core.workflow
+forbidden_modules =
+ configs
+ controllers
+ extensions
+ models
+ services
+ tasks
+ core.agent
+ core.app
+ core.base
+ core.callback_handler
+ core.datasource
+ core.db
+ core.entities
+ core.errors
+ core.extension
+ core.external_data_tool
+ core.file
+ core.helper
+ core.hosting_configuration
+ core.indexing_runner
+ core.llm_generator
+ core.logging
+ core.mcp
+ core.memory
+ core.model_manager
+ core.moderation
+ core.ops
+ core.plugin
+ core.prompt
+ core.provider_manager
+ core.rag
+ core.repositories
+ core.schemas
+ core.tools
+ core.trigger
+ core.variables
+ignore_imports =
+ core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
+ core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
+ core.workflow.workflow_entry -> core.app.workflow.layers.observability
+ core.workflow.graph_engine.worker_management.worker_pool -> configs
+ core.workflow.nodes.agent.agent_node -> core.model_manager
+ core.workflow.nodes.agent.agent_node -> core.provider_manager
+ core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
+ core.workflow.nodes.datasource.datasource_node -> models.model
+ core.workflow.nodes.datasource.datasource_node -> models.tools
+ core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
+ core.workflow.nodes.document_extractor.node -> configs
+ core.workflow.nodes.document_extractor.node -> core.file.file_manager
+ core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
+ core.workflow.nodes.http_request.entities -> configs
+ core.workflow.nodes.http_request.executor -> configs
+ core.workflow.nodes.http_request.executor -> core.file.file_manager
+ core.workflow.nodes.http_request.node -> configs
+ core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
+ core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.llm.llm_utils -> configs
+ core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.llm.llm_utils -> core.file.models
+ core.workflow.nodes.llm.llm_utils -> core.model_manager
+ core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.llm.llm_utils -> models.model
+ core.workflow.nodes.llm.llm_utils -> models.provider
+ core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
+ core.workflow.nodes.llm.node -> core.tools.signature
+ core.workflow.nodes.template_transform.template_transform_node -> configs
+ core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
+ core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
+ core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
+ core.workflow.workflow_entry -> configs
+ core.workflow.workflow_entry -> models.workflow
+ core.workflow.nodes.agent.agent_node -> core.agent.entities
+ core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
+ core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.start.entities -> core.app.app_config.entities
+ core.workflow.nodes.start.start_node -> core.app.app_config.entities
+ core.workflow.workflow_entry -> core.app.apps.exc
+ core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
+ core.workflow.workflow_entry -> core.app.workflow.node_factory
+ core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
+ core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
+ core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
+ core.workflow.node_events.node -> core.file
+ core.workflow.nodes.agent.agent_node -> core.file
+ core.workflow.nodes.datasource.datasource_node -> core.file
+ core.workflow.nodes.datasource.datasource_node -> core.file.enums
+ core.workflow.nodes.document_extractor.node -> core.file
+ core.workflow.nodes.http_request.executor -> core.file.enums
+ core.workflow.nodes.http_request.node -> core.file
+ core.workflow.nodes.http_request.node -> core.file.file_manager
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
+ core.workflow.nodes.list_operator.node -> core.file
+ core.workflow.nodes.llm.file_saver -> core.file
+ core.workflow.nodes.llm.llm_utils -> core.variables.segments
+ core.workflow.nodes.llm.node -> core.file
+ core.workflow.nodes.llm.node -> core.file.file_manager
+ core.workflow.nodes.llm.node -> core.file.models
+ core.workflow.nodes.loop.entities -> core.variables.types
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
+ core.workflow.nodes.protocols -> core.file
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
+ core.workflow.nodes.tool.tool_node -> core.file
+ core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
+ core.workflow.nodes.tool.tool_node -> models
+ core.workflow.nodes.trigger_webhook.node -> core.file
+ core.workflow.runtime.variable_pool -> core.file
+ core.workflow.runtime.variable_pool -> core.file.file_manager
+ core.workflow.system_variable -> core.file.models
+ core.workflow.utils.condition.processor -> core.file
+ core.workflow.utils.condition.processor -> core.file.file_manager
+ core.workflow.workflow_entry -> core.file.models
+ core.workflow.workflow_type_encoder -> core.file.models
+ core.workflow.nodes.agent.agent_node -> models.model
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
+ core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
+ core.workflow.nodes.datasource.datasource_node -> core.variables.variables
+ core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
+ core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
+ core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
+ core.workflow.nodes.llm.node -> core.helper.code_executor
+ core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
+ core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
+ core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
+ core.workflow.nodes.llm.node -> core.model_manager
+ core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.llm.node -> models.dataset
+ core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
+ core.workflow.nodes.llm.file_saver -> core.tools.signature
+ core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
+ core.workflow.nodes.tool.tool_node -> core.tools.errors
+ core.workflow.conversation_variable_updater -> core.variables
+ core.workflow.graph_engine.entities.commands -> core.variables.variables
+ core.workflow.nodes.agent.agent_node -> core.variables.segments
+ core.workflow.nodes.answer.answer_node -> core.variables
+ core.workflow.nodes.code.code_node -> core.variables.segments
+ core.workflow.nodes.code.code_node -> core.variables.types
+ core.workflow.nodes.code.entities -> core.variables.types
+ core.workflow.nodes.datasource.datasource_node -> core.variables.segments
+ core.workflow.nodes.document_extractor.node -> core.variables
+ core.workflow.nodes.document_extractor.node -> core.variables.segments
+ core.workflow.nodes.http_request.executor -> core.variables.segments
+ core.workflow.nodes.http_request.node -> core.variables.segments
+ core.workflow.nodes.iteration.iteration_node -> core.variables
+ core.workflow.nodes.iteration.iteration_node -> core.variables.segments
+ core.workflow.nodes.iteration.iteration_node -> core.variables.variables
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
+ core.workflow.nodes.list_operator.node -> core.variables
+ core.workflow.nodes.list_operator.node -> core.variables.segments
+ core.workflow.nodes.llm.node -> core.variables
+ core.workflow.nodes.loop.loop_node -> core.variables
+ core.workflow.nodes.parameter_extractor.entities -> core.variables.types
+ core.workflow.nodes.parameter_extractor.exc -> core.variables.types
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
+ core.workflow.nodes.tool.tool_node -> core.variables.segments
+ core.workflow.nodes.tool.tool_node -> core.variables.variables
+ core.workflow.nodes.trigger_webhook.node -> core.variables.types
+ core.workflow.nodes.trigger_webhook.node -> core.variables.variables
+ core.workflow.nodes.variable_aggregator.entities -> core.variables.types
+ core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
+ core.workflow.nodes.variable_assigner.v1.node -> core.variables
+ core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
+ core.workflow.nodes.variable_assigner.v2.node -> core.variables
+ core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
+ core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
+ core.workflow.runtime.read_only_wrappers -> core.variables.segments
+ core.workflow.runtime.variable_pool -> core.variables
+ core.workflow.runtime.variable_pool -> core.variables.consts
+ core.workflow.runtime.variable_pool -> core.variables.segments
+ core.workflow.runtime.variable_pool -> core.variables.variables
+ core.workflow.utils.condition.processor -> core.variables
+ core.workflow.utils.condition.processor -> core.variables.segments
+ core.workflow.variable_loader -> core.variables
+ core.workflow.variable_loader -> core.variables.consts
+ core.workflow.workflow_type_encoder -> core.variables
+ core.workflow.graph_engine.manager -> extensions.ext_redis
+ core.workflow.nodes.agent.agent_node -> extensions.ext_database
+ core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
+ core.workflow.nodes.llm.file_saver -> extensions.ext_database
+ core.workflow.nodes.llm.llm_utils -> extensions.ext_database
+ core.workflow.nodes.llm.node -> extensions.ext_database
+ core.workflow.nodes.tool.tool_node -> extensions.ext_database
+ core.workflow.workflow_entry -> extensions.otel.runtime
+ core.workflow.nodes.agent.agent_node -> models
+ core.workflow.nodes.base.node -> models.enums
+ core.workflow.nodes.llm.llm_utils -> models.provider_ids
+ core.workflow.nodes.llm.node -> models.model
+ core.workflow.workflow_entry -> models.enums
+ core.workflow.nodes.agent.agent_node -> services
+ core.workflow.nodes.tool.tool_node -> services
+
[importlinter:contract:rsc]
name = RSC
type = layers
diff --git a/api/.ruff.toml b/api/.ruff.toml
index 7206f7fa0f..8db0cbcb21 100644
--- a/api/.ruff.toml
+++ b/api/.ruff.toml
@@ -1,4 +1,8 @@
-exclude = ["migrations/*"]
+exclude = [
+ "migrations/*",
+ ".git",
+ ".git/**",
+]
line-length = 120
[format]
diff --git a/api/AGENTS.md b/api/AGENTS.md
index 17398ec4b8..13adb42276 100644
--- a/api/AGENTS.md
+++ b/api/AGENTS.md
@@ -1,62 +1,186 @@
-# Agent Skill Index
+# API Agent Guide
-Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
+## Notes for Agent (must-check)
-______________________________________________________________________
+Before changing any backend code under `api/`, you MUST read the surrounding docstrings and comments. These notes contain required context (invariants, edge cases, trade-offs) and are treated as part of the spec.
-## Platform Foundations
+Look for:
-- **[Infrastructure Overview](agent_skills/infra.md)**\
- When to read this:
+- The module (file) docstring at the top of a source code file
+- Docstrings on classes and functions/methods
+- Paragraph/block comments for non-obvious logic
- - You need to understand where a feature belongs in the architecture.
- - You’re wiring storage, Redis, vector stores, or OTEL.
- - You’re about to add CLI commands or async jobs.\
- What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
+### What to write where
-- **[Coding Style](agent_skills/coding_style.md)**\
- When to read this:
+- Keep notes scoped: module notes cover module-wide context, class notes cover class-wide context, function/method notes cover behavioural contracts, and paragraph/block comments cover local “why”. Avoid duplicating the same content across scopes unless repetition prevents misuse.
+- **Module (file) docstring**: purpose, boundaries, key invariants, and “gotchas” that a new reader must know before editing.
+ - Include cross-links to the key collaborators (modules/services) when discovery is otherwise hard.
+ - Prefer stable facts (invariants, contracts) over ephemeral “today we…” notes.
+- **Class docstring**: responsibility, lifecycle, invariants, and how it should be used (or not used).
+ - If the class is intentionally stateful, note what state exists and what methods mutate it.
+ - If concurrency/async assumptions matter, state them explicitly.
+- **Function/method docstring**: behavioural contract.
+ - Document arguments, return shape, side effects (DB writes, external I/O, task dispatch), and raised domain exceptions.
+ - Add examples only when they prevent misuse.
+- **Paragraph/block comments**: explain *why* (trade-offs, historical constraints, surprising edge cases), not what the code already states.
+ - Keep comments adjacent to the logic they justify; delete or rewrite comments that no longer match reality.
- - You’re writing or reviewing backend code and need the authoritative checklist.
- - You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
- - You want the exact lint/type/test commands used in PRs.\
- Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
+### Rules (must follow)
-______________________________________________________________________
+In this section, “notes” means module/class/function docstrings plus any relevant paragraph/block comments.
-## Plugin & Extension Development
+- **Before working**
+ - Read the notes in the area you’ll touch; treat them as part of the spec.
+ - If a docstring or comment conflicts with the current code, treat the **code as the single source of truth** and update the docstring or comment to match reality.
+ - If important intent/invariants/edge cases are missing, add them in the closest docstring or comment (module for overall scope, function for behaviour).
+- **During working**
+ - Keep the notes in sync as you discover constraints, make decisions, or change approach.
+ - If you move/rename responsibilities across modules/classes, update the affected docstrings and comments so readers can still find the “why” and the invariants.
+ - Record non-obvious edge cases, trade-offs, and the test/verification plan in the nearest docstring or comment that will stay correct.
+ - Keep the notes **coherent**: integrate new findings into the relevant docstrings and comments; avoid append-only “recent fix” / changelog-style additions.
+- **When finishing**
+ - Update the notes to reflect what changed, why, and any new edge cases/tests.
+ - Remove or rewrite any comments that could be mistaken as current guidance but no longer apply.
+ - Keep docstrings and comments concise and accurate; they are meant to prevent repeated rediscovery.
-- **[Plugin Systems](agent_skills/plugin.md)**\
- When to read this:
+## Coding Style
- - You’re building or debugging a marketplace plugin.
- - You need to know how manifests, providers, daemons, and migrations fit together.\
- What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
+This is the default standard for backend code in this repo. Follow it for new code and use it as the checklist when reviewing changes.
-- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
- When to read this:
+### Linting & Formatting
- - You must integrate OAuth for a plugin or datasource.
- - You’re handling credential encryption or refresh flows.\
- Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
+- Use Ruff for formatting and linting (follow `.ruff.toml`).
+- Keep each line under 120 characters (including spaces).
-______________________________________________________________________
+### Naming Conventions
-## Workflow Entry & Execution
+- Use `snake_case` for variables and functions.
+- Use `PascalCase` for classes.
+- Use `UPPER_CASE` for constants.
-- **[Trigger Concepts](agent_skills/trigger.md)**\
- When to read this:
- - You’re debugging why a workflow didn’t start.
- - You’re adding a new trigger type or hook.
- - You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
- Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
+### Typing & Class Layout
-______________________________________________________________________
+- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
+- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
+- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
-## Additional Notes for Agents
+```python
+from datetime import datetime
-- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
-- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
-- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
-- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
-- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
+
+class Example:
+ user_id: str
+ created_at: datetime
+
+ def __init__(self, user_id: str, created_at: datetime) -> None:
+ self.user_id = user_id
+ self.created_at = created_at
+```
+
+### General Rules
+
+- Use Pydantic v2 conventions.
+- Use `uv` for Python package management in this repo (usually with `--project api`).
+- Prefer simple functions over small “utility classes” for lightweight helpers.
+- Avoid implementing dunder methods unless it’s clearly needed and matches existing patterns.
+- Never start long-running services as part of agent work (`uv run app.py`, `flask run`, etc.); running tests is allowed.
+- Keep files below ~800 lines; split when necessary.
+- Keep code readable and explicit—avoid clever hacks.
+
+### Architecture & Boundaries
+
+- Mirror the layered architecture: controller → service → core/domain.
+- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
+- Optimise for observability: deterministic control flow, clear logging, actionable errors.
+
+### Logging & Errors
+
+- Never use `print`; use a module-level logger:
+ - `logger = logging.getLogger(__name__)`
+- Include tenant/app/workflow identifiers in log context when relevant.
+- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate them into HTTP responses in controllers.
+- Log retryable events at `warning`, terminal failures at `error`.
+
+### SQLAlchemy Patterns
+
+- Models inherit from `models.base.TypeBase`; do not create ad-hoc metadata or engines.
+- Open sessions with context managers:
+
+```python
+from sqlalchemy.orm import Session
+
+with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(Workflow).where(
+ Workflow.id == workflow_id,
+ Workflow.tenant_id == tenant_id,
+ )
+ workflow = session.execute(stmt).scalar_one_or_none()
+```
+
+- Prefer SQLAlchemy expressions; avoid raw SQL unless necessary.
+- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
+- Introduce repository abstractions only for very large tables (e.g., workflow executions) or when alternative storage strategies are required.
+
+### Storage & External I/O
+
+- Access storage via `extensions.ext_storage.storage`.
+- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
+- Background tasks that touch storage must be idempotent, and should log relevant object identifiers.
+
+### Pydantic Usage
+
+- Define DTOs with Pydantic v2 models and forbid extras by default.
+- Use `@field_validator` / `@model_validator` for domain rules.
+
+Example:
+
+```python
+from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
+
+
+class TriggerConfig(BaseModel):
+ endpoint: HttpUrl
+ secret: str
+
+ model_config = ConfigDict(extra="forbid")
+
+ @field_validator("secret")
+ def ensure_secret_prefix(cls, value: str) -> str:
+ if not value.startswith("dify_"):
+ raise ValueError("secret must start with dify_")
+ return value
+```
+
+### Generics & Protocols
+
+- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
+- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
+- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
+
+### Tooling & Checks
+
+Quick checks while iterating:
+
+- Format: `make format`
+- Lint (includes auto-fix): `make lint`
+- Type check: `make type-check`
+- Targeted tests: `make test TARGET_TESTS=./api/tests/`
+
+Before opening a PR / submitting:
+
+- `make lint`
+- `make type-check`
+- `make test`
+
+### Controllers & Services
+
+- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
+- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
+- Document non-obvious behaviour with concise docstrings and comments.
+
+### Miscellaneous
+
+- Use `configs.dify_config` for configuration—never read environment variables directly.
+- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
+- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
+- Keep experimental scripts under `dev/`; do not ship them in production builds.
diff --git a/api/Dockerfile b/api/Dockerfile
index 02df91bfc1..a08d4e3aab 100644
--- a/api/Dockerfile
+++ b/api/Dockerfile
@@ -50,16 +50,33 @@ WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
+ARG NODE_MAJOR=22
+ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1
+ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \
apt-get update \
+ && apt-get install -y --no-install-recommends \
+ ca-certificates \
+ curl \
+ gnupg \
+ && mkdir -p /etc/apt/keyrings \
+ && curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \
+ && gpg --show-keys --with-colons /tmp/nodesource.gpg \
+ | awk -F: '/^fpr:/ {print $10}' \
+ | grep -Fx "${NODESOURCE_KEY_FPR}" \
+ && gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \
+ && rm -f /tmp/nodesource.gpg \
+ && echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \
+ > /etc/apt/sources.list.d/nodesource.list \
+ && apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
- curl nodejs \
+ nodejs=${NODE_PACKAGE_VERSION} \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
@@ -79,7 +96,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
-RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
+RUN mkdir -p /usr/local/share/nltk_data \
+ && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
diff --git a/api/README.md b/api/README.md
index 794b05d3af..9d89b490b0 100644
--- a/api/README.md
+++ b/api/README.md
@@ -1,6 +1,6 @@
# Dify Backend API
-## Usage
+## Setup and Run
> [!IMPORTANT]
>
@@ -8,48 +8,77 @@
> [`uv`](https://docs.astral.sh/uv/) as the package manager
> for Dify API backend service.
-1. Start the docker-compose stack
+`uv` and `pnpm` are required to run the setup and development commands below.
- The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
+### Using scripts (recommended)
+
+The scripts resolve paths relative to their location, so you can run them from anywhere.
+
+1. Run setup (copies env files and installs dependencies).
```bash
- cd ../docker
- cp middleware.env.example middleware.env
- # change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
- docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
- cd ../api
+ ./dev/setup
```
-1. Copy `.env.example` to `.env`
+1. Review `api/.env`, `web/.env.local`, and `docker/middleware.env` values (see the `SECRET_KEY` note below).
- ```cli
- cp .env.example .env
+1. Start middleware (PostgreSQL/Redis/Weaviate).
+
+ ```bash
+ ./dev/start-docker-compose
```
-> [!IMPORTANT]
->
-> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
+1. Start backend (runs migrations first).
-1. Generate a `SECRET_KEY` in the `.env` file.
-
- bash for Linux
-
- ```bash for Linux
- sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
+ ```bash
+ ./dev/start-api
```
- bash for Mac
+1. Start Dify [web](../web) service.
- ```bash for Mac
- secret_key=$(openssl rand -base64 42)
- sed -i '' "/^SECRET_KEY=/c\\
- SECRET_KEY=${secret_key}" .env
+ ```bash
+ ./dev/start-web
```
-1. Create environment.
+1. Set up your application by visiting `http://localhost:3000`.
- Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
- First, you need to add the uv package manager, if you don't have it already.
+1. Optional: start the worker service (async tasks, runs from `api`).
+
+ ```bash
+ ./dev/start-worker
+ ```
+
+1. Optional: start Celery Beat (scheduled tasks).
+
+ ```bash
+ ./dev/start-beat
+ ```
+
+### Manual commands
+
+
+Show manual setup and run steps
+
+These commands assume you start from the repository root.
+
+1. Start the docker-compose stack.
+
+ The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
+
+ ```bash
+ cp docker/middleware.env.example docker/middleware.env
+ # Use mysql or another vector database profile if you are not using postgres/weaviate.
+ docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
+ ```
+
+1. Copy env files.
+
+ ```bash
+ cp api/.env.example api/.env
+ cp web/.env.example web/.env.local
+ ```
+
+1. Install UV if needed.
```bash
pip install uv
@@ -57,60 +86,96 @@
brew install uv
```
-1. Install dependencies
+1. Install API dependencies.
```bash
- uv sync --dev
+ cd api
+ uv sync --group dev
```
-1. Run migrate
-
- Before the first launch, migrate the database to the latest version.
+1. Install web dependencies.
```bash
+ cd web
+ pnpm install
+ cd ..
+ ```
+
+1. Start backend (runs migrations first, in a new terminal).
+
+ ```bash
+ cd api
uv run flask db upgrade
- ```
-
-1. Start backend
-
- ```bash
uv run flask run --host 0.0.0.0 --port=5001 --debug
```
-1. Start Dify [web](../web) service.
+1. Start Dify [web](../web) service (in a new terminal).
-1. Setup your application by visiting `http://localhost:3000`.
+ ```bash
+ cd web
+ pnpm dev:inspect
+ ```
-1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
+1. Set up your application by visiting `http://localhost:3000`.
-```bash
-uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
-```
+1. Optional: start the worker service (async tasks, in a new terminal).
-Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
+ ```bash
+ cd api
+ uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
+ ```
-```bash
-uv run celery -A app.celery beat
-```
+1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
+
+ ```bash
+ cd api
+ uv run celery -A app.celery beat
+ ```
+
+
+
+### Environment notes
+
+> [!IMPORTANT]
+>
+> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
+
+- Generate a `SECRET_KEY` in the `.env` file.
+
+ bash for Linux
+
+ ```bash
+ sed -i "/^SECRET_KEY=/c\\SECRET_KEY=$(openssl rand -base64 42)" .env
+ ```
+
+ bash for Mac
+
+ ```bash
+ secret_key=$(openssl rand -base64 42)
+ sed -i '' "/^SECRET_KEY=/c\\
+ SECRET_KEY=${secret_key}" .env
+ ```
## Testing
1. Install dependencies for both the backend and the test environment
```bash
- uv sync --dev
+ cd api
+ uv sync --group dev
```
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
```bash
+ cd api
uv run pytest # Run all tests
uv run pytest tests/unit_tests/ # Unit tests only
uv run pytest tests/integration_tests/ # Integration tests
# Code quality
- ../dev/reformat # Run all formatters and linters
- uv run ruff check --fix ./ # Fix linting issues
- uv run ruff format ./ # Format code
- uv run basedpyright . # Type checking
+ ./dev/reformat # Run all formatters and linters
+ uv run ruff check --fix ./ # Fix linting issues
+ uv run ruff format ./ # Format code
+ uv run basedpyright . # Type checking
```
diff --git a/api/agent_skills/coding_style.md b/api/agent_skills/coding_style.md
deleted file mode 100644
index a2b66f0bd5..0000000000
--- a/api/agent_skills/coding_style.md
+++ /dev/null
@@ -1,115 +0,0 @@
-## Linter
-
-- Always follow `.ruff.toml`.
-- Run `uv run ruff check --fix --unsafe-fixes`.
-- Keep each line under 100 characters (including spaces).
-
-## Code Style
-
-- `snake_case` for variables and functions.
-- `PascalCase` for classes.
-- `UPPER_CASE` for constants.
-
-## Rules
-
-- Use Pydantic v2 standard.
-- Use `uv` for package management.
-- Do not override dunder methods like `__init__`, `__iadd__`, etc.
-- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
-- Prefer simple functions over classes for lightweight helpers.
-- Keep files below 800 lines; split when necessary.
-- Keep code readable—no clever hacks.
-- Never use `print`; log with `logger = logging.getLogger(__name__)`.
-
-## Guiding Principles
-
-- Mirror the project’s layered architecture: controller → service → core/domain.
-- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
-- Optimise for observability: deterministic control flow, clear logging, actionable errors.
-
-## SQLAlchemy Patterns
-
-- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
-
-- Open sessions with context managers:
-
- ```python
- from sqlalchemy.orm import Session
-
- with Session(db.engine, expire_on_commit=False) as session:
- stmt = select(Workflow).where(
- Workflow.id == workflow_id,
- Workflow.tenant_id == tenant_id,
- )
- workflow = session.execute(stmt).scalar_one_or_none()
- ```
-
-- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
-
-- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
-
-- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
-
-## Storage & External IO
-
-- Access storage via `extensions.ext_storage.storage`.
-- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
-- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
-
-## Pydantic Usage
-
-- Define DTOs with Pydantic v2 models and forbid extras by default.
-
-- Use `@field_validator` / `@model_validator` for domain rules.
-
-- Example:
-
- ```python
- from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
-
- class TriggerConfig(BaseModel):
- endpoint: HttpUrl
- secret: str
-
- model_config = ConfigDict(extra="forbid")
-
- @field_validator("secret")
- def ensure_secret_prefix(cls, value: str) -> str:
- if not value.startswith("dify_"):
- raise ValueError("secret must start with dify_")
- return value
- ```
-
-## Generics & Protocols
-
-- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
-- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
-- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
-
-## Error Handling & Logging
-
-- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
-- Declare `logger = logging.getLogger(__name__)` at module top.
-- Include tenant/app/workflow identifiers in log context.
-- Log retryable events at `warning`, terminal failures at `error`.
-
-## Tooling & Checks
-
-- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
-- Type checks: `uv run --directory api --dev basedpyright`.
-- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
-- Run all of the above before submitting your work.
-
-## Controllers & Services
-
-- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
-- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
-- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
-- Document non-obvious behaviour with concise comments.
-
-## Miscellaneous
-
-- Use `configs.dify_config` for configuration—never read environment variables directly.
-- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
-- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
-- Keep experimental scripts under `dev/`; do not ship them in production builds.
diff --git a/api/agent_skills/infra.md b/api/agent_skills/infra.md
deleted file mode 100644
index bc36c7bf64..0000000000
--- a/api/agent_skills/infra.md
+++ /dev/null
@@ -1,96 +0,0 @@
-## Configuration
-
-- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
-- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
-- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
-- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
-
-## Dependencies
-
-- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
-- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
-- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
-
-## Storage & Files
-
-- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
-- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
-- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
-- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
-
-## Redis & Shared State
-
-- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
-- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
-
-## Models
-
-- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
-- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
-- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
-- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
-
-## Vector Stores
-
-- Vector client implementations live in `core/rag/datasource/vdb/`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
-- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
-- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
-- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
-
-## Observability & OTEL
-
-- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
-- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
-- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
-- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
-
-## Ops Integrations
-
-- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
-- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
-- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
-
-## Controllers, Services, Core
-
-- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
-- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
-- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
-
-## Plugins, Tools, Providers
-
-- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
-- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
-- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
-- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
-- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
-- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
-
-## Async Workloads
-
-see `agent_skills/trigger.md` for more detailed documentation.
-
-- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
-- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
-- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
-
-## Database & Migrations
-
-- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
-- Generate migrations with `uv run --project api flask db revision --autogenerate -m ""`, then review the diff; never hand-edit the database outside Alembic.
-- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
-- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
-
-## CLI Commands
-
-- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask `.
-- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
-- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
-- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
-- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
-
-## When You Add Features
-
-- Check for an existing helper or service before writing a new util.
-- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
-- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
-- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.
diff --git a/api/agent_skills/plugin.md b/api/agent_skills/plugin.md
deleted file mode 100644
index 954ddd236b..0000000000
--- a/api/agent_skills/plugin.md
+++ /dev/null
@@ -1 +0,0 @@
-// TBD
diff --git a/api/agent_skills/plugin_oauth.md b/api/agent_skills/plugin_oauth.md
deleted file mode 100644
index 954ddd236b..0000000000
--- a/api/agent_skills/plugin_oauth.md
+++ /dev/null
@@ -1 +0,0 @@
-// TBD
diff --git a/api/agent_skills/trigger.md b/api/agent_skills/trigger.md
deleted file mode 100644
index f4b076332c..0000000000
--- a/api/agent_skills/trigger.md
+++ /dev/null
@@ -1,53 +0,0 @@
-## Overview
-
-Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
-
-## Trigger nodes
-
-- `UserInput`
-- `Trigger Webhook`
-- `Trigger Schedule`
-- `Trigger Plugin`
-
-### UserInput
-
-Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
-
-1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
-1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
-1. For its detailed implementation, please refer to `core/workflow/nodes/start`
-
-### Trigger Webhook
-
-Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
-
-Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
-
-### Trigger Schedule
-
-`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
-
-To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
-
-### Trigger Plugin
-
-`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
-
-1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
-1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
-
-A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
-
-## Worker Pool / Async Task
-
-All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
-
-The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
-
-## Debug Strategy
-
-Dify divided users into 2 groups: builders / end users.
-
-Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
-
-A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.
diff --git a/api/app_factory.py b/api/app_factory.py
index bcad88e9e0..07859a3758 100644
--- a/api/app_factory.py
+++ b/api/app_factory.py
@@ -2,9 +2,11 @@ import logging
import time
from opentelemetry.trace import get_current_span
+from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
+from core.logging.context import init_request_context
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
# add before request hook
@dify_app.before_request
def before_request():
- # add an unique identifier to each request
+ # Initialize logging context for this request
+ init_request_context()
RecyclableContextVar.increment_thread_recycles()
- # add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
+ # add after request hook for injecting trace headers from OpenTelemetry span context
+ # Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request
- def add_trace_id_header(response):
+ def add_trace_headers(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
- if ctx and ctx.is_valid:
- trace_id_hex = format(ctx.trace_id, "032x")
- # Avoid duplicates if some middleware added it
- if "X-Trace-Id" not in response.headers:
- response.headers["X-Trace-Id"] = trace_id_hex
+
+ if not ctx or not ctx.is_valid:
+ return response
+
+ # Inject trace headers from OTEL context
+ if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
+ response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
+ if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
+ response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
+
except Exception:
# Never break the response due to tracing header injection
- logger.warning("Failed to add trace ID to response header", exc_info=True)
+ logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
- _ = add_trace_id_header
+ _ = add_trace_headers
return dify_app
@@ -62,6 +71,8 @@ def create_app() -> DifyApp:
def initialize_extensions(app: DifyApp):
+ # Initialize Flask context capture for workflow execution
+ from context.flask_app_context import init_flask_context
from extensions import (
ext_app_metrics,
ext_blueprints,
@@ -70,6 +81,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
+ ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
@@ -91,6 +103,8 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
)
+ init_flask_context()
+
extensions = [
ext_timezone,
ext_logging,
@@ -115,6 +129,7 @@ def initialize_extensions(app: DifyApp):
ext_proxy_fix,
ext_blueprints,
ext_commands,
+ ext_fastopenapi,
ext_otel,
ext_request_logging,
ext_session_factory,
diff --git a/api/commands.py b/api/commands.py
index a8d89ac200..3d68de4cb4 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -1,7 +1,9 @@
import base64
+import datetime
import json
import logging
import secrets
+import time
from typing import Any
import click
@@ -34,7 +36,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
-from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
+from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@@ -45,6 +47,9 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
+from services.retention.conversation.messages_clean_policy import create_message_clean_policy
+from services.retention.conversation.messages_clean_service import MessagesCleanService
+from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@@ -62,8 +67,10 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
+ normalized_email = email.strip().lower()
+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
- account = session.query(Account).where(Account.email == email).one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@@ -84,7 +91,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
- AccountService.reset_login_error_rate_limit(email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -100,20 +107,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
+ normalized_new_email = new_email.strip().lower()
+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
- account = session.query(Account).where(Account.email == email).one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
- email_validate(new_email)
+ email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
- account.email = new_email
+ account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@@ -235,7 +244,7 @@ def migrate_annotation_vector_database():
if annotations:
for annotation in annotations:
document = Document(
- page_content=annotation.question,
+ page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
)
documents.append(document)
@@ -658,7 +667,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
- email = email.strip()
+ email = email.strip().lower()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
@@ -852,6 +861,435 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
+@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
+@click.option(
+ "--before-days",
+ "--days",
+ default=30,
+ show_default=True,
+ type=click.IntRange(min=0),
+ help="Delete workflow runs created before N days ago.",
+)
+@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
+@click.option(
+ "--from-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
+)
+@click.option(
+ "--to-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
+)
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option(
+ "--dry-run",
+ is_flag=True,
+ help="Preview cleanup results without deleting any workflow run data.",
+)
+def clean_workflow_runs(
+ before_days: int,
+ batch_size: int,
+ from_days_ago: int | None,
+ to_days_ago: int | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ dry_run: bool,
+):
+ """
+ Clean workflow runs and related workflow data for free tenants.
+ """
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+
+ if (from_days_ago is None) ^ (to_days_ago is None):
+ raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
+
+ if from_days_ago is not None and to_days_ago is not None:
+ if start_from or end_before:
+ raise click.UsageError("Choose either day offsets or explicit dates, not both.")
+ if from_days_ago <= to_days_ago:
+ raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
+ now = datetime.datetime.now()
+ start_from = now - datetime.timedelta(days=from_days_ago)
+ end_before = now - datetime.timedelta(days=to_days_ago)
+ before_days = 0
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
+
+ WorkflowRunCleanup(
+ days=before_days,
+ batch_size=batch_size,
+ start_from=start_from,
+ end_before=end_before,
+ dry_run=dry_run,
+ ).run()
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+ click.echo(
+ click.style(
+ f"Workflow run cleanup completed. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+
+
+@click.command(
+ "archive-workflow-runs",
+ help="Archive workflow runs for paid plan tenants to S3-compatible storage.",
+)
+@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.")
+@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.")
+@click.option(
+ "--from-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
+)
+@click.option(
+ "--to-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
+)
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Archive runs created at or after this timestamp (UTC if no timezone).",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Archive runs created before this timestamp (UTC if no timezone).",
+)
+@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.")
+@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.")
+@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.")
+@click.option("--dry-run", is_flag=True, help="Preview without archiving.")
+@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.")
+def archive_workflow_runs(
+ tenant_ids: str | None,
+ before_days: int,
+ from_days_ago: int | None,
+ to_days_ago: int | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ batch_size: int,
+ workers: int,
+ limit: int | None,
+ dry_run: bool,
+ delete_after_archive: bool,
+):
+ """
+ Archive workflow runs for paid plan tenants older than the specified days.
+
+ This command archives the following tables to storage:
+ - workflow_node_executions
+ - workflow_node_execution_offload
+ - workflow_pauses
+ - workflow_pause_reasons
+ - workflow_trigger_logs
+
+ The workflow_runs and workflow_app_logs tables are preserved for UI listing.
+ """
+ from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
+
+ run_started_at = datetime.datetime.now(datetime.UTC)
+ click.echo(
+ click.style(
+ f"Starting workflow run archiving at {run_started_at.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ if (start_from is None) ^ (end_before is None):
+ click.echo(click.style("start-from and end-before must be provided together.", fg="red"))
+ return
+
+ if (from_days_ago is None) ^ (to_days_ago is None):
+ click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red"))
+ return
+
+ if from_days_ago is not None and to_days_ago is not None:
+ if start_from or end_before:
+ click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red"))
+ return
+ if from_days_ago <= to_days_ago:
+ click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red"))
+ return
+ now = datetime.datetime.now()
+ start_from = now - datetime.timedelta(days=from_days_ago)
+ end_before = now - datetime.timedelta(days=to_days_ago)
+ before_days = 0
+
+ if start_from and end_before and start_from >= end_before:
+ click.echo(click.style("start-from must be earlier than end-before.", fg="red"))
+ return
+ if workers < 1:
+ click.echo(click.style("workers must be at least 1.", fg="red"))
+ return
+
+ archiver = WorkflowRunArchiver(
+ days=before_days,
+ batch_size=batch_size,
+ start_from=start_from,
+ end_before=end_before,
+ workers=workers,
+ tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None,
+ limit=limit,
+ dry_run=dry_run,
+ delete_after_archive=delete_after_archive,
+ )
+ summary = archiver.run()
+ click.echo(
+ click.style(
+ f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
+ f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
+ f"time={summary.total_elapsed_time:.2f}s",
+ fg="cyan",
+ )
+ )
+
+ run_finished_at = datetime.datetime.now(datetime.UTC)
+ elapsed = run_finished_at - run_started_at
+ click.echo(
+ click.style(
+ f"Workflow run archiving completed. start={run_started_at.isoformat()} "
+ f"end={run_finished_at.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+
+
+@click.command(
+ "restore-workflow-runs",
+ help="Restore archived workflow runs from S3-compatible storage.",
+)
+@click.option(
+ "--tenant-ids",
+ required=False,
+ help="Tenant IDs (comma-separated).",
+)
+@click.option("--run-id", required=False, help="Workflow run ID to restore.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.")
+@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.")
+@click.option("--dry-run", is_flag=True, help="Preview without restoring.")
+def restore_workflow_runs(
+ tenant_ids: str | None,
+ run_id: str | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ workers: int,
+ limit: int,
+ dry_run: bool,
+):
+ """
+ Restore an archived workflow run from storage to the database.
+
+ This restores the following tables:
+ - workflow_node_executions
+ - workflow_node_execution_offload
+ - workflow_pauses
+ - workflow_pause_reasons
+ - workflow_trigger_logs
+ """
+ from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
+
+ parsed_tenant_ids = None
+ if tenant_ids:
+ parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
+ if not parsed_tenant_ids:
+ raise click.BadParameter("tenant-ids must not be empty")
+
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+ if run_id is None and (start_from is None or end_before is None):
+ raise click.UsageError("--start-from and --end-before are required for batch restore.")
+ if workers < 1:
+ raise click.BadParameter("workers must be at least 1")
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ click.echo(
+ click.style(
+ f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers)
+ if run_id:
+ results = [restorer.restore_by_run_id(run_id)]
+ else:
+ assert start_from is not None
+ assert end_before is not None
+ results = restorer.restore_batch(
+ parsed_tenant_ids,
+ start_date=start_from,
+ end_date=end_before,
+ limit=limit,
+ )
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+
+ successes = sum(1 for result in results if result.success)
+ failures = len(results) - successes
+
+ if failures == 0:
+ click.echo(
+ click.style(
+ f"Restore completed successfully. success={successes} duration={elapsed}",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}",
+ fg="red",
+ )
+ )
+
+
+@click.command(
+ "delete-archived-workflow-runs",
+ help="Delete archived workflow runs from the database.",
+)
+@click.option(
+ "--tenant-ids",
+ required=False,
+ help="Tenant IDs (comma-separated).",
+)
+@click.option("--run-id", required=False, help="Workflow run ID to delete.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.")
+@click.option("--dry-run", is_flag=True, help="Preview without deleting.")
+def delete_archived_workflow_runs(
+ tenant_ids: str | None,
+ run_id: str | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ limit: int,
+ dry_run: bool,
+):
+ """
+ Delete archived workflow runs from the database.
+ """
+ from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
+
+ parsed_tenant_ids = None
+ if tenant_ids:
+ parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
+ if not parsed_tenant_ids:
+ raise click.BadParameter("tenant-ids must not be empty")
+
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+ if run_id is None and (start_from is None or end_before is None):
+ raise click.UsageError("--start-from and --end-before are required for batch delete.")
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ target_desc = f"workflow run {run_id}" if run_id else "workflow runs"
+ click.echo(
+ click.style(
+ f"Starting delete of {target_desc} at {start_time.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run)
+ if run_id:
+ results = [deleter.delete_by_run_id(run_id)]
+ else:
+ assert start_from is not None
+ assert end_before is not None
+ results = deleter.delete_batch(
+ parsed_tenant_ids,
+ start_date=start_from,
+ end_date=end_before,
+ limit=limit,
+ )
+
+ for result in results:
+ if result.success:
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} "
+ f"workflow run {result.run_id} (tenant={result.tenant_id})",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Failed to delete workflow run {result.run_id}: {result.error}",
+ fg="red",
+ )
+ )
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+
+ successes = sum(1 for result in results if result.success)
+ failures = len(results) - successes
+
+ if failures == 0:
+ click.echo(
+ click.style(
+ f"Delete completed successfully. success={successes} duration={elapsed}",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}",
+ fg="red",
+ )
+ )
+
+
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):
@@ -1184,6 +1622,217 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
+@click.command("file-usage", help="Query file usages and show where files are referenced.")
+@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
+@click.option("--key", type=str, default=None, help="Filter by storage key.")
+@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
+@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
+@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
+@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
+def file_usage(
+ file_id: str | None,
+ key: str | None,
+ src: str | None,
+ limit: int,
+ offset: int,
+ output_json: bool,
+):
+ """
+ Query file usages and show where files are referenced in the database.
+
+ This command reuses the same reference checking logic as clear-orphaned-file-records
+ and displays detailed information about where each file is referenced.
+ """
+ # define tables and columns to process
+ files_tables = [
+ {"table": "upload_files", "id_column": "id", "key_column": "key"},
+ {"table": "tool_files", "id_column": "id", "key_column": "file_key"},
+ ]
+ ids_tables = [
+ {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
+ {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
+ {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
+ {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
+ {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
+ {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
+ {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
+ {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
+ {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
+ {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
+ {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
+ {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
+ {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
+ {"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
+ ]
+
+ # Stream file usages with pagination to avoid holding all results in memory
+ paginated_usages = []
+ total_count = 0
+
+ # First, build a mapping of file_id -> storage_key from the base tables
+ file_key_map = {}
+ for files_table in files_tables:
+ query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
+ with db.engine.begin() as conn:
+ rs = conn.execute(sa.text(query))
+ for row in rs:
+ file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
+
+ # If filtering by key or file_id, verify it exists
+ if file_id and file_id not in file_key_map:
+ if output_json:
+ click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
+ else:
+ click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
+ return
+
+ if key:
+ valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
+ matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
+ if not matching_file_ids:
+ if output_json:
+ click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
+ else:
+ click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
+ return
+
+ guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
+
+ # For each reference table/column, find matching file IDs and record the references
+ for ids_table in ids_tables:
+ src_filter = f"{ids_table['table']}.{ids_table['column']}"
+
+ # Skip if src filter doesn't match (use fnmatch for wildcard patterns)
+ if src:
+ if "%" in src or "_" in src:
+ import fnmatch
+
+ # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
+ pattern = src.replace("%", "*").replace("_", "?")
+ if not fnmatch.fnmatch(src_filter, pattern):
+ continue
+ else:
+ if src_filter != src:
+ continue
+
+ if ids_table["type"] == "uuid":
+ # Direct UUID match
+ query = (
+ f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
+ f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
+ )
+ with db.engine.begin() as conn:
+ rs = conn.execute(sa.text(query))
+ for row in rs:
+ record_id = str(row[0])
+ ref_file_id = str(row[1])
+ if ref_file_id not in file_key_map:
+ continue
+ storage_key = file_key_map[ref_file_id]
+
+ # Apply filters
+ if file_id and ref_file_id != file_id:
+ continue
+ if key and not storage_key.endswith(key):
+ continue
+
+ # Only collect items within the requested page range
+ if offset <= total_count < offset + limit:
+ paginated_usages.append(
+ {
+ "src": f"{ids_table['table']}.{ids_table['column']}",
+ "record_id": record_id,
+ "file_id": ref_file_id,
+ "key": storage_key,
+ }
+ )
+ total_count += 1
+
+ elif ids_table["type"] in ("text", "json"):
+ # Extract UUIDs from text/json content
+ column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
+ query = (
+ f"SELECT {ids_table['pk_column']}, {column_cast} "
+ f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
+ )
+ with db.engine.begin() as conn:
+ rs = conn.execute(sa.text(query))
+ for row in rs:
+ record_id = str(row[0])
+ content = str(row[1])
+
+ # Find all UUIDs in the content
+ import re
+
+ uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
+ matches = uuid_pattern.findall(content)
+
+ for ref_file_id in matches:
+ if ref_file_id not in file_key_map:
+ continue
+ storage_key = file_key_map[ref_file_id]
+
+ # Apply filters
+ if file_id and ref_file_id != file_id:
+ continue
+ if key and not storage_key.endswith(key):
+ continue
+
+ # Only collect items within the requested page range
+ if offset <= total_count < offset + limit:
+ paginated_usages.append(
+ {
+ "src": f"{ids_table['table']}.{ids_table['column']}",
+ "record_id": record_id,
+ "file_id": ref_file_id,
+ "key": storage_key,
+ }
+ )
+ total_count += 1
+
+ # Output results
+ if output_json:
+ result = {
+ "total": total_count,
+ "offset": offset,
+ "limit": limit,
+ "usages": paginated_usages,
+ }
+ click.echo(json.dumps(result, indent=2))
+ else:
+ click.echo(
+ click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
+ )
+ click.echo("")
+
+ if not paginated_usages:
+ click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
+ return
+
+ # Print table header
+ click.echo(
+ click.style(
+ f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
+ fg="cyan",
+ )
+ )
+ click.echo(click.style("-" * 190, fg="white"))
+
+ # Print each usage
+ for usage in paginated_usages:
+ click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
+
+ # Show pagination info
+ if offset + limit < total_count:
+ click.echo("")
+ click.echo(
+ click.style(
+ f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
+ )
+ )
+ click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
+
+
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
@@ -1900,3 +2549,79 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
+
+
+@click.command("clean-expired-messages", help="Clean expired messages.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ required=True,
+ help="Lower bound (inclusive) for created_at.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ required=True,
+ help="Upper bound (exclusive) for created_at.",
+)
+@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
+@click.option(
+ "--graceful-period",
+ default=21,
+ show_default=True,
+ help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
+)
+@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
+def clean_expired_messages(
+ batch_size: int,
+ graceful_period: int,
+ start_from: datetime.datetime,
+ end_before: datetime.datetime,
+ dry_run: bool,
+):
+ """
+ Clean expired messages and related data for tenants based on clean policy.
+ """
+ click.echo(click.style("clean_messages: start clean messages.", fg="green"))
+
+ start_at = time.perf_counter()
+
+ try:
+ # Create policy based on billing configuration
+ # NOTE: graceful_period will be ignored when billing is disabled.
+ policy = create_message_clean_policy(graceful_period_days=graceful_period)
+
+ # Create and run the cleanup service
+ service = MessagesCleanService.from_time_range(
+ policy=policy,
+ start_from=start_from,
+ end_before=end_before,
+ batch_size=batch_size,
+ dry_run=dry_run,
+ )
+ stats = service.run()
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ f"clean_messages: completed successfully\n"
+ f" - Latency: {end_at - start_at:.2f}s\n"
+ f" - Batches processed: {stats['batches']}\n"
+ f" - Total messages scanned: {stats['total_messages']}\n"
+ f" - Messages filtered: {stats['filtered_messages']}\n"
+ f" - Messages deleted: {stats['total_deleted']}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages failed")
+ click.echo(
+ click.style(
+ f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
+
+ click.echo(click.style("messages cleanup completed.", fg="green"))
diff --git a/api/configs/extra/__init__.py b/api/configs/extra/__init__.py
index 4543b5389d..de97adfc0e 100644
--- a/api/configs/extra/__init__.py
+++ b/api/configs/extra/__init__.py
@@ -1,9 +1,11 @@
+from configs.extra.archive_config import ArchiveStorageConfig
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
+ ArchiveStorageConfig,
NotionConfig,
SentryConfig,
):
diff --git a/api/configs/extra/archive_config.py b/api/configs/extra/archive_config.py
new file mode 100644
index 0000000000..a85628fa61
--- /dev/null
+++ b/api/configs/extra/archive_config.py
@@ -0,0 +1,43 @@
+from pydantic import Field
+from pydantic_settings import BaseSettings
+
+
+class ArchiveStorageConfig(BaseSettings):
+ """
+ Configuration settings for workflow run logs archiving storage.
+ """
+
+ ARCHIVE_STORAGE_ENABLED: bool = Field(
+ description="Enable workflow run logs archiving to S3-compatible storage",
+ default=False,
+ )
+
+ ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
+ description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
+ description="Name of the bucket to store archived workflow logs",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
+ description="Name of the bucket to store exported workflow runs",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
+ description="Access key ID for authenticating with storage",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
+ description="Secret access key for authenticating with storage",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_REGION: str = Field(
+ description="Region for storage (use 'auto' if the provider supports it)",
+ default="auto",
+ )
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index 0e20cc9f9e..f427466c1e 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -604,6 +604,11 @@ class LoggingConfig(BaseSettings):
default="INFO",
)
+ LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
+ description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
+ default="text",
+ )
+
LOG_FILE: str | None = Field(
description="File path for log output.",
default=None,
@@ -961,6 +966,12 @@ class MailConfig(BaseSettings):
default=False,
)
+ SMTP_LOCAL_HOSTNAME: str | None = Field(
+ description="Override the local hostname used in SMTP HELO/EHLO. "
+ "Useful behind NAT or when the default hostname causes rejections.",
+ default=None,
+ )
+
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
@@ -971,6 +982,16 @@ class MailConfig(BaseSettings):
default=None,
)
+ ENABLE_TRIAL_APP: bool = Field(
+ description="Enable trial app",
+ default=False,
+ )
+
+ ENABLE_EXPLORE_BANNER: bool = Field(
+ description="Enable explore banner",
+ default=False,
+ )
+
class RagEtlConfig(BaseSettings):
"""
@@ -1113,6 +1134,10 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable clean messages task",
default=False,
)
+ ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
+ description="Enable scheduled workflow run cleanup task",
+ default=False,
+ )
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,
@@ -1308,6 +1333,10 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
+ SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
+ description="Lock TTL for sandbox expired records clean task in seconds",
+ default=90000,
+ )
class FeatureConfig(
diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py
index 4ad30014c7..42ede718c4 100644
--- a/api/configs/feature/hosted_service/__init__.py
+++ b/api/configs/feature/hosted_service/__init__.py
@@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
default="",
)
+ HOSTED_POOL_CREDITS: int = Field(
+ description="Pool credits for hosted service",
+ default=200,
+ )
+
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@@ -60,19 +65,46 @@ class HostedOpenAiConfig(BaseSettings):
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
- default="gpt-3.5-turbo,"
- "gpt-3.5-turbo-1106,"
- "gpt-3.5-turbo-instruct,"
+ default="gpt-4,"
+ "gpt-4-turbo-preview,"
+ "gpt-4-turbo-2024-04-09,"
+ "gpt-4-1106-preview,"
+ "gpt-4-0125-preview,"
+ "gpt-4-turbo,"
+ "gpt-4.1,"
+ "gpt-4.1-2025-04-14,"
+ "gpt-4.1-mini,"
+ "gpt-4.1-mini-2025-04-14,"
+ "gpt-4.1-nano,"
+ "gpt-4.1-nano-2025-04-14,"
+ "gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
+ "gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
- "text-davinci-003",
- )
-
- HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
- description="Quota limit for hosted OpenAI service usage",
- default=200,
+ "gpt-3.5-turbo-instruct,"
+ "text-davinci-003,"
+ "chatgpt-4o-latest,"
+ "gpt-4o,"
+ "gpt-4o-2024-05-13,"
+ "gpt-4o-2024-08-06,"
+ "gpt-4o-2024-11-20,"
+ "gpt-4o-audio-preview,"
+ "gpt-4o-audio-preview-2025-06-03,"
+ "gpt-4o-mini,"
+ "gpt-4o-mini-2024-07-18,"
+ "o3-mini,"
+ "o3-mini-2025-01-31,"
+ "gpt-5-mini-2025-08-07,"
+ "gpt-5-mini,"
+ "o4-mini,"
+ "o4-mini-2025-04-16,"
+ "gpt-5-chat-latest,"
+ "gpt-5,"
+ "gpt-5-2025-08-07,"
+ "gpt-5-nano,"
+ "gpt-5-nano-2025-08-07",
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
@@ -87,6 +119,13 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
+ "gpt-4-turbo,"
+ "gpt-4.1,"
+ "gpt-4.1-2025-04-14,"
+ "gpt-4.1-mini,"
+ "gpt-4.1-mini-2025-04-14,"
+ "gpt-4.1-nano,"
+ "gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
@@ -94,7 +133,150 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
- "text-davinci-003",
+ "text-davinci-003,"
+ "chatgpt-4o-latest,"
+ "gpt-4o,"
+ "gpt-4o-2024-05-13,"
+ "gpt-4o-2024-08-06,"
+ "gpt-4o-2024-11-20,"
+ "gpt-4o-audio-preview,"
+ "gpt-4o-audio-preview-2025-06-03,"
+ "gpt-4o-mini,"
+ "gpt-4o-mini-2024-07-18,"
+ "o3-mini,"
+ "o3-mini-2025-01-31,"
+ "gpt-5-mini-2025-08-07,"
+ "gpt-5-mini,"
+ "o4-mini,"
+ "o4-mini-2025-04-16,"
+ "gpt-5-chat-latest,"
+ "gpt-5,"
+ "gpt-5-2025-08-07,"
+ "gpt-5-nano,"
+ "gpt-5-nano-2025-08-07",
+ )
+
+
+class HostedGeminiConfig(BaseSettings):
+ """
+ Configuration for fetching Gemini service
+ """
+
+ HOSTED_GEMINI_API_KEY: str | None = Field(
+ description="API key for hosted Gemini service",
+ default=None,
+ )
+
+ HOSTED_GEMINI_API_BASE: str | None = Field(
+ description="Base URL for hosted Gemini API",
+ default=None,
+ )
+
+ HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
+ description="Organization ID for hosted Gemini service",
+ default=None,
+ )
+
+ HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
+ description="Enable trial access to hosted Gemini service",
+ default=False,
+ )
+
+ HOSTED_GEMINI_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for trial access",
+ default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
+ )
+
+ HOSTED_GEMINI_PAID_ENABLED: bool = Field(
+ description="Enable paid access to hosted gemini service",
+ default=False,
+ )
+
+ HOSTED_GEMINI_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
+ )
+
+
+class HostedXAIConfig(BaseSettings):
+ """
+ Configuration for fetching XAI service
+ """
+
+ HOSTED_XAI_API_KEY: str | None = Field(
+ description="API key for hosted XAI service",
+ default=None,
+ )
+
+ HOSTED_XAI_API_BASE: str | None = Field(
+ description="Base URL for hosted XAI API",
+ default=None,
+ )
+
+ HOSTED_XAI_API_ORGANIZATION: str | None = Field(
+ description="Organization ID for hosted XAI service",
+ default=None,
+ )
+
+ HOSTED_XAI_TRIAL_ENABLED: bool = Field(
+ description="Enable trial access to hosted XAI service",
+ default=False,
+ )
+
+ HOSTED_XAI_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for trial access",
+ default="grok-3,grok-3-mini,grok-3-mini-fast",
+ )
+
+ HOSTED_XAI_PAID_ENABLED: bool = Field(
+ description="Enable paid access to hosted XAI service",
+ default=False,
+ )
+
+ HOSTED_XAI_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="grok-3,grok-3-mini,grok-3-mini-fast",
+ )
+
+
+class HostedDeepseekConfig(BaseSettings):
+ """
+ Configuration for fetching Deepseek service
+ """
+
+ HOSTED_DEEPSEEK_API_KEY: str | None = Field(
+ description="API key for hosted Deepseek service",
+ default=None,
+ )
+
+ HOSTED_DEEPSEEK_API_BASE: str | None = Field(
+ description="Base URL for hosted Deepseek API",
+ default=None,
+ )
+
+ HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
+ description="Organization ID for hosted Deepseek service",
+ default=None,
+ )
+
+ HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
+ description="Enable trial access to hosted Deepseek service",
+ default=False,
+ )
+
+ HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for trial access",
+ default="deepseek-chat,deepseek-reasoner",
+ )
+
+ HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
+ description="Enable paid access to hosted Deepseek service",
+ default=False,
+ )
+
+ HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="deepseek-chat,deepseek-reasoner",
)
@@ -144,16 +326,66 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
- HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
- description="Quota limit for hosted Anthropic service usage",
- default=600000,
- )
-
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
+ HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="claude-opus-4-20250514,"
+ "claude-sonnet-4-20250514,"
+ "claude-3-5-haiku-20241022,"
+ "claude-3-opus-20240229,"
+ "claude-3-7-sonnet-20250219,"
+ "claude-3-haiku-20240307",
+ )
+ HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="claude-opus-4-20250514,"
+ "claude-sonnet-4-20250514,"
+ "claude-3-5-haiku-20241022,"
+ "claude-3-opus-20240229,"
+ "claude-3-7-sonnet-20250219,"
+ "claude-3-haiku-20240307",
+ )
+
+
+class HostedTongyiConfig(BaseSettings):
+ """
+ Configuration for hosted Tongyi service
+ """
+
+ HOSTED_TONGYI_API_KEY: str | None = Field(
+ description="API key for hosted Tongyi service",
+ default=None,
+ )
+
+ HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT: bool = Field(
+ description="Use international endpoint for hosted Tongyi service",
+ default=False,
+ )
+
+ HOSTED_TONGYI_TRIAL_ENABLED: bool = Field(
+ description="Enable trial access to hosted Tongyi service",
+ default=False,
+ )
+
+ HOSTED_TONGYI_PAID_ENABLED: bool = Field(
+ description="Enable paid access to hosted Anthropic service",
+ default=False,
+ )
+
+ HOSTED_TONGYI_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for trial access",
+ default="",
+ )
+
+ HOSTED_TONGYI_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="",
+ )
+
class HostedMinmaxConfig(BaseSettings):
"""
@@ -246,9 +478,13 @@ class HostedServiceConfig(
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
+ HostedTongyiConfig,
# moderation
HostedModerationConfig,
# credit config
HostedCreditConfig,
+ HostedGeminiConfig,
+ HostedXAIConfig,
+ HostedDeepseekConfig,
):
pass
diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py
index e297e748e9..cdd10740f8 100644
--- a/api/configs/middleware/storage/tencent_cos_storage_config.py
+++ b/api/configs/middleware/storage/tencent_cos_storage_config.py
@@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings):
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
default=None,
)
+
+ TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
+ description="Tencent Cloud COS custom domain setting",
+ default=None,
+ )
diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py
index be01f2dc36..2a35300401 100644
--- a/api/configs/middleware/storage/volcengine_tos_storage_config.py
+++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py
@@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
- Configuration settings for Volcengine Tinder Object Storage (TOS)
+ Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(
diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py
index 05cee51cc9..eb9b0ac2ab 100644
--- a/api/configs/middleware/vdb/milvus_config.py
+++ b/api/configs/middleware/vdb/milvus_config.py
@@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings):
description="Authentication token for Milvus, if token-based authentication is enabled",
default=None,
)
-
MILVUS_USER: str | None = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
default=None,
diff --git a/api/context/__init__.py b/api/context/__init__.py
new file mode 100644
index 0000000000..aebf9750ce
--- /dev/null
+++ b/api/context/__init__.py
@@ -0,0 +1,74 @@
+"""
+Core Context - Framework-agnostic context management.
+
+This module provides context management that is independent of any specific
+web framework. Framework-specific implementations register their context
+capture functions at application initialization time.
+
+This ensures the workflow layer remains completely decoupled from Flask
+or any other web framework.
+"""
+
+import contextvars
+from collections.abc import Callable
+
+from core.workflow.context.execution_context import (
+ ExecutionContext,
+ IExecutionContext,
+ NullAppContext,
+)
+
+# Global capturer function - set by framework-specific modules
+_capturer: Callable[[], IExecutionContext] | None = None
+
+
+def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
+ """
+ Register a context capture function.
+
+ This should be called by framework-specific modules (e.g., Flask)
+ during application initialization.
+
+ Args:
+ capturer: Function that captures current context and returns IExecutionContext
+ """
+ global _capturer
+ _capturer = capturer
+
+
+def capture_current_context() -> IExecutionContext:
+ """
+ Capture current execution context.
+
+ This function uses the registered context capturer. If no capturer
+ is registered, it returns a minimal context with only contextvars
+ (suitable for non-framework environments like tests or standalone scripts).
+
+ Returns:
+ IExecutionContext with captured context
+ """
+ if _capturer is None:
+ # No framework registered - return minimal context
+ return ExecutionContext(
+ app_context=NullAppContext(),
+ context_vars=contextvars.copy_context(),
+ )
+
+ return _capturer()
+
+
+def reset_context_provider() -> None:
+ """
+ Reset the context capturer.
+
+ This is primarily useful for testing to ensure a clean state.
+ """
+ global _capturer
+ _capturer = None
+
+
+__all__ = [
+ "capture_current_context",
+ "register_context_capturer",
+ "reset_context_provider",
+]
diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py
new file mode 100644
index 0000000000..2d465c8cf4
--- /dev/null
+++ b/api/context/flask_app_context.py
@@ -0,0 +1,192 @@
+"""
+Flask App Context - Flask implementation of AppContext interface.
+"""
+
+import contextvars
+import threading
+from collections.abc import Generator
+from contextlib import contextmanager
+from typing import Any, final
+
+from flask import Flask, current_app, g
+
+from core.workflow.context import register_context_capturer
+from core.workflow.context.execution_context import (
+ AppContext,
+ IExecutionContext,
+)
+
+
+@final
+class FlaskAppContext(AppContext):
+ """
+ Flask implementation of AppContext.
+
+ This adapts Flask's app context to the AppContext interface.
+ """
+
+ def __init__(self, flask_app: Flask) -> None:
+ """
+ Initialize Flask app context.
+
+ Args:
+ flask_app: The Flask application instance
+ """
+ self._flask_app = flask_app
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value from Flask app config."""
+ return self._flask_app.config.get(key, default)
+
+ def get_extension(self, name: str) -> Any:
+ """Get Flask extension by name."""
+ return self._flask_app.extensions.get(name)
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter Flask app context."""
+ with self._flask_app.app_context():
+ yield
+
+ @property
+ def flask_app(self) -> Flask:
+ """Get the underlying Flask app instance."""
+ return self._flask_app
+
+
+def capture_flask_context(user: Any = None) -> IExecutionContext:
+ """
+ Capture current Flask execution context.
+
+ This function captures the Flask app context and contextvars from the
+ current environment. It should be called from within a Flask request or
+ app context.
+
+ Args:
+ user: Optional user object to include in context
+
+ Returns:
+ IExecutionContext with captured Flask context
+
+ Raises:
+ RuntimeError: If called outside Flask context
+ """
+ # Get Flask app instance
+ flask_app = current_app._get_current_object() # type: ignore
+
+ # Save current user if available
+ saved_user = user
+ if saved_user is None:
+ # Check for user in g (flask-login)
+ if hasattr(g, "_login_user"):
+ saved_user = g._login_user
+
+ # Capture contextvars
+ context_vars = contextvars.copy_context()
+
+ return FlaskExecutionContext(
+ flask_app=flask_app,
+ context_vars=context_vars,
+ user=saved_user,
+ )
+
+
+@final
+class FlaskExecutionContext:
+ """
+ Flask-specific execution context.
+
+ This is a specialized version of ExecutionContext that includes Flask app
+ context. It provides the same interface as ExecutionContext but with
+ Flask-specific implementation.
+ """
+
+ def __init__(
+ self,
+ flask_app: Flask,
+ context_vars: contextvars.Context,
+ user: Any = None,
+ ) -> None:
+ """
+ Initialize Flask execution context.
+
+ Args:
+ flask_app: Flask application instance
+ context_vars: Python contextvars
+ user: Optional user object
+ """
+ self._app_context = FlaskAppContext(flask_app)
+ self._context_vars = context_vars
+ self._user = user
+ self._flask_app = flask_app
+ self._local = threading.local()
+
+ @property
+ def app_context(self) -> FlaskAppContext:
+ """Get Flask app context."""
+ return self._app_context
+
+ @property
+ def context_vars(self) -> contextvars.Context:
+ """Get context variables."""
+ return self._context_vars
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ return self._user
+
+ def __enter__(self) -> "FlaskExecutionContext":
+ """Enter the Flask execution context."""
+ # Restore non-Flask context variables to avoid leaking Flask tokens across threads
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter Flask app context
+ cm = self._app_context.enter()
+ self._local.cm = cm
+ cm.__enter__()
+
+ # Restore user in new app context
+ if self._user is not None:
+ g._login_user = self._user
+
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the Flask execution context."""
+ cm = getattr(self._local, "cm", None)
+ if cm is not None:
+ cm.__exit__(*args)
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter Flask execution context as context manager."""
+ # Restore non-Flask context variables to avoid leaking Flask tokens across threads
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter Flask app context
+ with self._flask_app.app_context():
+ # Restore user in new app context
+ if self._user is not None:
+ g._login_user = self._user
+ yield
+
+
+def init_flask_context() -> None:
+ """
+ Initialize Flask context capture by registering the capturer.
+
+ This function should be called during Flask application initialization
+ to register the Flask-specific context capturer with the core context module.
+
+ Example:
+ app = Flask(__name__)
+ init_flask_context() # Register Flask context capturer
+
+ Note:
+ This function does not need the app instance as it uses Flask's
+ `current_app` to get the app when capturing context.
+ """
+ register_context_capturer(capture_flask_context)
diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py
index df9de825de..c16a23fac8 100644
--- a/api/controllers/common/fields.py
+++ b/api/controllers/common/fields.py
@@ -1,62 +1,59 @@
-from flask_restx import Api, Namespace, fields
+from __future__ import annotations
-from libs.helper import AppIconUrlField
+from typing import Any, TypeAlias
-parameters__system_parameters = {
- "image_file_size_limit": fields.Integer,
- "video_file_size_limit": fields.Integer,
- "audio_file_size_limit": fields.Integer,
- "file_size_limit": fields.Integer,
- "workflow_file_upload_limit": fields.Integer,
-}
+from pydantic import BaseModel, ConfigDict, computed_field
+
+from core.file import helpers as file_helpers
+from models.model import IconType
+
+JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
+JSONObject: TypeAlias = dict[str, Any]
-def build_system_parameters_model(api_or_ns: Api | Namespace):
- """Build the system parameters model for the API or Namespace."""
- return api_or_ns.model("SystemParameters", parameters__system_parameters)
+class SystemParameters(BaseModel):
+ image_file_size_limit: int
+ video_file_size_limit: int
+ audio_file_size_limit: int
+ file_size_limit: int
+ workflow_file_upload_limit: int
-parameters_fields = {
- "opening_statement": fields.String,
- "suggested_questions": fields.Raw,
- "suggested_questions_after_answer": fields.Raw,
- "speech_to_text": fields.Raw,
- "text_to_speech": fields.Raw,
- "retriever_resource": fields.Raw,
- "annotation_reply": fields.Raw,
- "more_like_this": fields.Raw,
- "user_input_form": fields.Raw,
- "sensitive_word_avoidance": fields.Raw,
- "file_upload": fields.Raw,
- "system_parameters": fields.Nested(parameters__system_parameters),
-}
+class Parameters(BaseModel):
+ opening_statement: str | None = None
+ suggested_questions: list[str]
+ suggested_questions_after_answer: JSONObject
+ speech_to_text: JSONObject
+ text_to_speech: JSONObject
+ retriever_resource: JSONObject
+ annotation_reply: JSONObject
+ more_like_this: JSONObject
+ user_input_form: list[JSONObject]
+ sensitive_word_avoidance: JSONObject
+ file_upload: JSONObject
+ system_parameters: SystemParameters
-def build_parameters_model(api_or_ns: Api | Namespace):
- """Build the parameters model for the API or Namespace."""
- copied_fields = parameters_fields.copy()
- copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
- return api_or_ns.model("Parameters", copied_fields)
+class Site(BaseModel):
+ model_config = ConfigDict(from_attributes=True)
+ title: str
+ chat_color_theme: str | None = None
+ chat_color_theme_inverted: bool
+ icon_type: str | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+ description: str | None = None
+ copyright: str | None = None
+ privacy_policy: str | None = None
+ custom_disclaimer: str | None = None
+ default_language: str
+ show_workflow_steps: bool
+ use_icon_as_answer_icon: bool
-site_fields = {
- "title": fields.String,
- "chat_color_theme": fields.String,
- "chat_color_theme_inverted": fields.Boolean,
- "icon_type": fields.String,
- "icon": fields.String,
- "icon_background": fields.String,
- "icon_url": AppIconUrlField,
- "description": fields.String,
- "copyright": fields.String,
- "privacy_policy": fields.String,
- "custom_disclaimer": fields.String,
- "default_language": fields.String,
- "show_workflow_steps": fields.Boolean,
- "use_icon_as_answer_icon": fields.Boolean,
-}
-
-
-def build_site_model(api_or_ns: Api | Namespace):
- """Build the site model for the API or Namespace."""
- return api_or_ns.model("Site", site_fields)
+ @computed_field(return_type=str | None) # type: ignore
+ @property
+ def icon_url(self) -> str | None:
+ if self.icon and self.icon_type == IconType.IMAGE:
+ return file_helpers.get_signed_file_url(self.icon)
+ return None
diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py
index e0896a8dc2..a5a3e4ebbd 100644
--- a/api/controllers/common/schema.py
+++ b/api/controllers/common/schema.py
@@ -1,7 +1,11 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
+from enum import StrEnum
+
from flask_restx import Namespace
-from pydantic import BaseModel
+from pydantic import BaseModel, TypeAdapter
+
+from controllers.console import console_ns
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -19,8 +23,25 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
register_schema_model(namespace, model)
+def get_or_create_model(model_name: str, field_def):
+ existing = console_ns.models.get(model_name)
+ if existing is None:
+ existing = console_ns.model(model_name, field_def)
+ return existing
+
+
+def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
+ """Register multiple StrEnum with a namespace."""
+ for model in models:
+ namespace.schema_model(
+ model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+ )
+
+
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
+ "get_or_create_model",
+ "register_enum_models",
"register_schema_model",
"register_schema_models",
]
diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py
index 2cd7e230f4..902d67174b 100644
--- a/api/controllers/console/__init__.py
+++ b/api/controllers/console/__init__.py
@@ -108,10 +108,12 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
+ banner,
installed_app,
parameter,
recommended_app,
saved_message,
+ trial,
)
# Import tag controllers
@@ -146,6 +148,7 @@ __all__ = [
"apikey",
"app",
"audio",
+ "banner",
"billing",
"bp",
"completion",
@@ -200,6 +203,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
+ "trial",
"trigger_providers",
"version",
"website",
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index a25ca5ef51..e1ee2c24b8 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
-from models.model import App, InstalledApp, RecommendedApp
+from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
P = ParamSpec("P")
R = TypeVar("R")
@@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel):
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
+ can_trial: bool = Field(default=False)
+ trial_limit: int = Field(default=0)
@field_validator("language")
@classmethod
@@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel):
return supported_language(value)
+class InsertExploreBannerPayload(BaseModel):
+ category: str = Field(...)
+ title: str = Field(...)
+ description: str = Field(...)
+ img_src: str = Field(..., alias="img-src")
+ language: str = Field(default="en-US")
+ link: str = Field(...)
+ sort: int = Field(...)
+
+ @field_validator("language")
+ @classmethod
+ def validate_language(cls, value: str) -> str:
+ return supported_language(value)
+
+ model_config = {"populate_by_name": True}
+
+
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
+console_ns.schema_model(
+ InsertExploreBannerPayload.__name__,
+ InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
def admin_required(view: Callable[P, R]):
@wraps(view)
@@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
+ if payload.can_trial:
+ trial_app = db.session.execute(
+ select(TrialApp).where(TrialApp.app_id == payload.app_id)
+ ).scalar_one_or_none()
+ if not trial_app:
+ db.session.add(
+ TrialApp(
+ app_id=payload.app_id,
+ tenant_id=app.tenant_id,
+ trial_limit=payload.trial_limit,
+ )
+ )
+ else:
+ trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = payload.category
recommended_app.position = payload.position
+ if payload.can_trial:
+ trial_app = db.session.execute(
+ select(TrialApp).where(TrialApp.app_id == payload.app_id)
+ ).scalar_one_or_none()
+ if not trial_app:
+ db.session.add(
+ TrialApp(
+ app_id=payload.app_id,
+ tenant_id=app.tenant_id,
+ trial_limit=payload.trial_limit,
+ )
+ )
+ else:
+ trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
+ trial_app = session.execute(
+ select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
+ ).scalar_one_or_none()
+ if trial_app:
+ session.delete(trial_app)
+
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
+
+
+@console_ns.route("/admin/insert-explore-banner")
+class InsertExploreBannerApi(Resource):
+ @console_ns.doc("insert_explore_banner")
+ @console_ns.doc(description="Insert an explore banner")
+ @console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
+ @console_ns.response(201, "Banner inserted successfully")
+ @only_edition_cloud
+ @admin_required
+ def post(self):
+ payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
+
+ content = {
+ "category": payload.category,
+ "title": payload.title,
+ "description": payload.description,
+ "img-src": payload.img_src,
+ }
+
+ banner = ExporleBanner(
+ content=content,
+ link=payload.link,
+ sort=payload.sort,
+ language=payload.language,
+ )
+ db.session.add(banner)
+ db.session.commit()
+
+ return {"result": "success"}, 201
+
+
+@console_ns.route("/admin/delete-explore-banner/")
+class DeleteExploreBannerApi(Resource):
+ @console_ns.doc("delete_explore_banner")
+ @console_ns.doc(description="Delete an explore banner")
+ @console_ns.doc(params={"banner_id": "Banner ID to delete"})
+ @console_ns.response(204, "Banner deleted successfully")
+ @only_edition_cloud
+ @admin_required
+ def delete(self, banner_id):
+ banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
+ if not banner:
+ raise NotFound(f"Banner '{banner_id}' is not found")
+
+ db.session.delete(banner)
+ db.session.commit()
+
+ return {"result": "success"}, 204
diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py
index 9b0d4b1a78..c81709e985 100644
--- a/api/controllers/console/apikey.py
+++ b/api/controllers/console/apikey.py
@@ -22,10 +22,10 @@ api_key_fields = {
"created_at": TimestampField,
}
-api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
-
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
+api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
+
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 62e997dae2..8c371da596 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -1,15 +1,19 @@
import uuid
-from typing import Literal
+from datetime import datetime
+from typing import Any, Literal, TypeAlias
from flask import request
-from flask_restx import Resource, fields, marshal, marshal_with
-from pydantic import BaseModel, Field, field_validator
+from flask_restx import Resource
+from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
+from controllers.common.helpers import FileInfo
+from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
+from controllers.console.workspace.models import LoadBalancingPayload
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@@ -18,27 +22,37 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
+from core.file import helpers as file_helpers
from core.ops.ops_trace_manager import OpsTraceManager
-from core.workflow.enums import NodeType
+from core.rag.retrieval.retrieval_methods import RetrievalMethod
+from core.workflow.enums import NodeType, WorkflowExecutionStatus
from extensions.ext_database import db
-from fields.app_fields import (
- deleted_tool_fields,
- model_config_fields,
- model_config_partial_fields,
- site_fields,
- tag_fields,
-)
-from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
-from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
-from models import App, Workflow
+from models import App, DatasetPermissionEnum, Workflow
+from models.model import IconType
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
+from services.entities.knowledge_entities.knowledge_entities import (
+ DataSource,
+ InfoList,
+ NotionIcon,
+ NotionInfo,
+ NotionPage,
+ PreProcessingRule,
+ RerankingModel,
+ Rule,
+ Segmentation,
+ WebsiteInfo,
+ WeightKeywordSetting,
+ WeightModel,
+ WeightVectorSetting,
+)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
-DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+register_enum_models(console_ns, IconType)
class AppListQuery(BaseModel):
@@ -134,124 +148,310 @@ class AppTracePayload(BaseModel):
return value
-def reg(cls: type[BaseModel]):
- console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+JSONValue: TypeAlias = Any
-reg(AppListQuery)
-reg(CreateAppPayload)
-reg(UpdateAppPayload)
-reg(CopyAppPayload)
-reg(AppExportQuery)
-reg(AppNamePayload)
-reg(AppIconPayload)
-reg(AppSiteStatusPayload)
-reg(AppApiStatusPayload)
-reg(AppTracePayload)
+class ResponseModel(BaseModel):
+ model_config = ConfigDict(
+ from_attributes=True,
+ extra="ignore",
+ populate_by_name=True,
+ serialize_by_alias=True,
+ protected_namespaces=(),
+ )
-# Register models for flask_restx to avoid dict type issues in Swagger
-# Register base models first
-tag_model = console_ns.model("Tag", tag_fields)
-workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
+def _to_timestamp(value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return int(value.timestamp())
+ return value
-model_config_model = console_ns.model("ModelConfig", model_config_fields)
-model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
+def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
+ if icon is None or icon_type is None:
+ return None
+ icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
+ if icon_type_value.lower() != IconType.IMAGE:
+ return None
+ return file_helpers.get_signed_file_url(icon)
-deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
-site_model = console_ns.model("Site", site_fields)
+class Tag(ResponseModel):
+ id: str
+ name: str
+ type: str
-app_partial_model = console_ns.model(
- "AppPartial",
- {
- "id": fields.String,
- "name": fields.String,
- "max_active_requests": fields.Raw(),
- "description": fields.String(attribute="desc_or_prompt"),
- "mode": fields.String(attribute="mode_compatible_with_agent"),
- "icon_type": fields.String,
- "icon": fields.String,
- "icon_background": fields.String,
- "icon_url": AppIconUrlField,
- "model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
- "workflow": fields.Nested(workflow_partial_model, allow_null=True),
- "use_icon_as_answer_icon": fields.Boolean,
- "created_by": fields.String,
- "created_at": TimestampField,
- "updated_by": fields.String,
- "updated_at": TimestampField,
- "tags": fields.List(fields.Nested(tag_model)),
- "access_mode": fields.String,
- "create_user_name": fields.String,
- "author_name": fields.String,
- "has_draft_trigger": fields.Boolean,
- },
-)
-app_detail_model = console_ns.model(
- "AppDetail",
- {
- "id": fields.String,
- "name": fields.String,
- "description": fields.String,
- "mode": fields.String(attribute="mode_compatible_with_agent"),
- "icon": fields.String,
- "icon_background": fields.String,
- "enable_site": fields.Boolean,
- "enable_api": fields.Boolean,
- "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
- "workflow": fields.Nested(workflow_partial_model, allow_null=True),
- "tracing": fields.Raw,
- "use_icon_as_answer_icon": fields.Boolean,
- "created_by": fields.String,
- "created_at": TimestampField,
- "updated_by": fields.String,
- "updated_at": TimestampField,
- "access_mode": fields.String,
- "tags": fields.List(fields.Nested(tag_model)),
- },
-)
+class WorkflowPartial(ResponseModel):
+ id: str
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
-app_detail_with_site_model = console_ns.model(
- "AppDetailWithSite",
- {
- "id": fields.String,
- "name": fields.String,
- "description": fields.String,
- "mode": fields.String(attribute="mode_compatible_with_agent"),
- "icon_type": fields.String,
- "icon": fields.String,
- "icon_background": fields.String,
- "icon_url": AppIconUrlField,
- "enable_site": fields.Boolean,
- "enable_api": fields.Boolean,
- "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
- "workflow": fields.Nested(workflow_partial_model, allow_null=True),
- "api_base_url": fields.String,
- "use_icon_as_answer_icon": fields.Boolean,
- "max_active_requests": fields.Integer,
- "created_by": fields.String,
- "created_at": TimestampField,
- "updated_by": fields.String,
- "updated_at": TimestampField,
- "deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
- "access_mode": fields.String,
- "tags": fields.List(fields.Nested(tag_model)),
- "site": fields.Nested(site_model),
- },
-)
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
-app_pagination_model = console_ns.model(
- "AppPagination",
- {
- "page": fields.Integer,
- "limit": fields.Integer(attribute="per_page"),
- "total": fields.Integer,
- "has_more": fields.Boolean(attribute="has_next"),
- "data": fields.List(fields.Nested(app_partial_model), attribute="items"),
- },
+
+class ModelConfigPartial(ResponseModel):
+ model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
+ pre_prompt: str | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class ModelConfig(ResponseModel):
+ opening_statement: str | None = None
+ suggested_questions: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions")
+ )
+ suggested_questions_after_answer: JSONValue | None = Field(
+ default=None,
+ validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"),
+ )
+ speech_to_text: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text")
+ )
+ text_to_speech: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech")
+ )
+ retriever_resource: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource")
+ )
+ annotation_reply: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply")
+ )
+ more_like_this: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this")
+ )
+ sensitive_word_avoidance: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance")
+ )
+ external_data_tools: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools")
+ )
+ model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
+ user_input_form: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form")
+ )
+ dataset_query_variable: str | None = None
+ pre_prompt: str | None = None
+ agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode"))
+ prompt_type: str | None = None
+ chat_prompt_config: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config")
+ )
+ completion_prompt_config: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config")
+ )
+ dataset_configs: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs")
+ )
+ file_upload: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload")
+ )
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class Site(ResponseModel):
+ access_token: str | None = Field(default=None, validation_alias="code")
+ code: str | None = None
+ title: str | None = None
+ icon_type: str | IconType | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+ description: str | None = None
+ default_language: str | None = None
+ chat_color_theme: str | None = None
+ chat_color_theme_inverted: bool | None = None
+ customize_domain: str | None = None
+ copyright: str | None = None
+ privacy_policy: str | None = None
+ custom_disclaimer: str | None = None
+ customize_token_strategy: str | None = None
+ prompt_public: bool | None = None
+ app_base_url: str | None = None
+ show_workflow_steps: bool | None = None
+ use_icon_as_answer_icon: bool | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+
+ @computed_field(return_type=str | None) # type: ignore
+ @property
+ def icon_url(self) -> str | None:
+ return _build_icon_url(self.icon_type, self.icon)
+
+ @field_validator("icon_type", mode="before")
+ @classmethod
+ def _normalize_icon_type(cls, value: str | IconType | None) -> str | None:
+ if isinstance(value, IconType):
+ return value.value
+ return value
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class DeletedTool(ResponseModel):
+ type: str
+ tool_name: str
+ provider_id: str
+
+
+class AppPartial(ResponseModel):
+ id: str
+ name: str
+ max_active_requests: int | None = None
+ description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description"))
+ mode: str = Field(validation_alias="mode_compatible_with_agent")
+ icon_type: str | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+ model_config_: ModelConfigPartial | None = Field(
+ default=None,
+ validation_alias=AliasChoices("app_model_config", "model_config"),
+ alias="model_config",
+ )
+ workflow: WorkflowPartial | None = None
+ use_icon_as_answer_icon: bool | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+ tags: list[Tag] = Field(default_factory=list)
+ access_mode: str | None = None
+ create_user_name: str | None = None
+ author_name: str | None = None
+ has_draft_trigger: bool | None = None
+
+ @computed_field(return_type=str | None) # type: ignore
+ @property
+ def icon_url(self) -> str | None:
+ return _build_icon_url(self.icon_type, self.icon)
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class AppDetail(ResponseModel):
+ id: str
+ name: str
+ description: str | None = None
+ mode: str = Field(validation_alias="mode_compatible_with_agent")
+ icon: str | None = None
+ icon_background: str | None = None
+ enable_site: bool
+ enable_api: bool
+ model_config_: ModelConfig | None = Field(
+ default=None,
+ validation_alias=AliasChoices("app_model_config", "model_config"),
+ alias="model_config",
+ )
+ workflow: WorkflowPartial | None = None
+ tracing: JSONValue | None = None
+ use_icon_as_answer_icon: bool | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+ access_mode: str | None = None
+ tags: list[Tag] = Field(default_factory=list)
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class AppDetailWithSite(AppDetail):
+ icon_type: str | None = None
+ api_base_url: str | None = None
+ max_active_requests: int | None = None
+ deleted_tools: list[DeletedTool] = Field(default_factory=list)
+ site: Site | None = None
+
+ @computed_field(return_type=str | None) # type: ignore
+ @property
+ def icon_url(self) -> str | None:
+ return _build_icon_url(self.icon_type, self.icon)
+
+
+class AppPagination(ResponseModel):
+ page: int
+ limit: int = Field(validation_alias=AliasChoices("per_page", "limit"))
+ total: int
+ has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more"))
+ data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data"))
+
+
+class AppExportResponse(ResponseModel):
+ data: str
+
+
+register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
+
+register_schema_models(
+ console_ns,
+ AppListQuery,
+ CreateAppPayload,
+ UpdateAppPayload,
+ CopyAppPayload,
+ AppExportQuery,
+ AppNamePayload,
+ AppIconPayload,
+ AppSiteStatusPayload,
+ AppApiStatusPayload,
+ AppTracePayload,
+ Tag,
+ WorkflowPartial,
+ ModelConfigPartial,
+ ModelConfig,
+ Site,
+ DeletedTool,
+ AppPartial,
+ AppDetail,
+ AppDetailWithSite,
+ AppPagination,
+ AppExportResponse,
+ Segmentation,
+ PreProcessingRule,
+ Rule,
+ WeightVectorSetting,
+ WeightKeywordSetting,
+ WeightModel,
+ RerankingModel,
+ InfoList,
+ NotionInfo,
+ FileInfo,
+ WebsiteInfo,
+ NotionPage,
+ NotionIcon,
+ RerankingModel,
+ DataSource,
+ LoadBalancingPayload,
)
@@ -260,7 +460,7 @@ class AppListApi(Resource):
@console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(console_ns.models[AppListQuery.__name__])
- @console_ns.response(200, "Success", app_pagination_model)
+ @console_ns.response(200, "Success", console_ns.models[AppPagination.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -276,7 +476,8 @@ class AppListApi(Resource):
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination:
- return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
+ empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
+ return empty.model_dump(mode="json"), 200
if FeatureService.get_system_features().webapp_auth.enabled:
app_ids = [str(app.id) for app in app_pagination.items]
@@ -320,18 +521,18 @@ class AppListApi(Resource):
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
- return marshal(app_pagination, app_pagination_model), 200
+ pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
+ return pagination_model.model_dump(mode="json"), 200
@console_ns.doc("create_app")
@console_ns.doc(description="Create a new application")
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
- @console_ns.response(201, "App created successfully", app_detail_model)
+ @console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
- @marshal_with(app_detail_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@@ -341,8 +542,8 @@ class AppListApi(Resource):
app_service = AppService()
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
-
- return app, 201
+ app_detail = AppDetail.model_validate(app, from_attributes=True)
+ return app_detail.model_dump(mode="json"), 201
@console_ns.route("/apps/")
@@ -350,13 +551,12 @@ class AppApi(Resource):
@console_ns.doc("get_app_detail")
@console_ns.doc(description="Get application details")
@console_ns.doc(params={"app_id": "Application ID"})
- @console_ns.response(200, "Success", app_detail_with_site_model)
+ @console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
- @get_app_model
- @marshal_with(app_detail_with_site_model)
+ @get_app_model(mode=None)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
@@ -367,21 +567,21 @@ class AppApi(Resource):
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
- return app_model
+ response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.doc("update_app")
@console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
- @console_ns.response(200, "App updated successfully", app_detail_with_site_model)
+ @console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
- @get_app_model
+ @get_app_model(mode=None)
@edit_permission_required
- @marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
args = UpdateAppPayload.model_validate(console_ns.payload)
@@ -398,8 +598,8 @@ class AppApi(Resource):
"max_active_requests": args.max_active_requests or 0,
}
app_model = app_service.update_app(app_model, args_dict)
-
- return app_model
+ response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.doc("delete_app")
@console_ns.doc(description="Delete application")
@@ -425,14 +625,13 @@ class AppCopyApi(Resource):
@console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
- @console_ns.response(201, "App copied successfully", app_detail_with_site_model)
+ @console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
- @get_app_model
+ @get_app_model(mode=None)
@edit_permission_required
- @marshal_with(app_detail_with_site_model)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
@@ -458,7 +657,8 @@ class AppCopyApi(Resource):
stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt)
- return app, 201
+ response_model = AppDetailWithSite.model_validate(app, from_attributes=True)
+ return response_model.model_dump(mode="json"), 201
@console_ns.route("/apps//export")
@@ -467,11 +667,7 @@ class AppExportApi(Resource):
@console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
- @console_ns.response(
- 200,
- "App exported successfully",
- console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
- )
+ @console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@@ -482,13 +678,14 @@ class AppExportApi(Resource):
"""Export app"""
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- return {
- "data": AppDslService.export_dsl(
+ payload = AppExportResponse(
+ data=AppDslService.export_dsl(
app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
)
- }
+ )
+ return payload.model_dump(mode="json")
@console_ns.route("/apps//name")
@@ -497,20 +694,19 @@ class AppNameApi(Resource):
@console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
- @console_ns.response(200, "Name availability checked")
+ @console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__])
@setup_required
@login_required
@account_initialization_required
- @get_app_model
- @marshal_with(app_detail_model)
+ @get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.name)
-
- return app_model
+ response_model = AppDetail.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.route("/apps//icon")
@@ -524,16 +720,15 @@ class AppIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @get_app_model
- @marshal_with(app_detail_model)
+ @get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
-
- return app_model
+ response_model = AppDetail.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.route("/apps//site-enable")
@@ -542,21 +737,20 @@ class AppSiteStatus(Resource):
@console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
- @console_ns.response(200, "Site status updated successfully", app_detail_model)
+ @console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
- @get_app_model
- @marshal_with(app_detail_model)
+ @get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppSiteStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.enable_site)
-
- return app_model
+ response_model = AppDetail.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.route("/apps//api-enable")
@@ -565,21 +759,20 @@ class AppApiStatus(Resource):
@console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
- @console_ns.response(200, "API status updated successfully", app_detail_model)
+ @console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
- @get_app_model
- @marshal_with(app_detail_model)
+ @get_app_model(mode=None)
def post(self, app_model):
args = AppApiStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.enable_api)
-
- return app_model
+ response_model = AppDetail.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.route("/apps//trace")
diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py
index 22e2aeb720..fdef54ba5a 100644
--- a/api/controllers/console/app/app_import.py
+++ b/api/controllers/console/app/app_import.py
@@ -41,14 +41,14 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
- yaml_content: str | None = None
- yaml_url: str | None = None
- name: str | None = None
- description: str | None = None
- icon_type: str | None = None
- icon: str | None = None
- icon_background: str | None = None
- app_id: str | None = None
+ yaml_content: str | None = Field(None)
+ yaml_url: str | None = Field(None)
+ name: str | None = Field(None)
+ description: str | None = Field(None)
+ icon_type: str | None = Field(None)
+ icon: str | None = Field(None)
+ icon_background: str | None = Field(None)
+ app_id: str | None = Field(None)
console_ns.schema_model(
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index 1501d39485..14910c5895 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
-from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField
@@ -178,6 +177,12 @@ annotation_hit_history_model = console_ns.model(
},
)
+
+class MessageTextField(fields.Raw):
+ def format(self, value):
+ return value[0]["text"] if value else ""
+
+
# Simple message detail model
simple_message_detail_model = console_ns.model(
"SimpleMessageDetail",
@@ -344,10 +349,13 @@ class CompletionConversationApi(Resource):
)
if args.keyword:
+ from libs.helper import escape_like_pattern
+
+ escaped_keyword = escape_like_pattern(args.keyword)
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
- Message.query.ilike(f"%{args.keyword}%"),
- Message.answer.ilike(f"%{args.keyword}%"),
+ Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
+ Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
@@ -456,7 +464,10 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword:
- keyword_filter = f"%{args.keyword}%"
+ from libs.helper import escape_like_pattern
+
+ escaped_keyword = escape_like_pattern(args.keyword)
+ keyword_filter = f"%{escaped_keyword}%"
query = (
query.join(
Message,
@@ -465,11 +476,11 @@ class ChatConversationApi(Resource):
.join(subquery, subquery.c.conversation_id == Conversation.id)
.where(
or_(
- Message.query.ilike(keyword_filter),
- Message.answer.ilike(keyword_filter),
- Conversation.name.ilike(keyword_filter),
- Conversation.introduction.ilike(keyword_filter),
- subquery.c.from_end_user_session_id.ilike(keyword_filter),
+ Message.query.ilike(keyword_filter, escape="\\"),
+ Message.answer.ilike(keyword_filter, escape="\\"),
+ Conversation.name.ilike(keyword_filter, escape="\\"),
+ Conversation.introduction.ilike(keyword_filter, escape="\\"),
+ subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
),
)
.group_by(Conversation.id)
@@ -582,9 +593,12 @@ def _get_conversation(app_model, conversation_id):
if not conversation:
raise NotFound("Conversation Not Exists.")
- if not conversation.read_at:
- conversation.read_at = naive_utc_now()
- conversation.read_account_id = current_user.id
- db.session.commit()
+ db.session.execute(
+ sa.update(Conversation)
+ .where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
+ .values(read_at=naive_utc_now(), read_account_id=current_user.id)
+ )
+ db.session.commit()
+ db.session.refresh(conversation)
return conversation
diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py
index fbd7901646..3fa15d6d6d 100644
--- a/api/controllers/console/app/error.py
+++ b/api/controllers/console/app/error.py
@@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist"
description = "Draft workflow need to be initialized."
- code = 400
+ code = 404
class DraftWorkflowNotSync(BaseHTTPException):
error_code = "draft_workflow_not_sync"
description = "Workflow graph might have been modified, please refresh and resubmit."
- code = 400
+ code = 409
class TracingConfigNotExist(BaseHTTPException):
@@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
+
+
+class NeedAddIdsError(BaseHTTPException):
+ error_code = "need_add_ids"
+ description = "Need to add ids."
+ code = 400
diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py
index b0bdf2657d..94b9693929 100644
--- a/api/controllers/console/app/workflow.py
+++ b/api/controllers/console/app/workflow.py
@@ -12,6 +12,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
+from controllers.console.app.workflow_run import workflow_run_node_execution_model
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError
@@ -36,7 +37,6 @@ from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
-from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
@@ -89,26 +89,6 @@ workflow_pagination_fields_copy = workflow_pagination_fields.copy()
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
-# Reuse workflow_run_node_execution_model from workflow_run.py if already registered
-# Otherwise register it here
-from fields.end_user_fields import simple_end_user_fields
-
-simple_end_user_model = None
-try:
- simple_end_user_model = console_ns.models.get("SimpleEndUser")
-except AttributeError:
- pass
-if simple_end_user_model is None:
- simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
-
-workflow_run_node_execution_model = None
-try:
- workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
-except AttributeError:
- pass
-if workflow_run_node_execution_model is None:
- workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
-
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
@@ -471,7 +451,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
- args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
+ args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = AppGenerateService.generate_single_loop(
@@ -509,7 +489,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
- args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
+ args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = AppGenerateService.generate_single_loop(
@@ -1176,6 +1156,7 @@ class DraftWorkflowTriggerRunApi(Resource):
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
workflow_args = dict(event.workflow_args)
+
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
return helper.compact_generate_response(
AppGenerateService.generate(
@@ -1324,6 +1305,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
try:
workflow_args = dict(trigger_debug_event.workflow_args)
+
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
response = AppGenerateService.generate(
app_model=app_model,
diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py
index fa67fb8154..6736f24a2e 100644
--- a/api/controllers/console/app/workflow_app_log.py
+++ b/api/controllers/console/app/workflow_app_log.py
@@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
-from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
+from fields.workflow_app_log_fields import (
+ build_workflow_app_log_pagination_model,
+ build_workflow_archived_log_pagination_model,
+)
from libs.login import login_required
from models import App
from models.model import AppMode
@@ -61,6 +64,7 @@ console_ns.schema_model(
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
+workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
@console_ns.route("/apps//workflow-app-logs")
@@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource):
)
return workflow_app_log_pagination
+
+
+@console_ns.route("/apps//workflow-archived-logs")
+class WorkflowArchivedLogApi(Resource):
+ @console_ns.doc("get_workflow_archived_logs")
+ @console_ns.doc(description="Get workflow archived execution logs")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
+ @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model(mode=[AppMode.WORKFLOW])
+ @marshal_with(workflow_archived_log_pagination_model)
+ def get(self, app_model: App):
+ """
+ Get workflow archived logs
+ """
+ args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+
+ workflow_app_service = WorkflowAppService()
+ with Session(db.engine) as session:
+ workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
+ session=session,
+ app_model=app_model,
+ page=args.page,
+ limit=args.limit,
+ )
+
+ return workflow_app_log_pagination
diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py
index 36b2e40928..585ede6f5d 100644
--- a/api/controllers/console/app/workflow_run.py
+++ b/api/controllers/console/app/workflow_run.py
@@ -1,8 +1,10 @@
+from datetime import UTC, datetime, timedelta
from typing import Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
+from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
@@ -25,12 +27,14 @@ from fields.workflow_run_fields import (
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
+from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import current_user, login_required
-from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
+from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
@@ -45,6 +49,7 @@ def _build_backstage_input_url(form_token: str | None) -> str | None:
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
+EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@@ -111,6 +116,15 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
+workflow_run_export_fields = console_ns.model(
+ "WorkflowRunExport",
+ {
+ "status": fields.String(description="Export status: success/failed"),
+ "presigned_url": fields.String(description="Pre-signed URL for download", required=False),
+ "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False),
+ },
+)
+
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -199,6 +213,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
return result
+@console_ns.route("/apps//workflow-runs//export")
+class WorkflowRunExportApi(Resource):
+ @console_ns.doc("get_workflow_run_export_url")
+ @console_ns.doc(description="Generate a download URL for an archived workflow run.")
+ @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
+ @console_ns.response(200, "Export URL generated", workflow_run_export_fields)
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model()
+ def get(self, app_model: App, run_id: str):
+ tenant_id = str(app_model.tenant_id)
+ app_id = str(app_model.id)
+ run_id_str = str(run_id)
+
+ run_created_at = db.session.scalar(
+ select(WorkflowArchiveLog.run_created_at)
+ .where(
+ WorkflowArchiveLog.tenant_id == tenant_id,
+ WorkflowArchiveLog.app_id == app_id,
+ WorkflowArchiveLog.workflow_run_id == run_id_str,
+ )
+ .limit(1)
+ )
+ if not run_created_at:
+ return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404
+
+ prefix = (
+ f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/"
+ f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}"
+ )
+ archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+
+ try:
+ archive_storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ return {"code": "archive_storage_not_configured", "message": str(e)}, 500
+
+ presigned_url = archive_storage.generate_presigned_url(
+ archive_key,
+ expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS,
+ )
+ expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS)
+ return {
+ "status": "success",
+ "presigned_url": presigned_url,
+ "presigned_url_expires_at": expires_at.isoformat(),
+ }, 200
+
+
@console_ns.route("/apps//advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs_count")
diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py
index 9433b732e4..8236e766ae 100644
--- a/api/controllers/console/app/workflow_trigger.py
+++ b/api/controllers/console/app/workflow_trigger.py
@@ -1,13 +1,14 @@
import logging
from flask import request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
+from controllers.common.schema import get_or_create_model
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
@@ -22,6 +23,14 @@ from ..wraps import account_initialization_required, edit_permission_required, s
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
+
+triggers_list_fields_copy = triggers_list_fields.copy()
+triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
+triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
+
+webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
+
class Parser(BaseModel):
node_id: str
@@ -48,7 +57,7 @@ class WebhookTriggerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
- @marshal_with(webhook_trigger_fields)
+ @marshal_with(webhook_trigger_model)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
@@ -80,7 +89,7 @@ class AppTriggersApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
- @marshal_with(triggers_list_fields)
+ @marshal_with(triggers_list_model)
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
@@ -120,7 +129,7 @@ class AppTriggerEnableApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
- @marshal_with(trigger_fields)
+ @marshal_with(trigger_model)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)
diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py
index 9bb2718f89..e687d980fa 100644
--- a/api/controllers/console/app/wraps.py
+++ b/api/controllers/console/app/wraps.py
@@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None:
return app_model
+def _load_app_model_with_trial(app_id: str) -> App | None:
+ app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
+ return app_model
+
+
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
@@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator
else:
return decorator(view)
+
+
+def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
+ def decorator(view_func: Callable[P, R]):
+ @wraps(view_func)
+ def decorated_view(*args: P.args, **kwargs: P.kwargs):
+ if not kwargs.get("app_id"):
+ raise ValueError("missing app_id in path parameters")
+
+ app_id = kwargs.get("app_id")
+ app_id = str(app_id)
+
+ del kwargs["app_id"]
+
+ app_model = _load_app_model_with_trial(app_id)
+
+ if not app_model:
+ raise AppNotFoundError()
+
+ app_mode = AppMode.value_of(app_model.mode)
+
+ if mode is not None:
+ if isinstance(mode, list):
+ modes = mode
+ else:
+ modes = [mode]
+
+ if app_mode not in modes:
+ mode_values = {m.value for m in modes}
+ raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
+
+ kwargs["app_model"] = app_model
+
+ return view_func(*args, **kwargs)
+
+ return decorated_view
+
+ if view is None:
+ return decorator
+ else:
+ return decorator(view)
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index fe70d930fb..f741107b87 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -63,13 +63,19 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
- reg_email = args.email
token = args.token
- invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
+ invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
+
+ # Check workspace permission
+ if tenant:
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(tenant.id)
+
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
@@ -100,11 +106,12 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
- invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
+ normalized_request_email = args.email.lower() if args.email else None
+ invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
- RegisterService.revoke_token(args.workspace_id, args.email, args.token)
+ RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account = invitation["account"]
account.name = args.name
diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py
index fa082c735d..c2a95ddad2 100644
--- a/api/controllers/console/auth/email_register.py
+++ b/api/controllers/console/auth/email_register.py
@@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
- token = None
- token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
+ token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
+ is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_email_register_error_rate_limit(args.email)
+ AccountService.add_email_register_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
- AccountService.reset_email_register_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_email_register_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/email-register")
@@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
+ normalized_email = email.lower()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
else:
- account = self._create_new_account(email, args.password_confirm)
+ account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
- def _create_new_account(self, email, password) -> Account | None:
+ def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed
account = None
try:
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index 661f591182..394f205d93 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
-from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
- email=args.email,
+ email=normalized_email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(args.email)
+ AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=args.code, additional_data={"phase": "reset"}
+ token_email, code=args.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_forgot_password_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/forgot-password/resets")
@@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
-
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index 772d98822e..400df138b8 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -1,3 +1,5 @@
+from typing import Any
+
import flask_login
from flask import make_response, request
from flask_restx import Resource
@@ -88,33 +90,38 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
+ request_email = args.email
+ normalized_email = request_email.lower()
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
- is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
+ is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
- # TODO: why invitation is re-assigned with different type?
- invitation = args.invite_token # type: ignore
- if invitation:
- invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
+ invite_token = args.invite_token
+ invitation_data: dict[str, Any] | None = None
+ if invite_token:
+ invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
+ if invitation_data is None:
+ invite_token = None
try:
- if invitation:
- data = invitation.get("data", {}) # type: ignore
+ if invitation_data:
+ data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
- if invitee_email != args.email:
+ invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
+ if invitee_email_normalized != normalized_email:
raise InvalidEmailError()
- account = AccountService.authenticate(args.email, args.password, args.invite_token)
- else:
- account = AccountService.authenticate(args.email, args.password)
+ account = _authenticate_account_with_case_fallback(
+ request_email, normalized_email, args.password, invite_token
+ )
except services.errors.account.AccountLoginError:
raise AccountBannedError()
- except services.errors.account.AccountPasswordError:
- AccountService.add_login_error_rate_limit(args.email)
- raise AuthenticationFailedError()
+ except services.errors.account.AccountPasswordError as exc:
+ AccountService.add_login_error_rate_limit(normalized_email)
+ raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@@ -129,7 +136,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args.email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -169,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args.email)
+ account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
- email=args.email,
+ email=normalized_email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@@ -195,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -205,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args.email)
+ account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
- token = AccountService.send_email_code_login_email(email=args.email, language=language)
+ token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
else:
raise AccountNotFound()
else:
@@ -228,14 +237,17 @@ class EmailCodeLoginApi(Resource):
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
- user_email = args.email
+ original_email = args.email
+ user_email = original_email.lower()
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args.email:
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@@ -243,7 +255,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
- account = AccountService.get_user_through_email(user_email)
+ account = _get_account_with_case_fallback(original_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@@ -274,7 +286,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args.email)
+ AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -308,3 +320,22 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
+
+
+def _get_account_with_case_fallback(email: str):
+ account = AccountService.get_user_through_email(email)
+ if account or email == email.lower():
+ return account
+
+ return AccountService.get_user_through_email(email.lower())
+
+
+def _authenticate_account_with_case_fallback(
+ original_email: str, normalized_email: str, password: str, invite_token: str | None
+):
+ try:
+ return AccountService.authenticate(original_email, password, invite_token)
+ except services.errors.account.AccountPasswordError:
+ if original_email == normalized_email:
+ raise
+ return AccountService.authenticate(normalized_email, password, invite_token)
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index 7ad1e56373..112e152432 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -3,7 +3,6 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
-from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -118,13 +117,16 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
- if invitation_email != user_info.email:
+ invitation_email_normalized = (
+ invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
+ )
+ if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
- account = _generate_account(provider, user_info)
+ account, oauth_new_user = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
@@ -159,7 +161,10 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
- response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
+ base_url = dify_config.CONSOLE_WEB_URL
+ query_char = "&" if "?" in base_url else "?"
+ target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
+ response = redirect(target_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
@@ -172,14 +177,15 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account
-def _generate_account(provider: str, user_info: OAuthUserInfo):
+def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
+ oauth_new_user = False
if account:
tenants = TenantService.get_join_tenants(account)
@@ -193,8 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
tenant_was_created.send(new_tenant)
if not account:
+ normalized_email = user_info.email.lower()
+ oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@@ -205,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
- email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
+ email=normalized_email,
+ name=account_name,
+ password=None,
+ open_id=user_info.id,
+ provider=provider,
)
# Set interface language
@@ -220,4 +232,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
- return account
+ return account, oauth_new_user
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index cd958bbb36..01e9bf77c0 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -3,13 +3,13 @@ from collections.abc import Generator
from typing import Any, cast
from flask import request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
-from controllers.common.schema import register_schema_model
+from controllers.common.schema import get_or_create_model, register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@@ -17,7 +17,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
-from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
+from fields.data_source_fields import (
+ integrate_fields,
+ integrate_icon_fields,
+ integrate_list_fields,
+ integrate_notion_info_list_fields,
+ integrate_page_fields,
+ integrate_workspace_fields,
+)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
@@ -36,9 +43,62 @@ class NotionEstimatePayload(BaseModel):
doc_language: str = Field(default="English")
+class DataSourceNotionListQuery(BaseModel):
+ dataset_id: str | None = Field(default=None, description="Dataset ID")
+ credential_id: str = Field(..., description="Credential ID", min_length=1)
+ datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
+
+
+class DataSourceNotionPreviewQuery(BaseModel):
+ credential_id: str = Field(..., description="Credential ID", min_length=1)
+
+
register_schema_model(console_ns, NotionEstimatePayload)
+integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
+
+integrate_page_fields_copy = integrate_page_fields.copy()
+integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
+integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
+
+integrate_workspace_fields_copy = integrate_workspace_fields.copy()
+integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
+integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
+
+integrate_fields_copy = integrate_fields.copy()
+integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
+integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
+
+integrate_list_fields_copy = integrate_list_fields.copy()
+integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
+integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
+
+notion_page_fields = {
+ "page_name": fields.String,
+ "page_id": fields.String,
+ "page_icon": fields.Nested(integrate_icon_model, allow_null=True),
+ "is_bound": fields.Boolean,
+ "parent_id": fields.String,
+ "type": fields.String,
+}
+notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
+
+notion_workspace_fields = {
+ "workspace_name": fields.String,
+ "workspace_id": fields.String,
+ "workspace_icon": fields.String,
+ "pages": fields.List(fields.Nested(notion_page_model)),
+}
+notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
+
+integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
+integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
+integrate_notion_info_list_model = get_or_create_model(
+ "NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
+)
+
+
@console_ns.route(
"/data-source/integrates",
"/data-source/integrates//",
@@ -47,7 +107,7 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(integrate_list_fields)
+ @marshal_with(integrate_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
@@ -132,30 +192,19 @@ class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(integrate_notion_info_list_fields)
+ @marshal_with(integrate_notion_info_list_model)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
- dataset_id = request.args.get("dataset_id", default=None, type=str)
- credential_id = request.args.get("credential_id", default=None, type=str)
- if not credential_id:
- raise ValueError("Credential id is required.")
+ query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
- datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
- datasource_parameters = {}
- if datasource_parameters_str:
- try:
- datasource_parameters = json.loads(datasource_parameters_str)
- if not isinstance(datasource_parameters, dict):
- raise ValueError("datasource_parameters must be a JSON object.")
- except json.JSONDecodeError:
- raise ValueError("Invalid datasource_parameters JSON format.")
+ datasource_parameters = query.datasource_parameters or {}
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
- credential_id=credential_id,
+ credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
@@ -164,8 +213,8 @@ class DataSourceNotionListApi(Resource):
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
- if dataset_id:
- dataset = DatasetService.get_dataset(dataset_id)
+ if query.dataset_id:
+ dataset = DatasetService.get_dataset(query.dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
@@ -173,7 +222,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars(
select(Document).filter_by(
- dataset_id=dataset_id,
+ dataset_id=query.dataset_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
@@ -240,13 +289,12 @@ class DataSourceNotionApi(Resource):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
- credential_id = request.args.get("credential_id", default=None, type=str)
- if not credential_id:
- raise ValueError("Credential id is required.")
+ query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
+
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
- credential_id=credential_id,
+ credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index 8ceb896d4f..8fbbc51e21 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
-from controllers.common.schema import register_schema_models
+from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
@@ -34,6 +34,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
+ content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
@@ -41,6 +42,7 @@ from fields.dataset_fields import (
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
+ file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
@@ -55,41 +57,33 @@ from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
-
-def _get_or_create_model(model_name: str, field_def):
- existing = console_ns.models.get(model_name)
- if existing is None:
- existing = console_ns.model(model_name, field_def)
- return existing
-
-
# Register models for flask_restx to avoid dict type issues in Swagger
-dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
+dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
-tag_model = _get_or_create_model("Tag", tag_fields)
+tag_model = get_or_create_model("Tag", tag_fields)
-keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
-vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
+keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
+vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
-weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
+weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
-reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
+reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
-dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
+dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
-external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
+external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
-external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
+external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
-doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
+doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
-icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
+icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@@ -98,14 +92,22 @@ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_k
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
-dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
+dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
-dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
+file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
-app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
+content_fields_copy = content_fields.copy()
+content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
+content_model = get_or_create_model("DatasetContent", content_fields_copy)
+
+dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
+dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
+dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
+
+app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
-related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
+related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:
@@ -176,7 +178,18 @@ class IndexingEstimatePayload(BaseModel):
return result
-register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
+class ConsoleDatasetListQuery(BaseModel):
+ page: int = Field(default=1, description="Page number")
+ limit: int = Field(default=20, description="Number of items per page")
+ keyword: str | None = Field(default=None, description="Search keyword")
+ include_all: bool = Field(default=False, description="Include all datasets")
+ ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs")
+ tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
+
+
+register_schema_models(
+ console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
+)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@@ -275,18 +288,19 @@ class DatasetListApi(Resource):
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- ids = request.args.getlist("ids")
+ query = ConsoleDatasetListQuery.model_validate(request.args.to_dict())
# provider = request.args.get("provider", default="vendor")
- search = request.args.get("keyword", default=None, type=str)
- tag_ids = request.args.getlist("tag_ids")
- include_all = request.args.get("include_all", default="false").lower() == "true"
- if ids:
- datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
+ if query.ids:
+ datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
- page, limit, current_tenant_id, current_user, search, tag_ids, include_all
+ query.page,
+ query.limit,
+ current_tenant_id,
+ current_user,
+ query.keyword,
+ query.tag_ids,
+ query.include_all,
)
# check embedding setting
@@ -318,7 +332,13 @@ class DatasetListApi(Resource):
else:
item.update({"partial_member_list": []})
- response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
+ response = {
+ "data": data,
+ "has_more": len(datasets) == query.limit,
+ "limit": query.limit,
+ "total": total,
+ "page": query.page,
+ }
return response, 200
@console_ns.doc("create_dataset")
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index e94768f985..57fb9abf29 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -2,17 +2,19 @@ import json
import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
-from typing import Literal, cast
+from contextlib import ExitStack
+from typing import Any, Literal, cast
+from uuid import UUID
import sqlalchemy as sa
-from flask import request
+from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
-from controllers.common.schema import register_schema_models
+from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@@ -42,6 +44,7 @@ from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
+from services.file_service import FileService
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@@ -65,35 +68,31 @@ from ..wraps import (
logger = logging.getLogger(__name__)
-
-def _get_or_create_model(model_name: str, field_def):
- existing = console_ns.models.get(model_name)
- if existing is None:
- existing = console_ns.model(model_name, field_def)
- return existing
+# NOTE: Keep constants near the top of the module for discoverability.
+DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
# Register models for flask_restx to avoid dict type issues in Swagger
-dataset_model = _get_or_create_model("Dataset", dataset_fields)
+dataset_model = get_or_create_model("Dataset", dataset_fields)
-document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
+document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
-document_model = _get_or_create_model("Document", document_fields_copy)
+document_model = get_or_create_model("Document", document_fields_copy)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
-document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
+document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
-dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
+dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
@@ -104,6 +103,21 @@ class DocumentRenamePayload(BaseModel):
name: str
+class DocumentBatchDownloadZipPayload(BaseModel):
+ """Request payload for bulk downloading documents as a zip archive."""
+
+ document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
+
+
+class DocumentDatasetListParam(BaseModel):
+ page: int = Field(1, title="Page", description="Page number.")
+ limit: int = Field(20, title="Limit", description="Page size.")
+ search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
+ sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
+ status: str | None = Field(None, title="Status", description="Document status.")
+ fetch_val: str = Field("false", alias="fetch")
+
+
register_schema_models(
console_ns,
KnowledgeConfig,
@@ -111,6 +125,7 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
+ DocumentBatchDownloadZipPayload,
)
@@ -225,14 +240,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
- sort = request.args.get("sort", default="-created_at", type=str)
- status = request.args.get("status", default=None, type=str)
+ raw_args = request.args.to_dict()
+ param = DocumentDatasetListParam.model_validate(raw_args)
+ page = param.page
+ limit = param.limit
+ search = param.search
+ sort = param.sort_by
+ status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
- fetch_val = request.args.get("fetch", default="false")
+ fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:
@@ -751,12 +768,12 @@ class DocumentApi(DocumentResource):
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
- data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
- "data_source_info": data_source_info,
+ "data_source_info": document.data_source_info_dict,
+ "data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@@ -784,12 +801,12 @@ class DocumentApi(DocumentResource):
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
- data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
- "data_source_info": data_source_info,
+ "data_source_info": document.data_source_info_dict,
+ "data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@@ -842,6 +859,62 @@ class DocumentApi(DocumentResource):
return {"result": "success"}, 204
+@console_ns.route("/datasets//documents//download")
+class DocumentDownloadApi(DocumentResource):
+ """Return a signed download URL for a dataset document's original uploaded file."""
+
+ @console_ns.doc("get_dataset_document_download_url")
+ @console_ns.doc(description="Get a signed download URL for a dataset document's original uploaded file")
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
+ # Reuse the shared permission/tenant checks implemented in DocumentResource.
+ document = self.get_document(str(dataset_id), str(document_id))
+ return {"url": DocumentService.get_document_download_url(document)}
+
+
+@console_ns.route("/datasets//documents/download-zip")
+class DocumentBatchDownloadZipApi(DocumentResource):
+ """Download multiple uploaded-file documents as a single ZIP (avoids browser multi-download limits)."""
+
+ @console_ns.doc("download_dataset_documents_as_zip")
+ @console_ns.doc(description="Download selected dataset documents as a single ZIP archive (upload-file only)")
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
+ def post(self, dataset_id: str):
+ """Stream a ZIP archive containing the requested uploaded documents."""
+ # Parse and validate request payload.
+ payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
+
+ current_user, current_tenant_id = current_account_with_tenant()
+ dataset_id = str(dataset_id)
+ document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
+ upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
+ dataset_id=dataset_id,
+ document_ids=document_ids,
+ tenant_id=current_tenant_id,
+ current_user=current_user,
+ )
+
+ # Delegate ZIP packing to FileService, but keep Flask response+cleanup in the route.
+ with ExitStack() as stack:
+ zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files))
+ response = send_file(
+ zip_path,
+ mimetype="application/zip",
+ as_attachment=True,
+ download_name=download_name,
+ )
+ cleanup = stack.pop_all()
+ response.call_on_close(cleanup.close)
+ return response
+
+
@console_ns.route("/datasets//documents//processing/")
class DocumentProcessingApi(DocumentResource):
@console_ns.doc("update_document_processing")
@@ -1098,7 +1171,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(document_fields)
+ @marshal_with(document_model)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index e73abc2555..08e1ddd3e0 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -3,10 +3,12 @@ import uuid
from flask import request
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
-from sqlalchemy import select
+from sqlalchemy import String, cast, func, or_, select
+from sqlalchemy.dialects.postgresql import JSONB
from werkzeug.exceptions import Forbidden, NotFound
import services
+from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
@@ -28,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
+from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
@@ -87,6 +90,7 @@ register_schema_models(
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
+ ChildChunkUpdateArgs,
)
@@ -143,7 +147,31 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
- query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
+ # Escape special characters in keyword to prevent SQL injection via LIKE wildcards
+ escaped_keyword = escape_like_pattern(keyword)
+ # Search in both content and keywords fields
+ # Use database-specific methods for JSON array search
+ if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
+ # PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
+ keywords_condition = func.array_to_string(
+ func.array(
+ select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
+ .correlate(DocumentSegment)
+ .scalar_subquery()
+ ),
+ ",",
+ ).ilike(f"%{escaped_keyword}%", escape="\\")
+ else:
+ # MySQL: Cast JSON to string for pattern matching
+ # MySQL stores Chinese text directly in JSON without Unicode escaping
+ keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")
+
+ query = query.where(
+ or_(
+ DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
+ keywords_condition,
+ )
+ )
if args.enabled.lower() != "all":
if args.enabled.lower() == "true":
diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py
index 89c9fcad36..86090bcd10 100644
--- a/api/controllers/console/datasets/external.py
+++ b/api/controllers/console/datasets/external.py
@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
-from controllers.common.schema import register_schema_models
+from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@@ -28,34 +28,27 @@ from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
-def _get_or_create_model(model_name: str, field_def):
- existing = console_ns.models.get(model_name)
- if existing is None:
- existing = console_ns.model(model_name, field_def)
- return existing
-
-
def _build_dataset_detail_model():
- keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
- vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
+ keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
+ vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
- weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
+ weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
- reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
+ reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
- dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
+ dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
- tag_model = _get_or_create_model("Tag", tag_fields)
- doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
- external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
- external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
- icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
+ tag_model = get_or_create_model("Tag", tag_fields)
+ doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
+ external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
+ external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
+ icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@@ -64,7 +57,7 @@ def _build_dataset_detail_model():
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
- return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
+ return get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
try:
@@ -81,7 +74,7 @@ class ExternalKnowledgeApiPayload(BaseModel):
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
- name: str = Field(..., min_length=1, max_length=40)
+ name: str = Field(..., min_length=1, max_length=100)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
@@ -98,12 +91,19 @@ class BedrockRetrievalPayload(BaseModel):
knowledge_id: str
+class ExternalApiTemplateListQuery(BaseModel):
+ page: int = Field(default=1, description="Page number")
+ limit: int = Field(default=20, description="Number of items per page")
+ keyword: str | None = Field(default=None, description="Search keyword")
+
+
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
+ ExternalApiTemplateListQuery,
)
@@ -124,19 +124,17 @@ class ExternalApiTemplateListApi(Resource):
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
+ query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
- page, limit, current_tenant_id, search
+ query.page, query.limit, current_tenant_id, query.keyword
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
- "has_more": len(external_knowledge_apis) == limit,
- "limit": limit,
+ "has_more": len(external_knowledge_apis) == query.limit,
+ "limit": query.limit,
"total": total,
- "page": page,
+ "page": query.page,
}
return response, 200
diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py
index db7c50f422..db1a874437 100644
--- a/api/controllers/console/datasets/hit_testing_base.py
+++ b/api/controllers/console/datasets/hit_testing_base.py
@@ -1,7 +1,7 @@
import logging
from typing import Any
-from flask_restx import marshal, reqparse
+from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -56,15 +56,10 @@ class DatasetsHitTestingBase:
HitTestingService.hit_testing_args_check(args)
@staticmethod
- def parse_args():
- parser = (
- reqparse.RequestParser()
- .add_argument("query", type=str, required=False, location="json")
- .add_argument("attachment_ids", type=list, required=False, location="json")
- .add_argument("retrieval_model", type=dict, required=False, location="json")
- .add_argument("external_retrieval_model", type=dict, required=False, location="json")
- )
- return parser.parse_args()
+ def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
+ """Validate and return hit-testing arguments from an incoming payload."""
+ hit_testing_payload = HitTestingPayload.model_validate(payload or {})
+ return hit_testing_payload.model_dump(exclude_none=True)
@staticmethod
def perform_hit_testing(dataset, args):
diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py
index 8eead1696a..05fc4cd714 100644
--- a/api/controllers/console/datasets/metadata.py
+++ b/api/controllers/console/datasets/metadata.py
@@ -4,14 +4,16 @@ from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
-from controllers.common.schema import register_schema_model, register_schema_models
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
+ DocumentMetadataOperation,
MetadataArgs,
+ MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@@ -21,8 +23,9 @@ class MetadataUpdatePayload(BaseModel):
name: str
-register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
-register_schema_model(console_ns, MetadataUpdatePayload)
+register_schema_models(
+ console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
+)
@console_ns.route("/datasets//metadata")
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
index 720e2ce365..2911b1cf18 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
@@ -2,7 +2,7 @@ import logging
from typing import Any, NoReturn
from flask import Response, request
-from flask_restx import Resource, fields, marshal, marshal_with
+from flask_restx import Resource, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -14,7 +14,9 @@ from controllers.console.app.error import (
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
- _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
+ workflow_draft_variable_list_model,
+ workflow_draft_variable_list_without_value_model,
+ workflow_draft_variable_model,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
@@ -27,7 +29,6 @@ from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
-from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@@ -52,20 +53,6 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
-def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
- return var_list.variables
-
-
-_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
- "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
- "total": fields.Raw(),
-}
-
-_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
- "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
-}
-
-
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
@@ -92,7 +79,7 @@ def _api_prerequisite(f):
@console_ns.route("/rag/pipelines//workflows/draft/variables")
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
+ @marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
@@ -150,7 +137,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/rag/pipelines//workflows/draft/nodes//variables")
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
@@ -176,7 +163,7 @@ class RagPipelineVariableApi(Resource):
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ @marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
@@ -189,7 +176,7 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ @marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
@@ -307,7 +294,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
@console_ns.route("/rag/pipelines//workflows/draft/system-variables")
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
index d43ee9a6e0..af142b4646 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
@@ -1,9 +1,9 @@
from flask import request
-from flask_restx import Resource, marshal_with # type: ignore
+from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
-from controllers.common.schema import register_schema_models
+from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@@ -12,7 +12,11 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
-from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
+from fields.rag_pipeline_fields import (
+ leaked_dependency_fields,
+ pipeline_import_check_dependencies_fields,
+ pipeline_import_fields,
+)
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
@@ -38,13 +42,25 @@ class IncludeSecretQuery(BaseModel):
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
+pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
+
+leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
+pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
+pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
+ fields.Nested(leaked_dependency_model)
+)
+pipeline_import_check_dependencies_model = get_or_create_model(
+ "RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
+)
+
+
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
- @marshal_with(pipeline_import_fields)
+ @marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
# Check user role first
@@ -81,7 +97,7 @@ class RagPipelineImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
- @marshal_with(pipeline_import_fields)
+ @marshal_with(pipeline_import_model)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
@@ -106,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
- @marshal_with(pipeline_import_check_dependencies_fields)
+ @marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
index 46d67f0581..d34fd5088d 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
@@ -17,6 +17,13 @@ from controllers.console.app.error import (
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
+from controllers.console.app.workflow import workflow_model, workflow_pagination_model
+from controllers.console.app.workflow_run import (
+ workflow_run_detail_model,
+ workflow_run_node_execution_list_model,
+ workflow_run_node_execution_model,
+ workflow_run_pagination_model,
+)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
@@ -30,13 +37,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
-from fields.workflow_fields import workflow_fields, workflow_pagination_fields
-from fields.workflow_run_fields import (
- workflow_run_detail_fields,
- workflow_run_node_execution_fields,
- workflow_run_node_execution_list_fields,
- workflow_run_pagination_fields,
-)
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
@@ -145,7 +145,7 @@ class DraftRagPipelineApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get draft rag pipeline's workflow
@@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource):
pipeline=pipeline,
user=current_user,
args=args,
- invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE,
streaming=streaming,
)
@@ -521,7 +521,7 @@ class RagPipelineDraftNodeRunApi(Resource):
@edit_permission_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_node_execution_fields)
+ @marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow node
@@ -569,7 +569,7 @@ class PublishedRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get published pipeline
@@ -664,7 +664,7 @@ class PublishedAllRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
- @marshal_with(workflow_pagination_fields)
+ @marshal_with(workflow_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get published workflows
@@ -708,7 +708,7 @@ class RagPipelineByIdApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
def patch(self, pipeline: Pipeline, workflow_id: str):
"""
Update workflow attributes
@@ -830,7 +830,7 @@ class RagPipelineWorkflowRunListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_pagination_fields)
+ @marshal_with(workflow_run_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get workflow run list
@@ -858,7 +858,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_detail_fields)
+ @marshal_with(workflow_run_detail_model)
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run detail
@@ -877,7 +877,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_node_execution_list_fields)
+ @marshal_with(workflow_run_node_execution_list_model)
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
@@ -911,7 +911,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_node_execution_fields)
+ @marshal_with(workflow_run_node_execution_model)
def get(self, pipeline: Pipeline, node_id: str):
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@@ -952,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
- @marshal_with(workflow_run_node_execution_fields)
+ @marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline):
"""
Set datasource variables
diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py
new file mode 100644
index 0000000000..da306fbc9d
--- /dev/null
+++ b/api/controllers/console/explore/banner.py
@@ -0,0 +1,43 @@
+from flask import request
+from flask_restx import Resource
+
+from controllers.console import api
+from controllers.console.explore.wraps import explore_banner_enabled
+from extensions.ext_database import db
+from models.model import ExporleBanner
+
+
+class BannerApi(Resource):
+ """Resource for banner list."""
+
+ @explore_banner_enabled
+ def get(self):
+ """Get banner list."""
+ language = request.args.get("language", "en-US")
+
+ # Build base query for enabled banners
+ base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
+
+ # Try to get banners in the requested language
+ banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
+
+ # Fallback to en-US if no banners found and language is not en-US
+ if not banners and language != "en-US":
+ banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
+ # Convert banners to serializable format
+ result = []
+ for banner in banners:
+ banner_data = {
+ "id": banner.id,
+ "content": banner.content, # Already parsed as JSON by SQLAlchemy
+ "link": banner.link,
+ "sort": banner.sort,
+ "status": banner.status,
+ "created_at": banner.created_at.isoformat() if banner.created_at else None,
+ }
+ result.append(banner_data)
+
+ return result
+
+
+api.add_resource(BannerApi, "/explore/banners")
diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py
index 51995b8b8a..933c80f509 100644
--- a/api/controllers/console/explore/conversation.py
+++ b/api/controllers/console/explore/conversation.py
@@ -1,8 +1,7 @@
from typing import Any
from flask import request
-from flask_restx import marshal_with
-from pydantic import BaseModel, Field, model_validator
+from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -11,7 +10,11 @@ from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
-from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
+from fields.conversation_fields import (
+ ConversationInfiniteScrollPagination,
+ ResultResponse,
+ SimpleConversation,
+)
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
@@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource):
- @marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
app_model = installed_app.app
@@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session:
- return WebConversationService.pagination_by_last_id(
+ pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
@@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource):
invoke_from=InvokeFrom.EXPLORE,
pinned=args.pinned,
)
+ adapter = TypeAdapter(SimpleConversation)
+ conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
+ return ConversationInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=conversations,
+ ).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}, 204
+ return ResultResponse(result="success").model_dump(mode="json"), 204
@console_ns.route(
@@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource):
endpoint="installed_app_conversation_rename",
)
class ConversationRenameApi(InstalledAppResource):
- @marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id):
app_model = installed_app.app
@@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource):
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
- return ConversationService.rename(
+ conversation = ConversationService.rename(
app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
+ return (
+ TypeAdapter(SimpleConversation)
+ .validate_python(conversation, from_attributes=True)
+ .model_dump(mode="json")
+ )
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user)
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py
index 1e05ff4206..e96fa64f84 100644
--- a/api/controllers/console/explore/error.py
+++ b/api/controllers/console/explore/error.py
@@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
+
+
+class TrialAppNotAllowed(BaseHTTPException):
+ """*403* `Trial App Not Allowed`
+
+ Raise if the user has reached the trial app limit.
+ """
+
+ error_code = "trial_app_not_allowed"
+ code = 403
+ description = "the app is not allowed to be trial."
+
+
+class TrialAppLimitExceeded(BaseHTTPException):
+ """*403* `Trial App Limit Exceeded`
+
+ Raise if the user has exceeded the trial app limit.
+ """
+
+ error_code = "trial_app_limit_exceeded"
+ code = 403
+ description = "The user has exceeded the trial app limit."
diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py
index e42db10ba6..aca766567f 100644
--- a/api/controllers/console/explore/installed_app.py
+++ b/api/controllers/console/explore/installed_app.py
@@ -2,16 +2,17 @@ import logging
from typing import Any
from flask import request
-from flask_restx import Resource, marshal_with
-from pydantic import BaseModel
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
+from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
-from fields.installed_app_fields import installed_app_list_fields
+from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
@@ -28,22 +29,37 @@ class InstalledAppUpdatePayload(BaseModel):
is_pinned: bool | None = None
+class InstalledAppsListQuery(BaseModel):
+ app_id: str | None = Field(default=None, description="App ID to filter by")
+
+
logger = logging.getLogger(__name__)
+app_model = get_or_create_model("InstalledAppInfo", app_fields)
+
+installed_app_fields_copy = installed_app_fields.copy()
+installed_app_fields_copy["app"] = fields.Nested(app_model)
+installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
+
+installed_app_list_fields_copy = installed_app_list_fields.copy()
+installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
+installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
+
+
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
- @marshal_with(installed_app_list_fields)
+ @marshal_with(installed_app_list_model)
def get(self):
- app_id = request.args.get("app_id", default=None, type=str)
+ query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
- if app_id:
+ if query.app_id:
installed_apps = db.session.scalars(
select(InstalledApp).where(
- and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
+ and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id)
)
).all()
else:
diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py
index 229b7c8865..88487ac96f 100644
--- a/api/controllers/console/explore/message.py
+++ b/api/controllers/console/explore/message.py
@@ -1,10 +1,8 @@
import logging
from typing import Literal
-from uuid import UUID
from flask import request
-from flask_restx import marshal_with
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@@ -24,8 +22,10 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
-from fields.message_fields import message_infinite_scroll_pagination_fields
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper
+from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
- conversation_id: UUID
- first_id: UUID | None = None
+ conversation_id: UUIDStrOrEmpty
+ first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
@@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource):
- @marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource):
args = MessageListQuery.model_validate(request.args.to_dict())
try:
- return MessageService.pagination_by_first_id(
+ pagination = MessageService.pagination_by_first_id(
app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
)
+ adapter = TypeAdapter(MessageListItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return MessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
logger.exception("internal server error.")
raise InternalServerError()
- return {"data": questions}
+ return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py
index 9c6b2aedfb..660a4d5aea 100644
--- a/api/controllers/console/explore/parameter.py
+++ b/api/controllers/console/explore/parameter.py
@@ -1,5 +1,3 @@
-from flask_restx import marshal_with
-
from controllers.common import fields
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
@@ -13,7 +11,6 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
- @marshal_with(fields.parameters_fields)
def get(self, installed_app: InstalledApp):
"""Retrieve app parameters."""
app_model = installed_app.app
@@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", [])
- return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@console_ns.route("/installed-apps//meta", endpoint="installed_app_meta")
diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py
index 2b2f807694..c9920c97cf 100644
--- a/api/controllers/console/explore/recommended_app.py
+++ b/api/controllers/console/explore/recommended_app.py
@@ -3,6 +3,7 @@ from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages
+from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
@@ -19,8 +20,10 @@ app_fields = {
"icon_background": fields.String,
}
+app_model = get_or_create_model("RecommendedAppInfo", app_fields)
+
recommended_app_fields = {
- "app": fields.Nested(app_fields, attribute="app"),
+ "app": fields.Nested(app_model, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
@@ -29,13 +32,18 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
+ "can_trial": fields.Boolean,
}
+recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
+
recommended_app_list_fields = {
- "recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
+ "recommended_apps": fields.List(fields.Nested(recommended_app_model)),
"categories": fields.List(fields.String),
}
+recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
+
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
@@ -52,7 +60,7 @@ class RecommendedAppListApi(Resource):
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required
@account_initialization_required
- @marshal_with(recommended_app_list_fields)
+ @marshal_with(recommended_app_list_model)
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py
index 6a9e274a0e..ea3de91741 100644
--- a/api/controllers/console/explore/saved_message.py
+++ b/api/controllers/console/explore/saved_message.py
@@ -1,55 +1,33 @@
-from uuid import UUID
-
from flask import request
-from flask_restx import fields, marshal_with
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
-from fields.conversation_fields import message_file_fields
-from libs.helper import TimestampField
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
+from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
- last_id: UUID | None = None
+ last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
- message_id: UUID
+ message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
-feedback_fields = {"rating": fields.String}
-
-message_fields = {
- "id": fields.String,
- "inputs": fields.Raw,
- "query": fields.String,
- "answer": fields.String,
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
- "created_at": TimestampField,
-}
-
-
@console_ns.route("/installed-apps//saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
- saved_message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
- }
-
- @marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@@ -59,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource):
args = SavedMessageListQuery.model_validate(request.args.to_dict())
- return SavedMessageService.pagination_by_last_id(
+ pagination = SavedMessageService.pagination_by_last_id(
app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
)
+ adapter = TypeAdapter(SavedMessageItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return SavedMessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app):
@@ -80,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -98,4 +83,4 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
- return {"result": "success"}, 204
+ return ResultResponse(result="success").model_dump(mode="json"), 204
diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py
new file mode 100644
index 0000000000..1eb0cdb019
--- /dev/null
+++ b/api/controllers/console/explore/trial.py
@@ -0,0 +1,555 @@
+import logging
+from typing import Any, cast
+
+from flask import request
+from flask_restx import Resource, fields, marshal, marshal_with, reqparse
+from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
+
+import services
+from controllers.common.fields import Parameters as ParametersResponse
+from controllers.common.fields import Site as SiteResponse
+from controllers.common.schema import get_or_create_model
+from controllers.console import api, console_ns
+from controllers.console.app.error import (
+ AppUnavailableError,
+ AudioTooLargeError,
+ CompletionRequestError,
+ ConversationCompletedError,
+ NeedAddIdsError,
+ NoAudioUploadedError,
+ ProviderModelCurrentlyNotSupportError,
+ ProviderNotInitializeError,
+ ProviderNotSupportSpeechToTextError,
+ ProviderQuotaExceededError,
+ UnsupportedAudioTypeError,
+)
+from controllers.console.app.wraps import get_app_model_with_trial
+from controllers.console.explore.error import (
+ AppSuggestedQuestionsAfterAnswerDisabledError,
+ NotChatAppError,
+ NotCompletionAppError,
+ NotWorkflowAppError,
+)
+from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
+from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
+from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
+from core.app.apps.base_app_queue_manager import AppQueueManager
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.errors.error import (
+ ModelCurrentlyNotSupportError,
+ ProviderTokenNotInitError,
+ QuotaExceededError,
+)
+from core.model_runtime.errors.invoke import InvokeError
+from core.workflow.graph_engine.manager import GraphEngineManager
+from extensions.ext_database import db
+from fields.app_fields import (
+ app_detail_fields_with_site,
+ deleted_tool_fields,
+ model_config_fields,
+ site_fields,
+ tag_fields,
+)
+from fields.dataset_fields import dataset_fields
+from fields.member_fields import build_simple_account_model
+from fields.workflow_fields import (
+ conversation_variable_fields,
+ pipeline_variable_fields,
+ workflow_fields,
+ workflow_partial_fields,
+)
+from libs import helper
+from libs.helper import uuid_value
+from libs.login import current_user
+from models import Account
+from models.account import TenantStatus
+from models.model import AppMode, Site
+from models.workflow import Workflow
+from services.app_generate_service import AppGenerateService
+from services.app_service import AppService
+from services.audio_service import AudioService
+from services.dataset_service import DatasetService
+from services.errors.audio import (
+ AudioTooLargeServiceError,
+ NoAudioUploadedServiceError,
+ ProviderNotSupportSpeechToTextServiceError,
+ UnsupportedAudioTypeServiceError,
+)
+from services.errors.conversation import ConversationNotExistsError
+from services.errors.llm import InvokeRateLimitError
+from services.errors.message import (
+ MessageNotExistsError,
+ SuggestedQuestionsAfterAnswerDisabledError,
+)
+from services.message_service import MessageService
+from services.recommended_app_service import RecommendedAppService
+
+logger = logging.getLogger(__name__)
+
+
+model_config_model = get_or_create_model("TrialAppModelConfig", model_config_fields)
+workflow_partial_model = get_or_create_model("TrialWorkflowPartial", workflow_partial_fields)
+deleted_tool_model = get_or_create_model("TrialDeletedTool", deleted_tool_fields)
+tag_model = get_or_create_model("TrialTag", tag_fields)
+site_model = get_or_create_model("TrialSite", site_fields)
+
+app_detail_fields_with_site_copy = app_detail_fields_with_site.copy()
+app_detail_fields_with_site_copy["model_config"] = fields.Nested(
+ model_config_model, attribute="app_model_config", allow_null=True
+)
+app_detail_fields_with_site_copy["workflow"] = fields.Nested(workflow_partial_model, allow_null=True)
+app_detail_fields_with_site_copy["deleted_tools"] = fields.List(fields.Nested(deleted_tool_model))
+app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
+app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
+app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
+
+simple_account_model = build_simple_account_model(console_ns)
+conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
+pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
+
+workflow_fields_copy = workflow_fields.copy()
+workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
+workflow_fields_copy["updated_by"] = fields.Nested(
+ simple_account_model, attribute="updated_by_account", allow_null=True
+)
+workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
+workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
+workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
+
+
+class TrialAppWorkflowRunApi(TrialAppResource):
+ def post(self, trial_app):
+ """
+ Run workflow
+ """
+ app_model = trial_app
+ if not app_model:
+ raise NotWorkflowAppError()
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode != AppMode.WORKFLOW:
+ raise NotWorkflowAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
+ parser.add_argument("files", type=list, required=False, location="json")
+ args = parser.parse_args()
+ assert current_user is not None
+ try:
+ app_id = app_model.id
+ user_id = current_user.id
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
+ )
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except InvokeRateLimitError as ex:
+ raise InvokeRateLimitHttpError(ex.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialAppWorkflowTaskStopApi(TrialAppResource):
+ def post(self, trial_app, task_id: str):
+ """
+ Stop workflow task
+ """
+ app_model = trial_app
+ if not app_model:
+ raise NotWorkflowAppError()
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode != AppMode.WORKFLOW:
+ raise NotWorkflowAppError()
+ assert current_user is not None
+
+ # Stop using both mechanisms for backward compatibility
+ # Legacy stop flag mechanism (without user check)
+ AppQueueManager.set_stop_flag_no_user_check(task_id)
+
+ # New graph engine command channel mechanism
+ GraphEngineManager.send_stop_command(task_id)
+
+ return {"result": "success"}
+
+
+class TrialChatApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
+ raise NotChatAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, location="json")
+ parser.add_argument("query", type=str, required=True, location="json")
+ parser.add_argument("files", type=list, required=False, location="json")
+ parser.add_argument("conversation_id", type=uuid_value, location="json")
+ parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
+ parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
+ args = parser.parse_args()
+
+ args["auto_generate_name"] = False
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
+ )
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+ except services.errors.conversation.ConversationCompletedError:
+ raise ConversationCompletedError()
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except InvokeRateLimitError as ex:
+ raise InvokeRateLimitHttpError(ex.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialMessageSuggestedQuestionApi(TrialAppResource):
+ @trial_feature_enable
+ def get(self, trial_app, message_id):
+ app_model = trial_app
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
+ raise NotChatAppError()
+
+ message_id = str(message_id)
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+ questions = MessageService.get_suggested_questions_after_answer(
+ app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
+ )
+ except MessageNotExistsError:
+ raise NotFound("Message not found")
+ except ConversationNotExistsError:
+ raise NotFound("Conversation not found")
+ except SuggestedQuestionsAfterAnswerDisabledError:
+ raise AppSuggestedQuestionsAfterAnswerDisabledError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+ return {"data": questions}
+
+
+class TrialChatAudioApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+
+ file = request.files["file"]
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return response
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except NoAudioUploadedServiceError:
+ raise NoAudioUploadedError()
+ except AudioTooLargeServiceError as e:
+ raise AudioTooLargeError(str(e))
+ except UnsupportedAudioTypeServiceError:
+ raise UnsupportedAudioTypeError()
+ except ProviderNotSupportSpeechToTextServiceError:
+ raise ProviderNotSupportSpeechToTextError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception as e:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialChatTextApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ try:
+ parser = reqparse.RequestParser()
+ parser.add_argument("message_id", type=str, required=False, location="json")
+ parser.add_argument("voice", type=str, location="json")
+ parser.add_argument("text", type=str, location="json")
+ parser.add_argument("streaming", type=bool, location="json")
+ args = parser.parse_args()
+
+ message_id = args.get("message_id", None)
+ text = args.get("text", None)
+ voice = args.get("voice", None)
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return response
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except NoAudioUploadedServiceError:
+ raise NoAudioUploadedError()
+ except AudioTooLargeServiceError as e:
+ raise AudioTooLargeError(str(e))
+ except UnsupportedAudioTypeServiceError:
+ raise UnsupportedAudioTypeError()
+ except ProviderNotSupportSpeechToTextServiceError:
+ raise ProviderNotSupportSpeechToTextError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception as e:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialCompletionApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ if app_model.mode != "completion":
+ raise NotCompletionAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, location="json")
+ parser.add_argument("query", type=str, location="json", default="")
+ parser.add_argument("files", type=list, required=False, location="json")
+ parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+ parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
+ args = parser.parse_args()
+
+ streaming = args["response_mode"] == "streaming"
+ args["auto_generate_name"] = False
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
+ )
+
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+ except services.errors.conversation.ConversationCompletedError:
+ raise ConversationCompletedError()
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialSitApi(Resource):
+ """Resource for trial app sites."""
+
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ """Retrieve app site info.
+
+ Returns the site configuration for the application including theme, icons, and text.
+ """
+ site = db.session.query(Site).where(Site.app_id == app_model.id).first()
+
+ if not site:
+ raise Forbidden()
+
+ assert app_model.tenant
+ if app_model.tenant.status == TenantStatus.ARCHIVE:
+ raise Forbidden()
+
+ return SiteResponse.model_validate(site).model_dump(mode="json")
+
+
+class TrialAppParameterApi(Resource):
+ """Resource for app variables."""
+
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ """Retrieve app parameters."""
+
+ if app_model is None:
+ raise AppUnavailableError()
+
+ if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+ workflow = app_model.workflow
+ if workflow is None:
+ raise AppUnavailableError()
+
+ features_dict = workflow.features_dict
+ user_input_form = workflow.user_input_form(to_old_structure=True)
+ else:
+ app_model_config = app_model.app_model_config
+ if app_model_config is None:
+ raise AppUnavailableError()
+
+ features_dict = app_model_config.to_dict()
+
+ user_input_form = features_dict.get("user_input_form", [])
+
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return ParametersResponse.model_validate(parameters).model_dump(mode="json")
+
+
+class AppApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ @marshal_with(app_detail_with_site_model)
+ def get(self, app_model):
+ """Get app detail"""
+
+ app_service = AppService()
+ app_model = app_service.get_app(app_model)
+
+ return app_model
+
+
+class AppWorkflowApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ @marshal_with(workflow_model)
+ def get(self, app_model):
+ """Get workflow detail"""
+ if not app_model.workflow_id:
+ raise AppUnavailableError()
+
+ workflow = (
+ db.session.query(Workflow)
+ .where(
+ Workflow.id == app_model.workflow_id,
+ )
+ .first()
+ )
+ return workflow
+
+
+class DatasetListApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ page = request.args.get("page", default=1, type=int)
+ limit = request.args.get("limit", default=20, type=int)
+ ids = request.args.getlist("ids")
+
+ tenant_id = app_model.tenant_id
+ if ids:
+ datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
+ else:
+ raise NeedAddIdsError()
+
+ data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
+
+ response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
+ return response
+
+
+api.add_resource(TrialChatApi, "/trial-apps//chat-messages", endpoint="trial_app_chat_completion")
+
+api.add_resource(
+ TrialMessageSuggestedQuestionApi,
+ "/trial-apps//messages//suggested-questions",
+ endpoint="trial_app_suggested_question",
+)
+
+api.add_resource(TrialChatAudioApi, "/trial-apps//audio-to-text", endpoint="trial_app_audio")
+api.add_resource(TrialChatTextApi, "/trial-apps//text-to-audio", endpoint="trial_app_text")
+
+api.add_resource(TrialCompletionApi, "/trial-apps//completion-messages", endpoint="trial_app_completion")
+
+api.add_resource(TrialSitApi, "/trial-apps//site")
+
+api.add_resource(TrialAppParameterApi, "/trial-apps//parameters", endpoint="trial_app_parameters")
+
+api.add_resource(AppApi, "/trial-apps/", endpoint="trial_app")
+
+api.add_resource(TrialAppWorkflowRunApi, "/trial-apps//workflows/run", endpoint="trial_app_workflow_run")
+api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps//workflows/tasks//stop")
+
+api.add_resource(AppWorkflowApi, "/trial-apps//workflows", endpoint="trial_app_workflow")
+api.add_resource(DatasetListApi, "/trial-apps//datasets", endpoint="trial_app_datasets")
diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py
index 2a97d312aa..38f0a04904 100644
--- a/api/controllers/console/explore/wraps.py
+++ b/api/controllers/console/explore/wraps.py
@@ -2,14 +2,15 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
+from flask import abort
from flask_restx import Resource
from werkzeug.exceptions import NotFound
-from controllers.console.explore.error import AppAccessDeniedError
+from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
-from models import InstalledApp
+from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
+def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
+ def decorator(view: Callable[Concatenate[App, P], R]):
+ @wraps(view)
+ def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
+ current_user, _ = current_account_with_tenant()
+
+ trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
+
+ if trial_app is None:
+ raise TrialAppNotAllowed()
+ app = trial_app.app
+
+ if app is None:
+ raise TrialAppNotAllowed()
+
+ account_trial_app_record = (
+ db.session.query(AccountTrialAppRecord)
+ .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
+ .first()
+ )
+ if account_trial_app_record:
+ if account_trial_app_record.count >= trial_app.trial_limit:
+ raise TrialAppLimitExceeded()
+
+ return view(app, *args, **kwargs)
+
+ return decorated
+
+ if view:
+ return decorator(view)
+ return decorator
+
+
+def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if not features.enable_trial_app:
+ abort(403, "Trial app feature is not enabled.")
+ return view(*args, **kwargs)
+
+ return decorated
+
+
+def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if not features.enable_explore_banner:
+ abort(403, "Explore banner feature is not enabled.")
+ return view(*args, **kwargs)
+
+ return decorated
+
+
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@@ -80,3 +136,13 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
+
+
+class TrialAppResource(Resource):
+ # must be reversed if there are multiple decorators
+
+ method_decorators = [
+ trial_app_required,
+ account_initialization_required,
+ login_required,
+ ]
diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py
index 6951c906e9..d3811e2d1b 100644
--- a/api/controllers/console/feature.py
+++ b/api/controllers/console/feature.py
@@ -1,6 +1,7 @@
from flask_restx import Resource, fields
+from werkzeug.exceptions import Unauthorized
-from libs.login import current_account_with_tenant, login_required
+from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureService
from . import console_ns
@@ -39,5 +40,21 @@ class SystemFeatureApi(Resource):
),
)
def get(self):
- """Get system-wide feature configuration"""
- return FeatureService.get_system_features().model_dump()
+ """Get system-wide feature configuration
+
+ NOTE: This endpoint is unauthenticated by design, as it provides system features
+ data required for dashboard initialization.
+
+ Authentication would create circular dependency (can't login without dashboard loading).
+
+ Only non-sensitive configuration data should be returned by this endpoint.
+ """
+ # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
+ # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
+ # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
+ # raise `Unauthorized` exception if authentication token is not provided.
+ try:
+ is_authenticated = current_user.is_authenticated
+ except Unauthorized:
+ is_authenticated = False
+ return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py
index 29417dc896..109a3cd0d3 100644
--- a/api/controllers/console/files.py
+++ b/api/controllers/console/files.py
@@ -1,7 +1,7 @@
from typing import Literal
from flask import request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource
from werkzeug.exceptions import Forbidden
import services
@@ -15,18 +15,21 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_schema_models
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_database import db
-from fields.file_fields import file_fields, upload_config_fields
+from fields.file_fields import FileResponse, UploadConfig
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService
from . import console_ns
+register_schema_models(console_ns, UploadConfig, FileResponse)
+
PREVIEW_WORDS_LIMIT = 3000
@@ -35,26 +38,27 @@ class FileApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(upload_config_fields)
+ @console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__])
def get(self):
- return {
- "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
- "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
- "file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
- "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
- "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
- "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
- "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
- "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
- "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
- "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
- }, 200
+ config = UploadConfig(
+ file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT,
+ batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT,
+ file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT,
+ image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
+ video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
+ audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
+ workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
+ image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT,
+ single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
+ attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
+ )
+ return config.model_dump(mode="json"), 200
@setup_required
@login_required
@account_initialization_required
- @marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
+ @console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source")
@@ -90,7 +94,8 @@ class FileApi(Resource):
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
raise BlockedFileExtensionError(blocked_extension_error.description)
- return upload_file, 201
+ response = FileResponse.model_validate(upload_file, from_attributes=True)
+ return response.model_dump(mode="json"), 201
@console_ns.route("/files//preview")
diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py
index 25a3d80522..d480af312b 100644
--- a/api/controllers/console/ping.py
+++ b/api/controllers/console/ping.py
@@ -1,17 +1,17 @@
-from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
-from . import console_ns
+from controllers.fastopenapi import console_router
-@console_ns.route("/ping")
-class PingApi(Resource):
- @console_ns.doc("health_check")
- @console_ns.doc(description="Health check endpoint for connection testing")
- @console_ns.response(
- 200,
- "Success",
- console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
- )
- def get(self):
- """Health check endpoint for connection testing"""
- return {"result": "pong"}
+class PingResponse(BaseModel):
+ result: str = Field(description="Health check result", examples=["pong"])
+
+
+@console_router.get(
+ "/ping",
+ response_model=PingResponse,
+ tags=["console"],
+)
+def ping() -> PingResponse:
+ """Health check endpoint for connection testing."""
+ return PingResponse(result="pong")
diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py
index 47eef7eb7e..70c7b80ffa 100644
--- a/api/controllers/console/remote_files.py
+++ b/api/controllers/console/remote_files.py
@@ -1,7 +1,7 @@
import urllib.parse
import httpx
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource
from pydantic import BaseModel, Field
import services
@@ -11,19 +11,22 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_schema_models
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
-from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
+from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
+register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
+
@console_ns.route("/remote-files/")
class RemoteFileInfoApi(Resource):
- @marshal_with(remote_file_info_fields)
+ @console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
def get(self, url):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
@@ -31,10 +34,11 @@ class RemoteFileInfoApi(Resource):
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
- return {
- "file_type": resp.headers.get("Content-Type", "application/octet-stream"),
- "file_length": int(resp.headers.get("Content-Length", 0)),
- }
+ info = RemoteFileInfo(
+ file_type=resp.headers.get("Content-Type", "application/octet-stream"),
+ file_length=int(resp.headers.get("Content-Length", 0)),
+ )
+ return info.model_dump(mode="json")
class RemoteFileUploadPayload(BaseModel):
@@ -50,7 +54,7 @@ console_ns.schema_model(
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
- @marshal_with(file_fields_with_signed_url)
+ @console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
def post(self):
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
@@ -85,13 +89,14 @@ class RemoteFileUploadApi(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
- return {
- "id": upload_file.id,
- "name": upload_file.name,
- "size": upload_file.size,
- "extension": upload_file.extension,
- "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
- "mime_type": upload_file.mime_type,
- "created_by": upload_file.created_by,
- "created_at": upload_file.created_at,
- }, 201
+ payload = FileWithSignedUrl(
+ id=upload_file.id,
+ name=upload_file.name,
+ size=upload_file.size,
+ extension=upload_file.extension,
+ url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
+ mime_type=upload_file.mime_type,
+ created_by=upload_file.created_by,
+ created_at=int(upload_file.created_at.timestamp()),
+ )
+ return payload.model_dump(mode="json"), 201
diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py
index 7fa02ae280..e1ea007232 100644
--- a/api/controllers/console/setup.py
+++ b/api/controllers/console/setup.py
@@ -1,20 +1,19 @@
+from typing import Literal
+
from flask import request
-from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from configs import dify_config
+from controllers.fastopenapi import console_router
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
-from . import console_ns
from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
-DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-
class SetupRequestPayload(BaseModel):
email: EmailStr = Field(..., description="Admin email address")
@@ -28,77 +27,66 @@ class SetupRequestPayload(BaseModel):
return valid_password(value)
-console_ns.schema_model(
- SetupRequestPayload.__name__,
- SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+class SetupStatusResponse(BaseModel):
+ step: Literal["not_started", "finished"] = Field(description="Setup step status")
+ setup_at: str | None = Field(default=None, description="Setup completion time (ISO format)")
+
+
+class SetupResponse(BaseModel):
+ result: str = Field(description="Setup result", examples=["success"])
+
+
+@console_router.get(
+ "/setup",
+ response_model=SetupStatusResponse,
+ tags=["console"],
)
+def get_setup_status_api() -> SetupStatusResponse:
+ """Get system setup status."""
+ if dify_config.EDITION == "SELF_HOSTED":
+ setup_status = get_setup_status()
+ if setup_status and not isinstance(setup_status, bool):
+ return SetupStatusResponse(step="finished", setup_at=setup_status.setup_at.isoformat())
+ if setup_status:
+ return SetupStatusResponse(step="finished")
+ return SetupStatusResponse(step="not_started")
+ return SetupStatusResponse(step="finished")
-@console_ns.route("/setup")
-class SetupApi(Resource):
- @console_ns.doc("get_setup_status")
- @console_ns.doc(description="Get system setup status")
- @console_ns.response(
- 200,
- "Success",
- console_ns.model(
- "SetupStatusResponse",
- {
- "step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
- "setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
- },
- ),
+@console_router.post(
+ "/setup",
+ response_model=SetupResponse,
+ tags=["console"],
+ status_code=201,
+)
+@only_edition_self_hosted
+def setup_system(payload: SetupRequestPayload) -> SetupResponse:
+ """Initialize system setup with admin account."""
+ if get_setup_status():
+ raise AlreadySetupError()
+
+ tenant_count = TenantService.get_tenant_count()
+ if tenant_count > 0:
+ raise AlreadySetupError()
+
+ if not get_init_validate_status():
+ raise NotInitValidateError()
+
+ normalized_email = payload.email.lower()
+
+ RegisterService.setup(
+ email=normalized_email,
+ name=payload.name,
+ password=payload.password,
+ ip_address=extract_remote_ip(request),
+ language=payload.language,
)
- def get(self):
- """Get system setup status"""
- if dify_config.EDITION == "SELF_HOSTED":
- setup_status = get_setup_status()
- # Check if setup_status is a DifySetup object rather than a bool
- if setup_status and not isinstance(setup_status, bool):
- return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
- elif setup_status:
- return {"step": "finished"}
- return {"step": "not_started"}
- return {"step": "finished"}
- @console_ns.doc("setup_system")
- @console_ns.doc(description="Initialize system setup with admin account")
- @console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
- @console_ns.response(
- 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
- )
- @console_ns.response(400, "Already setup or validation failed")
- @only_edition_self_hosted
- def post(self):
- """Initialize system setup with admin account"""
- # is set up
- if get_setup_status():
- raise AlreadySetupError()
-
- # is tenant created
- tenant_count = TenantService.get_tenant_count()
- if tenant_count > 0:
- raise AlreadySetupError()
-
- if not get_init_validate_status():
- raise NotInitValidateError()
-
- args = SetupRequestPayload.model_validate(console_ns.payload)
-
- # setup
- RegisterService.setup(
- email=args.email,
- name=args.name,
- password=args.password,
- ip_address=extract_remote_ip(request),
- language=args.language,
- )
-
- return {"result": "success"}, 201
+ return SetupResponse(result="success")
-def get_setup_status():
+def get_setup_status() -> DifySetup | bool | None:
if dify_config.EDITION == "SELF_HOSTED":
return db.session.query(DifySetup).first()
- else:
- return True
+
+ return True
diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py
index e9fbb515e4..9988524a80 100644
--- a/api/controllers/console/tag/tags.py
+++ b/api/controllers/console/tag/tags.py
@@ -30,11 +30,17 @@ class TagBindingRemovePayload(BaseModel):
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
+class TagListQueryParam(BaseModel):
+ type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
+ keyword: str | None = Field(None, description="Search keyword")
+
+
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
+ TagListQueryParam,
)
@@ -43,12 +49,15 @@ class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
+ @console_ns.doc(
+ params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
+ )
@marshal_with(dataset_tag_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
- tag_type = request.args.get("type", type=str, default="")
- keyword = request.args.get("keyword", default=None, type=str)
- tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
+ raw_args = request.args.to_dict()
+ param = TagListQueryParam.model_validate(raw_args)
+ tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
return tags, 200
diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py
index 419261ba2a..fdb23acf52 100644
--- a/api/controllers/console/version.py
+++ b/api/controllers/console/version.py
@@ -1,15 +1,11 @@
-import json
import logging
import httpx
-from flask import request
-from flask_restx import Resource, fields
from packaging import version
from pydantic import BaseModel, Field
from configs import dify_config
-
-from . import console_ns
+from controllers.fastopenapi import console_router
logger = logging.getLogger(__name__)
@@ -18,69 +14,61 @@ class VersionQuery(BaseModel):
current_version: str = Field(..., description="Current application version")
-console_ns.schema_model(
- VersionQuery.__name__,
- VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
+class VersionFeatures(BaseModel):
+ can_replace_logo: bool = Field(description="Whether logo replacement is supported")
+ model_load_balancing_enabled: bool = Field(description="Whether model load balancing is enabled")
+
+
+class VersionResponse(BaseModel):
+ version: str = Field(description="Latest version number")
+ release_date: str = Field(description="Release date of latest version")
+ release_notes: str = Field(description="Release notes for latest version")
+ can_auto_update: bool = Field(description="Whether auto-update is supported")
+ features: VersionFeatures = Field(description="Feature flags and capabilities")
+
+
+@console_router.get(
+ "/version",
+ response_model=VersionResponse,
+ tags=["console"],
)
+def check_version_update(query: VersionQuery) -> VersionResponse:
+ """Check for application version updates."""
+ check_update_url = dify_config.CHECK_UPDATE_URL
-
-@console_ns.route("/version")
-class VersionApi(Resource):
- @console_ns.doc("check_version_update")
- @console_ns.doc(description="Check for application version updates")
- @console_ns.expect(console_ns.models[VersionQuery.__name__])
- @console_ns.response(
- 200,
- "Success",
- console_ns.model(
- "VersionResponse",
- {
- "version": fields.String(description="Latest version number"),
- "release_date": fields.String(description="Release date of latest version"),
- "release_notes": fields.String(description="Release notes for latest version"),
- "can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
- "features": fields.Raw(description="Feature flags and capabilities"),
- },
+ result = VersionResponse(
+ version=dify_config.project.version,
+ release_date="",
+ release_notes="",
+ can_auto_update=False,
+ features=VersionFeatures(
+ can_replace_logo=dify_config.CAN_REPLACE_LOGO,
+ model_load_balancing_enabled=dify_config.MODEL_LB_ENABLED,
),
)
- def get(self):
- """Check for application version updates"""
- args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- check_update_url = dify_config.CHECK_UPDATE_URL
- result = {
- "version": dify_config.project.version,
- "release_date": "",
- "release_notes": "",
- "can_auto_update": False,
- "features": {
- "can_replace_logo": dify_config.CAN_REPLACE_LOGO,
- "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
- },
- }
-
- if not check_update_url:
- return result
-
- try:
- response = httpx.get(
- check_update_url,
- params={"current_version": args.current_version},
- timeout=httpx.Timeout(timeout=10.0, connect=3.0),
- )
- except Exception as error:
- logger.warning("Check update version error: %s.", str(error))
- result["version"] = args.current_version
- return result
-
- content = json.loads(response.content)
- if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
- result["version"] = content["version"]
- result["release_date"] = content["releaseDate"]
- result["release_notes"] = content["releaseNotes"]
- result["can_auto_update"] = content["canAutoUpdate"]
+ if not check_update_url:
return result
+ try:
+ response = httpx.get(
+ check_update_url,
+ params={"current_version": query.current_version},
+ timeout=httpx.Timeout(timeout=10.0, connect=3.0),
+ )
+ content = response.json()
+ except Exception as error:
+ logger.warning("Check update version error: %s.", str(error))
+ result.version = query.current_version
+ return result
+ latest_version = content.get("version", result.version)
+ if _has_new_version(latest_version=latest_version, current_version=f"{query.current_version}"):
+ result.version = latest_version
+ result.release_date = content.get("releaseDate", "")
+ result.release_notes = content.get("releaseNotes", "")
+ result.can_auto_update = content.get("canAutoUpdate", False)
+ return result
+
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
try:
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 55eaa2f09f..38c66525b3 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from datetime import datetime
from typing import Literal
@@ -39,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
-from models import Account, AccountIntegrate, InvitationCode
+from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel):
repeat_new_password: str
@model_validator(mode="after")
- def check_passwords_match(self) -> "AccountPasswordPayload":
+ def check_passwords_match(self) -> AccountPasswordPayload:
if self.new_password != self.repeat_new_password:
raise RepeatPasswordNotMatchError()
return self
@@ -169,6 +171,19 @@ reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
+integrate_fields = {
+ "provider": fields.String,
+ "created_at": TimestampField,
+ "is_bound": fields.Boolean,
+ "link": fields.String,
+}
+
+integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
+integrate_list_model = console_ns.model(
+ "AccountIntegrateList",
+ {"data": fields.List(fields.Nested(integrate_model))},
+)
+
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@@ -334,21 +349,10 @@ class AccountPasswordApi(Resource):
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource):
- integrate_fields = {
- "provider": fields.String,
- "created_at": TimestampField,
- "is_bound": fields.Boolean,
- "link": fields.String,
- }
-
- integrate_list_fields = {
- "data": fields.List(fields.Nested(integrate_fields)),
- }
-
@setup_required
@login_required
@account_initialization_required
- @marshal_with(integrate_list_fields)
+ @marshal_with(integrate_list_model)
def get(self):
account, _ = current_account_with_tenant()
@@ -534,7 +538,8 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
- user_email = args.email
+ user_email = None
+ email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@@ -544,16 +549,24 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
- if user_email != current_user.email:
+ if user_email.lower() != current_user.email.lower():
raise InvalidEmailError()
+
+ user_email = current_user.email
else:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
+ email_for_sending = account.email
+ user_email = account.email
token = AccountService.send_change_email_email(
- account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
+ account=account,
+ email=email_for_sending,
+ old_email=user_email,
+ language=language,
+ phase=args.phase,
)
return {"result": "success", "data": token}
@@ -569,9 +582,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
+ is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@@ -579,11 +592,13 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_change_email_error_rate_limit(args.email)
+ AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -594,8 +609,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
- AccountService.reset_change_email_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_change_email_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/account/change-email/reset")
@@ -609,11 +624,12 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
+ normalized_new_email = args.new_email.lower()
- if AccountService.is_account_in_freeze(args.new_email):
+ if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args.new_email):
+ if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@@ -624,13 +640,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
- if current_user.email != old_email:
+ if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
- updated_account = AccountService.update_account_email(current_user, email=args.new_email)
+ updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
- email=args.new_email,
+ email=normalized_new_email,
)
return updated_account
@@ -643,8 +659,9 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
- if AccountService.is_account_in_freeze(args.email):
+ normalized_email = args.email.lower()
+ if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args.email):
+ if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError()
return {"result": "success"}
diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py
index 9bf393ea2e..ccb60b1461 100644
--- a/api/controllers/console/workspace/load_balancing_config.py
+++ b/api/controllers/console/workspace/load_balancing_config.py
@@ -1,6 +1,8 @@
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@@ -10,10 +12,20 @@ from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
+class LoadBalancingCredentialPayload(BaseModel):
+ model: str
+ model_type: ModelType
+ credentials: dict[str, object]
+
+
+register_schema_models(console_ns, LoadBalancingCredentialPayload)
+
+
@console_ns.route(
"/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
+ @console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
- parser = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService()
@@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- credentials=args["credentials"],
+ model=payload.model,
+ model_type=payload.model_type,
+ credentials=payload.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
"/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
+ @console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
- parser = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService()
@@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- credentials=args["credentials"],
+ model=payload.model,
+ model_type=payload.model_type,
+ credentials=payload.credentials,
config_id=config_id,
)
except CredentialsValidateFailedError as ex:
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index 0142e14fb0..271cdce3c3 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -1,11 +1,12 @@
from urllib import parse
from flask import abort, request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
import services
from configs import dify_config
+from controllers.common.schema import get_or_create_model, register_enum_models
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
@@ -24,7 +25,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
-from fields.member_fields import account_with_role_list_fields
+from fields.member_fields import account_with_role_fields, account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
@@ -67,6 +68,13 @@ reg(MemberRoleUpdatePayload)
reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
+register_enum_models(console_ns, TenantAccountRole)
+
+account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
+
+account_with_role_list_fields_copy = account_with_role_list_fields.copy()
+account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
+account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
@console_ns.route("/workspaces/current/members")
@@ -76,7 +84,7 @@ class MemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(account_with_role_list_fields)
+ @marshal_with(account_with_role_list_model)
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@@ -107,6 +115,12 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
+
+ # Check workspace permission for member invitations
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(inviter.current_tenant.id)
+
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
@@ -116,26 +130,31 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
+ normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
- inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
+ tenant=inviter.current_tenant,
+ email=invitee_email,
+ language=interface_language,
+ role=invitee_role,
+ inviter=inviter,
)
- encoded_invitee_email = parse.quote(invitee_email)
+ encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
- "email": invitee_email,
+ "email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
- {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
+ {"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
- invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
+ invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",
@@ -216,7 +235,7 @@ class DatasetOperatorMemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(account_with_role_list_fields)
+ @marshal_with(account_with_role_list_model)
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index 2def57ed7b..583e3e3057 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -5,6 +5,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
+from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@@ -23,12 +24,13 @@ class ParserGetDefault(BaseModel):
model_type: ModelType
-class ParserPostDefault(BaseModel):
- class Inner(BaseModel):
- model_type: ModelType
- model: str | None = None
- provider: str | None = None
+class Inner(BaseModel):
+ model_type: ModelType
+ model: str | None = None
+ provider: str | None = None
+
+class ParserPostDefault(BaseModel):
model_settings: list[Inner]
@@ -105,19 +107,21 @@ class ParserParameter(BaseModel):
model: str
-def reg(cls: type[BaseModel]):
- console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+register_schema_models(
+ console_ns,
+ ParserGetDefault,
+ ParserPostDefault,
+ ParserDeleteModels,
+ ParserPostModels,
+ ParserGetCredentials,
+ ParserCreateCredential,
+ ParserUpdateCredential,
+ ParserDeleteCredential,
+ ParserParameter,
+ Inner,
+)
-
-reg(ParserGetDefault)
-reg(ParserPostDefault)
-reg(ParserDeleteModels)
-reg(ParserPostModels)
-reg(ParserGetCredentials)
-reg(ParserCreateCredential)
-reg(ParserUpdateCredential)
-reg(ParserDeleteCredential)
-reg(ParserParameter)
+register_enum_models(console_ns, ModelType)
@console_ns.route("/workspaces/current/default-model")
diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py
index ea74fc0337..d1485bc1c0 100644
--- a/api/controllers/console/workspace/plugin.py
+++ b/api/controllers/console/workspace/plugin.py
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from configs import dify_config
+from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
@@ -20,57 +21,12 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
-DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-
-
-def reg(cls: type[BaseModel]):
- console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
-
-
-@console_ns.route("/workspaces/current/plugin/debugging-key")
-class PluginDebuggingKeyApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- @plugin_permission_required(debug_required=True)
- def get(self):
- _, tenant_id = current_account_with_tenant()
-
- try:
- return {
- "key": PluginService.get_debugging_key(tenant_id),
- "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
- "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
- }
- except PluginDaemonClientSideError as e:
- raise ValueError(e)
-
class ParserList(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
-reg(ParserList)
-
-
-@console_ns.route("/workspaces/current/plugin/list")
-class PluginListApi(Resource):
- @console_ns.expect(console_ns.models[ParserList.__name__])
- @setup_required
- @login_required
- @account_initialization_required
- def get(self):
- _, tenant_id = current_account_with_tenant()
- args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
- try:
- plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
- except PluginDaemonClientSideError as e:
- raise ValueError(e)
-
- return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
-
-
class ParserLatest(BaseModel):
plugin_ids: list[str]
@@ -180,23 +136,73 @@ class ParserReadme(BaseModel):
language: str = Field(default="en-US")
-reg(ParserLatest)
-reg(ParserIcon)
-reg(ParserAsset)
-reg(ParserGithubUpload)
-reg(ParserPluginIdentifiers)
-reg(ParserGithubInstall)
-reg(ParserPluginIdentifierQuery)
-reg(ParserTasks)
-reg(ParserMarketplaceUpgrade)
-reg(ParserGithubUpgrade)
-reg(ParserUninstall)
-reg(ParserPermissionChange)
-reg(ParserDynamicOptions)
-reg(ParserDynamicOptionsWithCredentials)
-reg(ParserPreferencesChange)
-reg(ParserExcludePlugin)
-reg(ParserReadme)
+register_schema_models(
+ console_ns,
+ ParserList,
+ PluginAutoUpgradeSettingsPayload,
+ PluginPermissionSettingsPayload,
+ ParserLatest,
+ ParserIcon,
+ ParserAsset,
+ ParserGithubUpload,
+ ParserPluginIdentifiers,
+ ParserGithubInstall,
+ ParserPluginIdentifierQuery,
+ ParserTasks,
+ ParserMarketplaceUpgrade,
+ ParserGithubUpgrade,
+ ParserUninstall,
+ ParserPermissionChange,
+ ParserDynamicOptions,
+ ParserDynamicOptionsWithCredentials,
+ ParserPreferencesChange,
+ ParserExcludePlugin,
+ ParserReadme,
+)
+
+register_enum_models(
+ console_ns,
+ TenantPluginPermission.DebugPermission,
+ TenantPluginAutoUpgradeStrategy.UpgradeMode,
+ TenantPluginAutoUpgradeStrategy.StrategySetting,
+ TenantPluginPermission.InstallPermission,
+)
+
+
+@console_ns.route("/workspaces/current/plugin/debugging-key")
+class PluginDebuggingKeyApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @plugin_permission_required(debug_required=True)
+ def get(self):
+ _, tenant_id = current_account_with_tenant()
+
+ try:
+ return {
+ "key": PluginService.get_debugging_key(tenant_id),
+ "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
+ "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
+ }
+ except PluginDaemonClientSideError as e:
+ raise ValueError(e)
+
+
+@console_ns.route("/workspaces/current/plugin/list")
+class PluginListApi(Resource):
+ @console_ns.expect(console_ns.models[ParserList.__name__])
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ _, tenant_id = current_account_with_tenant()
+ args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ try:
+ plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
+ except PluginDaemonClientSideError as e:
+ raise ValueError(e)
+
+ return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index cb711d16e4..e9e7b72718 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -1,4 +1,5 @@
import io
+import logging
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
@@ -17,8 +18,8 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
+from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
@@ -40,6 +41,8 @@ from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
+logger = logging.getLogger(__name__)
+
def is_valid_url(url: str) -> bool:
if not url:
@@ -945,8 +948,8 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
- # Create provider in transaction
- with Session(db.engine) as session, session.begin():
+ # 1) Create provider in a short transaction (no network I/O inside)
+ with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
@@ -962,8 +965,26 @@ class ToolProviderMCPApi(Resource):
authentication=authentication,
)
- # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
- ToolProviderListCache.invalidate_cache(tenant_id)
+ # 2) Try to fetch tools immediately after creation so they appear without a second save.
+ # Perform network I/O outside any DB session to avoid holding locks.
+ try:
+ reconnect = MCPToolManageService.reconnect_with_url(
+ server_url=args["server_url"],
+ headers=args.get("headers") or {},
+ timeout=configuration.timeout,
+ sse_read_timeout=configuration.sse_read_timeout,
+ )
+ # Update just-created provider with authed/tools in a new short transaction
+ with session_factory.create_session() as session, session.begin():
+ service = MCPToolManageService(session=session)
+ db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
+ db_provider.authed = reconnect.authed
+ db_provider.tools = reconnect.tools
+
+ result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
+ except Exception:
+ # Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
+ logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
return jsonable_encoder(result)
@@ -1011,9 +1032,6 @@ class ToolProviderMCPApi(Resource):
validation_result=validation_result,
)
- # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
- ToolProviderListCache.invalidate_cache(current_tenant_id)
-
return {"result": "success"}
@console_ns.expect(parser_mcp_delete)
@@ -1028,9 +1046,6 @@ class ToolProviderMCPApi(Resource):
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
- # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
- ToolProviderListCache.invalidate_cache(current_tenant_id)
-
return {"result": "success"}
@@ -1081,8 +1096,6 @@ class ToolMCPAuthApi(Resource):
credentials=provider_entity.credentials,
authed=True,
)
- # Invalidate cache after updating credentials
- ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
except MCPAuthError as e:
try:
@@ -1096,22 +1109,16 @@ class ToolMCPAuthApi(Resource):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
- # Invalidate cache after auth actions may have updated provider state
- ToolProviderListCache.invalidate_cache(tenant_id)
return response
except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
- # Invalidate cache after clearing credentials
- ToolProviderListCache.invalidate_cache(tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
- # Invalidate cache after clearing credentials
- ToolProviderListCache.invalidate_cache(tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py
index 497e62b790..6b642af613 100644
--- a/api/controllers/console/workspace/trigger_providers.py
+++ b/api/controllers/console/workspace/trigger_providers.py
@@ -1,15 +1,14 @@
import logging
-from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request
-from flask_restx import Resource, reqparse
-from pydantic import BaseModel, Field
+from flask_restx import Resource
+from pydantic import BaseModel, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
-from constants import HIDDEN_VALUE, UNKNOWN_VALUE
+from controllers.common.schema import register_schema_models
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@@ -36,29 +35,38 @@ from ..wraps import (
logger = logging.getLogger(__name__)
-class TriggerSubscriptionUpdateRequest(BaseModel):
- """Request payload for updating a trigger subscription"""
-
- name: str | None = Field(default=None, description="The name for the subscription")
- credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
- parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
- properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
+class TriggerSubscriptionBuilderCreatePayload(BaseModel):
+ credential_type: str = CredentialType.UNAUTHORIZED
-class TriggerSubscriptionVerifyRequest(BaseModel):
- """Request payload for verifying subscription credentials."""
-
- credentials: Mapping[str, Any] = Field(description="The credentials to verify")
+class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
+ credentials: dict[str, Any]
-console_ns.schema_model(
- TriggerSubscriptionUpdateRequest.__name__,
- TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
-)
+class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
+ name: str | None = None
+ parameters: dict[str, Any] | None = None
+ properties: dict[str, Any] | None = None
+ credentials: dict[str, Any] | None = None
-console_ns.schema_model(
- TriggerSubscriptionVerifyRequest.__name__,
- TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
+ @model_validator(mode="after")
+ def check_at_least_one_field(self):
+ if all(v is None for v in self.model_dump().values()):
+ raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
+ return self
+
+
+class TriggerOAuthClientPayload(BaseModel):
+ client_params: dict[str, Any] | None = None
+ enabled: bool | None = None
+
+
+register_schema_models(
+ console_ns,
+ TriggerSubscriptionBuilderCreatePayload,
+ TriggerSubscriptionBuilderVerifyPayload,
+ TriggerSubscriptionBuilderUpdatePayload,
+ TriggerOAuthClientPayload,
)
@@ -127,16 +135,11 @@ class TriggerSubscriptionListApi(Resource):
raise
-parser = reqparse.RequestParser().add_argument(
- "credential_type", type=str, required=False, nullable=True, location="json"
-)
-
-
@console_ns.route(
"/workspaces/current/trigger-provider//subscriptions/builder/create",
)
class TriggerSubscriptionBuilderCreateApi(Resource):
- @console_ns.expect(parser)
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -146,10 +149,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
- args = parser.parse_args()
+ payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
try:
- credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
+ credential_type = CredentialType.of(payload.credential_type)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
@@ -177,18 +180,11 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
-parser_api = (
- reqparse.RequestParser()
- # The credentials of the subscription builder
- .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
-)
-
-
@console_ns.route(
"/workspaces/current/trigger-provider//subscriptions/builder/verify-and-update/",
)
-class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
- @console_ns.expect(parser_api)
+class TriggerSubscriptionBuilderVerifyApi(Resource):
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -198,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
- args = parser_api.parse_args()
+ payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_verify to prevent race conditions
@@ -208,7 +204,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
- credentials=args.get("credentials", None),
+ credentials=payload.credentials,
),
)
except Exception as e:
@@ -216,24 +212,11 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
raise ValueError(str(e)) from e
-parser_update_api = (
- reqparse.RequestParser()
- # The name of the subscription builder
- .add_argument("name", type=str, required=False, nullable=True, location="json")
- # The parameters of the subscription builder
- .add_argument("parameters", type=dict, required=False, nullable=True, location="json")
- # The properties of the subscription builder
- .add_argument("properties", type=dict, required=False, nullable=True, location="json")
- # The credentials of the subscription builder
- .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
-)
-
-
@console_ns.route(
"/workspaces/current/trigger-provider//subscriptions/builder/update/",
)
class TriggerSubscriptionBuilderUpdateApi(Resource):
- @console_ns.expect(parser_update_api)
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -244,7 +227,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
- args = parser_update_api.parse_args()
+ payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@@ -252,10 +235,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
- name=args.get("name", None),
- parameters=args.get("parameters", None),
- properties=args.get("properties", None),
- credentials=args.get("credentials", None),
+ name=payload.name,
+ parameters=payload.parameters,
+ properties=payload.properties,
+ credentials=payload.credentials,
),
)
)
@@ -290,7 +273,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
"/workspaces/current/trigger-provider//subscriptions/builder/build/",
)
class TriggerSubscriptionBuilderBuildApi(Resource):
- @console_ns.expect(parser_update_api)
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -299,7 +282,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
- args = parser_update_api.parse_args()
+ payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
@@ -308,9 +291,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
- name=args.get("name", None),
- parameters=args.get("parameters", None),
- properties=args.get("properties", None),
+ name=payload.name,
+ parameters=payload.parameters,
+ properties=payload.properties,
),
)
return 200
@@ -323,7 +306,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"/workspaces/current/trigger-provider//subscriptions/update",
)
class TriggerSubscriptionUpdateApi(Resource):
- @console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -333,7 +316,7 @@ class TriggerSubscriptionUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
- args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
+ request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
@@ -345,50 +328,32 @@ class TriggerSubscriptionUpdateApi(Resource):
provider_id = TriggerProviderID(subscription.provider_id)
try:
- # rename only
- if (
- args.name is not None
- and args.credentials is None
- and args.parameters is None
- and args.properties is None
- ):
+ # For rename only, just update the name
+ rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
+ # When credential type is UNAUTHORIZED, it indicates the subscription was manually created
+ # For Manually created subscription, they dont have credentials, parameters
+ # They only have name and properties(which is input by user)
+ manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
+ if rename or manually_created:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
- name=args.name,
+ name=request.name,
+ properties=request.properties,
)
return 200
- # rebuild for create automatically by the provider
- match subscription.credential_type:
- case CredentialType.UNAUTHORIZED:
- TriggerProviderService.update_trigger_subscription(
- tenant_id=user.current_tenant_id,
- subscription_id=subscription_id,
- name=args.name,
- properties=args.properties,
- )
- return 200
- case CredentialType.API_KEY | CredentialType.OAUTH2:
- if args.credentials:
- new_credentials: dict[str, Any] = {
- key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
- for key, value in args.credentials.items()
- }
- else:
- new_credentials = subscription.credentials
-
- TriggerProviderService.rebuild_trigger_subscription(
- tenant_id=user.current_tenant_id,
- name=args.name,
- provider_id=provider_id,
- subscription_id=subscription_id,
- credentials=new_credentials,
- parameters=args.parameters or subscription.parameters,
- )
- return 200
- case _:
- raise BadRequest("Invalid credential type")
+ # For the rest cases(API_KEY, OAUTH2)
+ # we need to call third party provider(e.g. GitHub) to rebuild the subscription
+ TriggerProviderService.rebuild_trigger_subscription(
+ tenant_id=user.current_tenant_id,
+ name=request.name,
+ provider_id=provider_id,
+ subscription_id=subscription_id,
+ credentials=request.credentials or subscription.credentials,
+ parameters=request.parameters or subscription.parameters,
+ )
+ return 200
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
@@ -581,13 +546,6 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
-parser_oauth_client = (
- reqparse.RequestParser()
- .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
- .add_argument("enabled", type=bool, required=False, nullable=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/trigger-provider//oauth/client")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@@ -635,7 +593,7 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e)
raise
- @console_ns.expect(parser_oauth_client)
+ @console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -645,15 +603,15 @@ class TriggerOAuthClientManageApi(Resource):
user = current_user
assert user.current_tenant_id is not None
- args = parser_oauth_client.parse_args()
+ payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
- client_params=args.get("client_params"),
- enabled=args.get("enabled"),
+ client_params=payload.client_params,
+ enabled=payload.enabled,
)
except ValueError as e:
@@ -689,7 +647,7 @@ class TriggerOAuthClientManageApi(Resource):
"/workspaces/current/trigger-provider//subscriptions/verify/",
)
class TriggerSubscriptionVerifyApi(Resource):
- @console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -699,9 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource):
user = current_user
assert user.current_tenant_id is not None
- verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
- console_ns.payload
- )
+ verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
result = TriggerProviderService.verify_subscription_credentials(
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index 909a5ce201..94be81d94f 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -20,6 +20,7 @@ from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
+ only_edition_enterprise,
setup_required,
)
from enums.cloud_plan import CloudPlan
@@ -28,6 +29,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
+from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workspace_service import WorkspaceService
@@ -80,6 +82,9 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
+ "trial_credits": fields.Integer,
+ "trial_credits_used": fields.Integer,
+ "next_credit_reset_date": fields.Integer,
}
tenants_fields = {
@@ -285,3 +290,31 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
+
+
+@console_ns.route("/workspaces/current/permission")
+class WorkspacePermissionApi(Resource):
+ """Get workspace permissions for the current workspace."""
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @only_edition_enterprise
+ def get(self):
+ """
+ Get workspace permission settings.
+ Returns permission flags that control workspace features like member invitations and owner transfer.
+ """
+ _, current_tenant_id = current_account_with_tenant()
+
+ if not current_tenant_id:
+ raise ValueError("No current tenant")
+
+ # Get workspace permissions from enterprise service
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
+
+ return {
+ "workspace_id": permission.workspace_id,
+ "allow_member_invite": permission.allow_member_invite,
+ "allow_owner_transfer": permission.allow_owner_transfer,
+ }, 200
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index 95fc006a12..fd928b077d 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -286,13 +286,12 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
- _, current_tenant_id = current_account_with_tenant()
- features = FeatureService.get_features(current_tenant_id)
- if features.is_allow_transfer_workspace:
- return view(*args, **kwargs)
+ from libs.workspace_permission import check_workspace_owner_transfer_permission
- # otherwise, return 403
- abort(403)
+ _, current_tenant_id = current_account_with_tenant()
+ # Check both billing/plan level and workspace policy level permissions
+ check_workspace_owner_transfer_permission(current_tenant_id)
+ return view(*args, **kwargs)
return decorated
diff --git a/api/controllers/fastopenapi.py b/api/controllers/fastopenapi.py
new file mode 100644
index 0000000000..c13f22338b
--- /dev/null
+++ b/api/controllers/fastopenapi.py
@@ -0,0 +1,3 @@
+from fastopenapi.routers import FlaskRouter
+
+console_router = FlaskRouter()
diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py
index 6096a87c56..28ec4b3935 100644
--- a/api/controllers/files/upload.py
+++ b/api/controllers/files/upload.py
@@ -4,18 +4,18 @@ from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
-from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
import services
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
-from fields.file_fields import build_file_model
+from fields.file_fields import FileResponse
from ..common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
+from ..common.schema import register_schema_models
from ..console.wraps import setup_required
from ..files import files_ns
from ..inner_api.plugin.wraps import get_user
@@ -35,6 +35,8 @@ files_ns.schema_model(
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
+register_schema_models(files_ns, FileResponse)
+
@files_ns.route("/upload/for-plugin")
class PluginUploadFileApi(Resource):
@@ -51,7 +53,7 @@ class PluginUploadFileApi(Resource):
415: "Unsupported file type",
}
)
- @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED)
+ @files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__])
def post(self):
"""Upload a file for plugin usage.
@@ -69,7 +71,7 @@ class PluginUploadFileApi(Resource):
"""
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- file: FileStorage | None = request.files.get("file")
+ file = request.files.get("file")
if file is None:
raise Forbidden("File is required.")
@@ -80,8 +82,8 @@ class PluginUploadFileApi(Resource):
user_id = args.user_id
user = get_user(tenant_id, user_id)
- filename: str | None = file.filename
- mimetype: str | None = file.mimetype
+ filename = file.filename
+ mimetype = file.mimetype
if not filename or not mimetype:
raise Forbidden("Invalid request.")
@@ -111,22 +113,22 @@ class PluginUploadFileApi(Resource):
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
# Create a dictionary with all the necessary attributes
- result = {
- "id": tool_file.id,
- "user_id": tool_file.user_id,
- "tenant_id": tool_file.tenant_id,
- "conversation_id": tool_file.conversation_id,
- "file_key": tool_file.file_key,
- "mimetype": tool_file.mimetype,
- "original_url": tool_file.original_url,
- "name": tool_file.name,
- "size": tool_file.size,
- "mime_type": mimetype,
- "extension": extension,
- "preview_url": preview_url,
- }
+ result = FileResponse(
+ id=tool_file.id,
+ name=tool_file.name,
+ size=tool_file.size,
+ extension=extension,
+ mime_type=mimetype,
+ preview_url=preview_url,
+ source_url=tool_file.original_url,
+ original_url=tool_file.original_url,
+ user_id=tool_file.user_id,
+ tenant_id=tool_file.tenant_id,
+ conversation_id=tool_file.conversation_id,
+ file_key=tool_file.file_key,
+ )
- return result, 201
+ return result.model_dump(mode="json"), 201
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py
index 63c373b50f..85ac9336d6 100644
--- a/api/controllers/service_api/app/annotation.py
+++ b/api/controllers/service_api/app/annotation.py
@@ -1,7 +1,7 @@
from typing import Literal
from flask import request
-from flask_restx import Api, Namespace, Resource, fields
+from flask_restx import Namespace, Resource, fields
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
@@ -92,7 +92,7 @@ annotation_list_fields = {
}
-def build_annotation_list_model(api_or_ns: Api | Namespace):
+def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py
index 25d7ccccec..562f5e33cc 100644
--- a/api/controllers/service_api/app/app.py
+++ b/api/controllers/service_api/app/app.py
@@ -1,6 +1,6 @@
from flask_restx import Resource
-from controllers.common.fields import build_parameters_model
+from controllers.common.fields import Parameters
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
@@ -23,7 +23,6 @@ class AppParameterApi(Resource):
}
)
@validate_app_token
- @service_api_ns.marshal_with(build_parameters_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app parameters.
@@ -45,7 +44,8 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", [])
- return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return Parameters.model_validate(parameters).model_dump(mode="json")
@service_api_ns.route("/meta")
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index 40e4bde389..62e8258e25 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -3,8 +3,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
-from flask_restx._http import HTTPStatus
-from pydantic import BaseModel, Field, field_validator, model_validator
+from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
- build_conversation_delete_model,
- build_conversation_infinite_scroll_pagination_model,
- build_simple_conversation_model,
+ ConversationDelete,
+ ConversationInfiniteScrollPagination,
+ SimpleConversation,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
@@ -105,7 +104,6 @@ class ConversationApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
- @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List all conversations for the current user.
@@ -120,7 +118,7 @@ class ConversationApi(Resource):
try:
with Session(db.engine) as session:
- return ConversationService.pagination_by_last_id(
+ pagination = ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
@@ -129,6 +127,13 @@ class ConversationApi(Resource):
invoke_from=InvokeFrom.SERVICE_API,
sort_by=query_args.sort_by,
)
+ adapter = TypeAdapter(SimpleConversation)
+ conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
+ return ConversationInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=conversations,
+ ).model_dump(mode="json")
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@@ -146,7 +151,6 @@ class ConversationDetailApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
- @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
@@ -159,7 +163,7 @@ class ConversationDetailApi(Resource):
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}, 204
+ return ConversationDelete(result="success").model_dump(mode="json"), 204
@service_api_ns.route("/conversations//name")
@@ -176,7 +180,6 @@ class ConversationRenameApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
- @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
def post(self, app_model: App, end_user: EndUser, c_id):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
@@ -188,7 +191,14 @@ class ConversationRenameApi(Resource):
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try:
- return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
+ conversation = ConversationService.rename(
+ app_model, conversation_id, end_user, payload.name, payload.auto_generate
+ )
+ return (
+ TypeAdapter(SimpleConversation)
+ .validate_python(conversation, from_attributes=True)
+ .model_dump(mode="json")
+ )
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py
index ffe4e0b492..6f6dadf768 100644
--- a/api/controllers/service_api/app/file.py
+++ b/api/controllers/service_api/app/file.py
@@ -10,13 +10,16 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
-from fields.file_fields import build_file_model
+from fields.file_fields import FileResponse
from models import App, EndUser
from services.file_service import FileService
+register_schema_models(service_api_ns, FileResponse)
+
@service_api_ns.route("/files/upload")
class FileApi(Resource):
@@ -31,8 +34,8 @@ class FileApi(Resource):
415: "Unsupported file type",
}
)
- @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
- @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED)
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
+ @service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
def post(self, app_model: App, end_user: EndUser):
"""Upload a file for use in conversations.
@@ -64,4 +67,5 @@ class FileApi(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
- return upload_file, 201
+ response = FileResponse.model_validate(upload_file, from_attributes=True)
+ return response.model_dump(mode="json"), 201
diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py
index d342f4e661..8981bbd7d5 100644
--- a/api/controllers/service_api/app/message.py
+++ b/api/controllers/service_api/app/message.py
@@ -1,11 +1,10 @@
-import json
import logging
from typing import Literal
from uuid import UUID
from flask import request
-from flask_restx import Namespace, Resource, fields
-from pydantic import BaseModel, Field
+from flask_restx import Resource
+from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
@@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
-from fields.conversation_fields import build_message_file_model
-from fields.message_fields import build_agent_thought_model, build_feedback_model
-from fields.raws import FilesContainedField
-from libs.helper import TimestampField
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@@ -48,49 +45,6 @@ class FeedbackListQuery(BaseModel):
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
-def build_message_model(api_or_ns: Namespace):
- """Build the message model for the API or Namespace."""
- # First build the nested models
- feedback_model = build_feedback_model(api_or_ns)
- agent_thought_model = build_agent_thought_model(api_or_ns)
- message_file_model = build_message_file_model(api_or_ns)
-
- # Then build the message fields with nested models
- message_fields = {
- "id": fields.String,
- "conversation_id": fields.String,
- "parent_message_id": fields.String,
- "inputs": FilesContainedField,
- "query": fields.String,
- "answer": fields.String(attribute="re_sign_file_url_answer"),
- "message_files": fields.List(fields.Nested(message_file_model)),
- "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
- "retriever_resources": fields.Raw(
- attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
- if obj.message_metadata
- else []
- ),
- "created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
- "status": fields.String,
- "error": fields.String,
- }
- return api_or_ns.model("Message", message_fields)
-
-
-def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
- """Build the message infinite scroll pagination model for the API or Namespace."""
- # Build the nested message model first
- message_model = build_message_model(api_or_ns)
-
- message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_model)),
- }
- return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
-
-
@service_api_ns.route("/messages")
class MessageListApi(Resource):
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@@ -104,7 +58,6 @@ class MessageListApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
- @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List messages in a conversation.
@@ -119,9 +72,16 @@ class MessageListApi(Resource):
first_id = str(query_args.first_id) if query_args.first_id else None
try:
- return MessageService.pagination_by_first_id(
+ pagination = MessageService.pagination_by_first_id(
app_model, end_user, conversation_id, first_id, query_args.limit
)
+ adapter = TypeAdapter(MessageListItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return MessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -162,7 +122,7 @@ class MessageFeedbackApi(Resource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@service_api_ns.route("/app/feedbacks")
diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py
index 9f8324a84e..8b47a887bb 100644
--- a/api/controllers/service_api/app/site.py
+++ b/api/controllers/service_api/app/site.py
@@ -1,7 +1,7 @@
from flask_restx import Resource
from werkzeug.exceptions import Forbidden
-from controllers.common.fields import build_site_model
+from controllers.common.fields import Site as SiteResponse
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
@@ -23,7 +23,6 @@ class AppSiteApi(Resource):
}
)
@validate_app_token
- @service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app site info.
@@ -38,4 +37,4 @@ class AppSiteApi(Resource):
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
- return site
+ return SiteResponse.model_validate(site).model_dump(mode="json")
diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py
index 5d10c3e8fd..2d24e0f5f9 100644
--- a/api/controllers/service_api/app/workflow.py
+++ b/api/controllers/service_api/app/workflow.py
@@ -3,7 +3,7 @@ from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
-from flask_restx import Api, Namespace, Resource, fields
+from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@@ -110,7 +110,7 @@ workflow_run_fields = {
}
-def build_workflow_run_model(api_or_ns: Api | Namespace):
+def build_workflow_run_model(api_or_ns: Namespace):
"""Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields)
diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py
index 4f91f40c55..28864a140a 100644
--- a/api/controllers/service_api/dataset/dataset.py
+++ b/api/controllers/service_api/dataset/dataset.py
@@ -2,7 +2,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
- validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
@@ -27,6 +26,14 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+service_api_ns.schema_model(
+ DatasetPermissionEnum.__name__,
+ TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
@@ -88,6 +95,14 @@ class TagUnbindingPayload(BaseModel):
target_id: str
+class DatasetListQuery(BaseModel):
+ page: int = Field(default=1, description="Page number")
+ limit: int = Field(default=20, description="Number of items per page")
+ keyword: str | None = Field(default=None, description="Search keyword")
+ include_all: bool = Field(default=False, description="Include all datasets")
+ tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
+
+
register_schema_models(
service_api_ns,
DatasetCreatePayload,
@@ -97,6 +112,7 @@ register_schema_models(
TagDeletePayload,
TagBindingPayload,
TagUnbindingPayload,
+ DatasetListQuery,
)
@@ -114,15 +130,11 @@ class DatasetListApi(DatasetApiResource):
)
def get(self, tenant_id):
"""Resource for getting datasets."""
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
+ query = DatasetListQuery.model_validate(request.args.to_dict())
# provider = request.args.get("provider", default="vendor")
- search = request.args.get("keyword", default=None, type=str)
- tag_ids = request.args.getlist("tag_ids")
- include_all = request.args.get("include_all", default="false").lower() == "true"
datasets, total = DatasetService.get_datasets(
- page, limit, tenant_id, current_user, search, tag_ids, include_all
+ query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
)
# check embedding setting
provider_manager = ProviderManager()
@@ -148,7 +160,13 @@ class DatasetListApi(DatasetApiResource):
item["embedding_available"] = False
else:
item["embedding_available"] = True
- response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
+ response = {
+ "data": data,
+ "has_more": len(datasets) == query.limit,
+ "limit": query.limit,
+ "total": total,
+ "page": query.page,
+ }
return response, 200
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@@ -460,9 +478,8 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
- @validate_dataset_token
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
- def get(self, _, dataset_id):
+ def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@@ -482,8 +499,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
- @validate_dataset_token
- def post(self, _, dataset_id):
+ def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -506,8 +522,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
- @validate_dataset_token
- def patch(self, _, dataset_id):
+ def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@@ -533,9 +548,8 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
- @validate_dataset_token
@edit_permission_required
- def delete(self, _, dataset_id):
+ def delete(self, _):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
@@ -555,8 +569,7 @@ class DatasetTagBindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
- @validate_dataset_token
- def post(self, _, dataset_id):
+ def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -580,8 +593,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
- @validate_dataset_token
- def post(self, _, dataset_id):
+ def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -604,7 +616,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
- @validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index c800c0e4e1..c85c1cf81e 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -16,6 +16,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_enum_models, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
@@ -29,12 +30,20 @@ from controllers.service_api.wraps import (
cloud_edition_billing_resource_check,
)
from core.errors.error import ProviderTokenNotInitError
+from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DatasetService, DocumentService
-from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
+from services.entities.knowledge_entities.knowledge_entities import (
+ KnowledgeConfig,
+ PreProcessingRule,
+ ProcessRule,
+ RetrievalModel,
+ Rule,
+ Segmentation,
+)
from services.file_service import FileService
@@ -69,8 +78,26 @@ class DocumentTextUpdate(BaseModel):
return self
-for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
- service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
+class DocumentListQuery(BaseModel):
+ page: int = Field(default=1, description="Page number")
+ limit: int = Field(default=20, description="Number of items per page")
+ keyword: str | None = Field(default=None, description="Search keyword")
+ status: str | None = Field(default=None, description="Document status filter")
+
+
+register_enum_models(service_api_ns, RetrievalMethod)
+
+register_schema_models(
+ service_api_ns,
+ ProcessRule,
+ RetrievalModel,
+ DocumentTextCreatePayload,
+ DocumentTextUpdate,
+ DocumentListQuery,
+ Rule,
+ PreProcessingRule,
+ Segmentation,
+)
@service_api_ns.route(
@@ -261,17 +288,6 @@ class DocumentAddByFileApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
- args = {}
- if "data" in request.form:
- args = json.loads(request.form["data"])
- if "doc_form" not in args:
- args["doc_form"] = "text_model"
- if "doc_language" not in args:
- args["doc_language"] = "English"
-
- # get dataset info
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -280,6 +296,18 @@ class DocumentAddByFileApi(DatasetApiResource):
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
+ args = {}
+ if "data" in request.form:
+ args = json.loads(request.form["data"])
+ if "doc_form" not in args:
+ args["doc_form"] = dataset.chunk_structure or "text_model"
+ if "doc_language" not in args:
+ args["doc_language"] = "English"
+
+ # get dataset info
+ dataset_id = str(dataset_id)
+ tenant_id = str(tenant_id)
+
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
if not indexing_technique:
raise ValueError("indexing_technique is required.")
@@ -370,17 +398,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
- args = {}
- if "data" in request.form:
- args = json.loads(request.form["data"])
- if "doc_form" not in args:
- args["doc_form"] = "text_model"
- if "doc_language" not in args:
- args["doc_language"] = "English"
-
- # get dataset info
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -389,6 +406,18 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
+ args = {}
+ if "data" in request.form:
+ args = json.loads(request.form["data"])
+ if "doc_form" not in args:
+ args["doc_form"] = dataset.chunk_structure or "text_model"
+ if "doc_language" not in args:
+ args["doc_language"] = "English"
+
+ # get dataset info
+ dataset_id = str(dataset_id)
+ tenant_id = str(tenant_id)
+
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
@@ -458,34 +487,33 @@ class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
- status = request.args.get("status", default=None, type=str)
+ query_params = DocumentListQuery.model_validate(request.args.to_dict())
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
- if status:
- query = DocumentService.apply_display_status_filter(query, status)
+ if query_params.status:
+ query = DocumentService.apply_display_status_filter(query, query_params.status)
- if search:
- search = f"%{search}%"
+ if query_params.keyword:
+ search = f"%{query_params.keyword}%"
query = query.where(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position))
- paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
+ paginated_documents = db.paginate(
+ select=query, page=query_params.page, per_page=query_params.limit, max_per_page=100, error_out=False
+ )
documents = paginated_documents.items
response = {
"data": marshal(documents, document_fields),
- "has_more": len(documents) == limit,
- "limit": limit,
+ "has_more": len(documents) == query_params.limit,
+ "limit": query_params.limit,
"total": paginated_documents.total,
- "page": page,
+ "page": query_params.page,
}
return response
diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py
index d81287d56f..8dbb690901 100644
--- a/api/controllers/service_api/dataset/hit_testing.py
+++ b/api/controllers/service_api/dataset/hit_testing.py
@@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
- args = self.parse_args()
+ args = self.parse_args(service_api_ns.payload)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py
index aab25c1af3..b8d9508004 100644
--- a/api/controllers/service_api/dataset/metadata.py
+++ b/api/controllers/service_api/dataset/metadata.py
@@ -11,7 +11,9 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_bill
from fields.dataset_fields import dataset_metadata_fields
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
+ DocumentMetadataOperation,
MetadataArgs,
+ MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@@ -22,7 +24,13 @@ class MetadataUpdatePayload(BaseModel):
register_schema_model(service_api_ns, MetadataUpdatePayload)
-register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
+register_schema_models(
+ service_api_ns,
+ MetadataArgs,
+ MetadataDetail,
+ DocumentMetadataOperation,
+ MetadataOperationData,
+)
@service_api_ns.route("/datasets//metadata")
diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py
index 0a2017e2bd..70b5030237 100644
--- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py
+++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py
@@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource):
pipeline=pipeline,
user=current_user,
args=payload.model_dump(),
- invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
+ invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER,
streaming=payload.response_mode == "streaming",
)
diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py
index b242fd2c3e..95679e6fcb 100644
--- a/api/controllers/service_api/dataset/segment.py
+++ b/api/controllers/service_api/dataset/segment.py
@@ -60,6 +60,7 @@ register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
+ SegmentUpdateArgs,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,
diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py
index db3b93a4dc..62ea532eac 100644
--- a/api/controllers/web/app.py
+++ b/api/controllers/web/app.py
@@ -1,7 +1,7 @@
import logging
from flask import request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized
@@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
@@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", [])
- return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@web_ns.route("/meta")
diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py
index 86e19423e5..e76649495a 100644
--- a/api/controllers/web/conversation.py
+++ b/api/controllers/web/conversation.py
@@ -1,14 +1,21 @@
-from flask_restx import fields, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from typing import Literal
+
+from flask import request
+from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
-from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
+from fields.conversation_fields import (
+ ConversationInfiniteScrollPagination,
+ ResultResponse,
+ SimpleConversation,
+)
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
@@ -16,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
from services.web_conversation_service import WebConversationService
+class ConversationListQuery(BaseModel):
+ last_id: str | None = None
+ limit: int = Field(default=20, ge=1, le=100)
+ pinned: bool | None = None
+ sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at"
+
+ @field_validator("last_id")
+ @classmethod
+ def validate_last_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class ConversationRenamePayload(BaseModel):
+ name: str | None = None
+ auto_generate: bool = False
+
+ @model_validator(mode="after")
+ def validate_name_requirement(self):
+ if not self.auto_generate:
+ if self.name is None or not self.name.strip():
+ raise ValueError("name is required when auto_generate is false")
+ return self
+
+
+register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
+
+
@web_ns.route("/conversations")
class ConversationListApi(WebApiResource):
@web_ns.doc("Get Conversation List")
@@ -54,54 +90,39 @@ class ConversationListApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("last_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- .add_argument("pinned", type=str, choices=["true", "false", None], location="args")
- .add_argument(
- "sort_by",
- type=str,
- choices=["created_at", "-created_at", "updated_at", "-updated_at"],
- required=False,
- default="-updated_at",
- location="args",
- )
- )
- args = parser.parse_args()
-
- pinned = None
- if "pinned" in args and args["pinned"] is not None:
- pinned = args["pinned"] == "true"
+ raw_args = request.args.to_dict()
+ query = ConversationListQuery.model_validate(raw_args)
try:
with Session(db.engine) as session:
- return WebConversationService.pagination_by_last_id(
+ pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
- last_id=args["last_id"],
- limit=args["limit"],
+ last_id=query.last_id,
+ limit=query.limit,
invoke_from=InvokeFrom.WEB_APP,
- pinned=pinned,
- sort_by=args["sort_by"],
+ pinned=query.pinned,
+ sort_by=query.sort_by,
)
+ adapter = TypeAdapter(SimpleConversation)
+ conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
+ return ConversationInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=conversations,
+ ).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@web_ns.route("/conversations/")
class ConversationApi(WebApiResource):
- delete_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Delete Conversation")
@web_ns.doc(description="Delete a specific conversation.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -115,7 +136,6 @@ class ConversationApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(delete_response_fields)
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -126,7 +146,7 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}, 204
+ return ResultResponse(result="success").model_dump(mode="json"), 204
@web_ns.route("/conversations//name")
@@ -155,7 +175,6 @@ class ConversationRenameApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -163,25 +182,23 @@ class ConversationRenameApi(WebApiResource):
conversation_id = str(c_id)
- parser = (
- reqparse.RequestParser()
- .add_argument("name", type=str, required=False, location="json")
- .add_argument("auto_generate", type=bool, required=False, default=False, location="json")
- )
- args = parser.parse_args()
+ payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
try:
- return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
+ conversation = ConversationService.rename(
+ app_model, conversation_id, end_user, payload.name, payload.auto_generate
+ )
+ return (
+ TypeAdapter(SimpleConversation)
+ .validate_python(conversation, from_attributes=True)
+ .model_dump(mode="json")
+ )
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@web_ns.route("/conversations//pin")
class ConversationPinApi(WebApiResource):
- pin_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Pin Conversation")
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -195,7 +212,6 @@ class ConversationPinApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(pin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -208,15 +224,11 @@ class ConversationPinApi(WebApiResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/conversations//unpin")
class ConversationUnPinApi(WebApiResource):
- unpin_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Unpin Conversation")
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -230,7 +242,6 @@ class ConversationUnPinApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(unpin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -239,4 +250,4 @@ class ConversationUnPinApi(WebApiResource):
conversation_id = str(c_id)
WebConversationService.unpin(app_model, conversation_id, end_user)
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py
index cce3dae95d..2540bf02f4 100644
--- a/api/controllers/web/feature.py
+++ b/api/controllers/web/feature.py
@@ -17,5 +17,15 @@ class SystemFeatureApi(Resource):
Returns:
dict: System feature configuration object
+
+ This endpoint is akin to the `SystemFeatureApi` endpoint in api/controllers/console/feature.py,
+ except it is intended for use by the web app, instead of the console dashboard.
+
+ NOTE: This endpoint is unauthenticated by design, as it provides system features
+ data required for webapp initialization.
+
+ Authentication would create circular dependency (can't authenticate without webapp loading).
+
+ Only non-sensitive configuration data should be returned by this endpoint.
"""
return FeatureService.get_system_features().model_dump()
diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py
index 80ad61e549..0036c90800 100644
--- a/api/controllers/web/files.py
+++ b/api/controllers/web/files.py
@@ -1,5 +1,4 @@
from flask import request
-from flask_restx import marshal_with
import services
from controllers.common.errors import (
@@ -9,12 +8,15 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
-from fields.file_fields import build_file_model
+from fields.file_fields import FileResponse
from services.file_service import FileService
+register_schema_models(web_ns, FileResponse)
+
@web_ns.route("/files/upload")
class FileApi(WebApiResource):
@@ -28,7 +30,7 @@ class FileApi(WebApiResource):
415: "Unsupported file type",
}
)
- @marshal_with(build_file_model(web_ns))
+ @web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
def post(self, app_model, end_user):
"""Upload a file for use in web applications.
@@ -81,4 +83,5 @@ class FileApi(WebApiResource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
- return upload_file, 201
+ response = FileResponse.model_validate(upload_file, from_attributes=True)
+ return response.model_dump(mode="json"), 201
diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py
index 690b76655f..91d206f727 100644
--- a/api/controllers/web/forgot_password.py
+++ b/api/controllers/web/forgot_password.py
@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
@@ -22,7 +21,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
-from models import Account
+from models.account import Account
from services.account_service import AccountService
@@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
+ request_email = payload.email
+ normalized_email = request_email.lower()
+
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
raise AuthenticationFailedError()
else:
- token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
+ token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
return {"result": "success", "data": token}
@@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
- user_email = payload.email
+ user_email = payload.email.lower()
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if payload.code != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(payload.email)
+ AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=payload.code, additional_data={"phase": "reset"}
+ token_email, code=payload.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(payload.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_forgot_password_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@web_ns.route("/forgot-password/resets")
@@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py
index 538d0c44be..a824f6d487 100644
--- a/api/controllers/web/login.py
+++ b/api/controllers/web/login.py
@@ -1,19 +1,26 @@
from flask import make_response, request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
from jwt import InvalidTokenError
+from pydantic import BaseModel, Field, field_validator
import services
from configs import dify_config
+from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
InvalidEmailError,
)
from controllers.console.error import AccountBannedError
-from controllers.console.wraps import only_edition_enterprise, setup_required
+from controllers.console.wraps import (
+ decrypt_code_field,
+ decrypt_password_field,
+ only_edition_enterprise,
+ setup_required,
+)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
-from libs.helper import email
+from libs.helper import EmailStr
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
@@ -25,10 +32,35 @@ from services.app_service import AppService
from services.webapp_auth_service import WebAppAuthService
+class LoginPayload(BaseModel):
+ email: EmailStr
+ password: str
+
+ @field_validator("password")
+ @classmethod
+ def validate_password(cls, value: str) -> str:
+ return valid_password(value)
+
+
+class EmailCodeLoginSendPayload(BaseModel):
+ email: EmailStr
+ language: str | None = None
+
+
+class EmailCodeLoginVerifyPayload(BaseModel):
+ email: EmailStr
+ code: str
+ token: str = Field(min_length=1)
+
+
+register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload)
+
+
@web_ns.route("/login")
class LoginApi(Resource):
"""Resource for web app email/password login."""
+ @web_ns.expect(web_ns.models[LoginPayload.__name__])
@setup_required
@only_edition_enterprise
@web_ns.doc("web_app_login")
@@ -42,17 +74,13 @@ class LoginApi(Resource):
404: "Account not found",
}
)
+ @decrypt_password_field
def post(self):
"""Authenticate user and login."""
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("password", type=valid_password, required=True, location="json")
- )
- args = parser.parse_args()
+ payload = LoginPayload.model_validate(web_ns.payload or {})
try:
- account = WebAppAuthService.authenticate(args["email"], args["password"])
+ account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
@@ -139,6 +167,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@only_edition_enterprise
@web_ns.doc("send_email_code_login")
@web_ns.doc(description="Send email verification code for login")
+ @web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__])
@web_ns.doc(
responses={
200: "Email code sent successfully",
@@ -147,19 +176,14 @@ class EmailCodeLoginSendEmailApi(Resource):
}
)
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {})
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if payload.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
- account = WebAppAuthService.get_user_through_email(args["email"])
+ account = WebAppAuthService.get_user_through_email(payload.email)
if account is None:
raise AuthenticationFailedError()
else:
@@ -173,6 +197,7 @@ class EmailCodeLoginApi(Resource):
@only_edition_enterprise
@web_ns.doc("verify_email_code_login")
@web_ns.doc(description="Verify email code and complete login")
+ @web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__])
@web_ns.doc(
responses={
200: "Email code verified and login successful",
@@ -181,34 +206,33 @@ class EmailCodeLoginApi(Resource):
404: "Account not found",
}
)
+ @decrypt_code_field
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, location="json")
- )
- args = parser.parse_args()
+ payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {})
- user_email = args["email"]
+ user_email = payload.email.lower()
- token_data = WebAppAuthService.get_email_code_login_data(args["token"])
+ token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args["email"]:
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+ if normalized_token_email != user_email:
raise InvalidEmailError()
- if token_data["code"] != args["code"]:
+ if token_data["code"] != payload.code:
raise EmailCodeError()
- WebAppAuthService.revoke_email_code_login_token(args["token"])
- account = WebAppAuthService.get_user_through_email(user_email)
+ WebAppAuthService.revoke_email_code_login_token(payload.token)
+ account = WebAppAuthService.get_user_through_email(token_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
- AccountService.reset_login_error_rate_limit(args["email"])
+ AccountService.reset_login_error_rate_limit(user_email)
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index a02d226762..80035ba818 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -2,8 +2,7 @@ import logging
from typing import Literal
from flask import request
-from flask_restx import fields, marshal_with
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@@ -22,11 +21,10 @@ from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
-from fields.conversation_fields import message_file_fields
-from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
-from fields.raws import FilesContainedField
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
-from libs.helper import TimestampField, uuid_value
+from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -70,30 +68,6 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message
@web_ns.route("/messages")
class MessageListApi(WebApiResource):
- message_fields = {
- "id": fields.String,
- "conversation_id": fields.String,
- "parent_message_id": fields.String,
- "inputs": FilesContainedField,
- "query": fields.String,
- "answer": fields.String(attribute="re_sign_file_url_answer"),
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
- "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
- "extra_contents": fields.List(fields.Raw),
- "created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
- "metadata": fields.Raw(attribute="message_metadata_dict"),
- "status": fields.String,
- "error": fields.String,
- }
-
- message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
- }
-
@web_ns.doc("Get Message List")
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
@web_ns.doc(
@@ -122,7 +96,6 @@ class MessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -132,9 +105,16 @@ class MessageListApi(WebApiResource):
query = MessageListQuery.model_validate(raw_args)
try:
- return MessageService.pagination_by_first_id(
+ pagination = MessageService.pagination_by_first_id(
app_model, end_user, query.conversation_id, query.first_id, query.limit
)
+ adapter = TypeAdapter(WebMessageListItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return WebMessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -143,10 +123,6 @@ class MessageListApi(WebApiResource):
@web_ns.route("/messages//feedbacks")
class MessageFeedbackApi(WebApiResource):
- feedback_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Create Message Feedback")
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@@ -171,7 +147,6 @@ class MessageFeedbackApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(feedback_response_fields)
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
@@ -188,7 +163,7 @@ class MessageFeedbackApi(WebApiResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/messages//more-like-this")
@@ -248,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource):
@web_ns.route("/messages//suggested-questions")
class MessageSuggestedQuestionApi(WebApiResource):
- suggested_questions_response_fields = {
- "data": fields.List(fields.String),
- }
-
@web_ns.doc("Get Suggested Questions")
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@@ -265,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(suggested_questions_response_fields)
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -278,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
)
# questions is a list of strings, not a list of Message objects
- # so we can directly return it
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
@@ -297,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
logger.exception("internal server error.")
raise InternalServerError()
- return {"data": questions}
+ return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py
index c1f976829f..b08b3fe858 100644
--- a/api/controllers/web/remote_files.py
+++ b/api/controllers/web/remote_files.py
@@ -1,7 +1,6 @@
import urllib.parse
import httpx
-from flask_restx import marshal_with
from pydantic import BaseModel, Field, HttpUrl
import services
@@ -14,7 +13,7 @@ from controllers.common.errors import (
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
-from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
+from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from services.file_service import FileService
from ..common.schema import register_schema_models
@@ -26,7 +25,7 @@ class RemoteFileUploadPayload(BaseModel):
url: HttpUrl = Field(description="Remote file URL")
-register_schema_models(web_ns, RemoteFileUploadPayload)
+register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl)
@web_ns.route("/remote-files/")
@@ -41,7 +40,7 @@ class RemoteFileInfoApi(WebApiResource):
500: "Failed to fetch remote file",
}
)
- @marshal_with(build_remote_file_info_model(web_ns))
+ @web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__])
def get(self, app_model, end_user, url):
"""Get information about a remote file.
@@ -65,10 +64,11 @@ class RemoteFileInfoApi(WebApiResource):
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
- return {
- "file_type": resp.headers.get("Content-Type", "application/octet-stream"),
- "file_length": int(resp.headers.get("Content-Length", -1)),
- }
+ info = RemoteFileInfo(
+ file_type=resp.headers.get("Content-Type", "application/octet-stream"),
+ file_length=int(resp.headers.get("Content-Length", -1)),
+ )
+ return info.model_dump(mode="json")
@web_ns.route("/remote-files/upload")
@@ -84,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource):
500: "Failed to fetch remote file",
}
)
- @marshal_with(build_file_with_signed_url_model(web_ns))
+ @web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__])
def post(self, app_model, end_user):
"""Upload a file from a remote URL.
@@ -139,13 +139,14 @@ class RemoteFileUploadApi(WebApiResource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError
- return {
- "id": upload_file.id,
- "name": upload_file.name,
- "size": upload_file.size,
- "extension": upload_file.extension,
- "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
- "mime_type": upload_file.mime_type,
- "created_by": upload_file.created_by,
- "created_at": upload_file.created_at,
- }, 201
+ payload1 = FileWithSignedUrl(
+ id=upload_file.id,
+ name=upload_file.name,
+ size=upload_file.size,
+ extension=upload_file.extension,
+ url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
+ mime_type=upload_file.mime_type,
+ created_by=upload_file.created_by,
+ created_at=int(upload_file.created_at.timestamp()),
+ )
+ return payload1.model_dump(mode="json"), 201
diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py
index 865f3610a7..29993100f6 100644
--- a/api/controllers/web/saved_message.py
+++ b/api/controllers/web/saved_message.py
@@ -1,40 +1,32 @@
-from flask_restx import fields, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from flask import request
+from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
-from fields.conversation_fields import message_file_fields
-from libs.helper import TimestampField, uuid_value
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
+from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
-feedback_fields = {"rating": fields.String}
-message_fields = {
- "id": fields.String,
- "inputs": fields.Raw,
- "query": fields.String,
- "answer": fields.String,
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
- "created_at": TimestampField,
-}
+class SavedMessageListQuery(BaseModel):
+ last_id: UUIDStrOrEmpty | None = None
+ limit: int = Field(default=20, ge=1, le=100)
+
+
+class SavedMessageCreatePayload(BaseModel):
+ message_id: UUIDStrOrEmpty
+
+
+register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
- saved_message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
- }
-
- post_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Get Saved Messages")
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
@web_ns.doc(
@@ -58,19 +50,21 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("last_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- )
- args = parser.parse_args()
+ raw_args = request.args.to_dict()
+ query = SavedMessageListQuery.model_validate(raw_args)
- return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
+ pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
+ adapter = TypeAdapter(SavedMessageItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return SavedMessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
@web_ns.doc("Save Message")
@web_ns.doc(description="Save a specific message for later reference.")
@@ -89,28 +83,22 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(post_response_fields)
def post(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
- parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
- args = parser.parse_args()
+ payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
try:
- SavedMessageService.save(app_model, end_user, args["message_id"])
+ SavedMessageService.save(app_model, end_user, payload.message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/saved-messages/")
class SavedMessageApi(WebApiResource):
- delete_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Delete Saved Message")
@web_ns.doc(description="Remove a message from saved messages.")
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
@@ -124,7 +112,6 @@ class SavedMessageApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(delete_response_fields)
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)
@@ -133,4 +120,4 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
- return {"result": "success"}, 204
+ return ResultResponse(result="success").model_dump(mode="json"), 204
diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py
index 3cbb07a296..95d8c6d5a5 100644
--- a/api/controllers/web/workflow.py
+++ b/api/controllers/web/workflow.py
@@ -1,8 +1,10 @@
import logging
+from typing import Any
-from flask_restx import reqparse
+from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
CompletionRequestError,
@@ -27,19 +29,22 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
+
+class WorkflowRunPayload(BaseModel):
+ inputs: dict[str, Any] = Field(description="Input variables for the workflow")
+ files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
+
+
logger = logging.getLogger(__name__)
+register_schema_models(web_ns, WorkflowRunPayload)
+
@web_ns.route("/workflows/run")
class WorkflowRunApi(WebApiResource):
@web_ns.doc("Run Workflow")
@web_ns.doc(description="Execute a workflow with provided inputs and files.")
- @web_ns.doc(
- params={
- "inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
- "files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
- }
- )
+ @web_ns.expect(web_ns.models[WorkflowRunPayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@@ -58,12 +63,8 @@ class WorkflowRunApi(WebApiResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- .add_argument("files", type=list, required=False, location="json")
- )
- args = parser.parse_args()
+ payload = WorkflowRunPayload.model_validate(web_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
try:
response = AppGenerateService.generate(
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index c196dbbdf1..3c6d36afe4 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -1,6 +1,7 @@
import json
import logging
import uuid
+from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
@@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
+from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
+ tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
- message_unit_price=0,
- message_price_unit=0,
+ message_unit_price=Decimal(0),
+ message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
- answer_unit_price=0,
- answer_price_unit=0,
+ answer_unit_price=Decimal(0),
+ answer_price_unit=Decimal("0.001"),
tokens=0,
- total_price=0,
+ total_price=Decimal(0),
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
- created_by_role="account",
+ created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)
@@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
- agent_thought.thought += thought
+ existing_thought = agent_thought.thought or ""
+ agent_thought.thought = f"{existing_thought}{thought}"
if tool_name:
agent_thought.tool = tool_name
@@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
- tools = agent_thought.tool
- if tools:
- tools = tools.split(";")
+ tool_names_raw = agent_thought.tool
+ if tool_names_raw:
+ tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
- try:
- tool_inputs = json.loads(agent_thought.tool_input)
- except Exception:
- tool_inputs = {tool: {} for tool in tools}
- try:
- tool_responses = json.loads(agent_thought.observation)
- except Exception:
- tool_responses = dict.fromkeys(tools, agent_thought.observation)
+ tool_input_payload = agent_thought.tool_input
+ if tool_input_payload:
+ try:
+ tool_inputs = json.loads(tool_input_payload)
+ except Exception:
+ tool_inputs = {tool: {} for tool in tool_names}
+ else:
+ tool_inputs = {tool: {} for tool in tool_names}
- for tool in tools:
+ observation_payload = agent_thought.observation
+ if observation_payload:
+ try:
+ tool_responses = json.loads(observation_payload)
+ except Exception:
+ tool_responses = dict.fromkeys(tool_names, observation_payload)
+ else:
+ tool_responses = dict.fromkeys(tool_names, observation_payload)
+
+ for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
- if not tools:
+ if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py
index b32e35d0ca..a55f2d0f5f 100644
--- a/api/core/agent/cot_agent_runner.py
+++ b/api/core/agent/cot_agent_runner.py
@@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
+from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
+ # Check if max iteration is reached and model still wants to call tools
+ if iteration_step == max_iteration_steps and scratchpad.action:
+ if scratchpad.action.action_name.lower() != "final answer":
+ raise AgentMaxIterationError(app_config.agent.max_iteration)
+
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index dcc1326b33..7c5c9136a7 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
+from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@@ -187,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
- assistant_message = AssistantPromptMessage(content="", tool_calls=[])
+ assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@@ -199,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
- else:
- assistant_message.content = response
self._current_thoughts.append(assistant_message)
@@ -222,6 +221,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
final_answer += response + "\n"
+ # Check if max iteration is reached and model still wants to call tools
+ if iteration_step == max_iteration_steps and tool_calls:
+ raise AgentMaxIterationError(app_config.agent.max_iteration)
+
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py
index 307af3747c..13c51529cc 100644
--- a/api/core/app/app_config/entities.py
+++ b/api/core/app/app_config/entities.py
@@ -1,4 +1,3 @@
-import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
- json_schema: str | None = Field(default=None)
+ json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
- def validate_json_schema(cls, schema: str | None) -> str | None:
+ def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
-
try:
- json_schema = json.loads(schema)
- except json.JSONDecodeError:
- raise ValueError(f"invalid json_schema value {schema}")
-
- try:
- Draft7Validator.check_schema(json_schema)
+ Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py
index e4b308a6f6..c21c494efe 100644
--- a/api/core/app/apps/advanced_chat/app_config_manager.py
+++ b/api/core/app/apps/advanced_chat/app_config_manager.py
@@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
-
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index fd913b807d..9249e3cc70 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
+
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
-from typing import Any, Literal, TypeVar, Union, overload
+from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -13,6 +15,9 @@ from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from constants import UUID_NIL
+
+if TYPE_CHECKING:
+ from controllers.console.app.workflow import LoopNodeRunPayload
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
@@ -347,7 +352,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
- args: Mapping,
+ args: LoopNodeRunPayload,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
@@ -363,7 +368,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if not node_id:
raise ValueError("node_id is required")
- if args.get("inputs") is None:
+ if args.inputs is None:
raise ValueError("inputs is required")
# convert to app config
@@ -381,7 +386,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
- single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
+ single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index f06a6f9e9b..f8cd68a9bf 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -20,13 +20,15 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
+from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
+from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@@ -37,9 +39,9 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from models import Workflow
-from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
+from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__)
@@ -105,6 +107,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
+ invoke_from = self.application_generate_entity.invoke_from
+ if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
+ invoke_from = InvokeFrom.DEBUGGER
+ user_from = self._resolve_user_from(invoke_from)
+
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
@@ -156,8 +163,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@@ -169,6 +176,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
+ user_from=user_from,
+ invoke_from=invoke_from,
)
db.session.close()
@@ -186,12 +195,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
- user_from=(
- UserFrom.ACCOUNT
- if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
- else UserFrom.END_USER
- ),
- invoke_from=self.application_generate_entity.invoke_from,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
@@ -214,6 +219,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
+ conversation_variable_layer = ConversationVariablePersistenceLayer(
+ ConversationVariableUpdater(session_factory.get_session_maker())
+ )
+ workflow_entry.graph_engine.layer(conversation_variable_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
@@ -323,7 +332,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
- def _initialize_conversation_variables(self) -> list[VariableUnion]:
+ def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@@ -348,7 +357,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
- return cast(list[VariableUnion], conversation_variables)
+ return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 2760466a3b..8b6b8f227b 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True,
+ message_id=message.id,
+ user_id=application_generate_entity.user_id,
+ tenant_id=app_config.tenant_id,
)
diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py
index a6aace168e..07bae66867 100644
--- a/api/core/app/apps/base_app_generator.py
+++ b/api/core/app/apps/base_app_generator.py
@@ -1,4 +1,3 @@
-import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -76,12 +75,24 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
- if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
- raise ValueError("Invalid input type")
- if any(
- filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
- ):
- raise ValueError("Invalid input type")
+ invalid_dict_keys = [
+ k
+ for k, v in user_inputs.items()
+ if isinstance(v, dict)
+ and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
+ ]
+ if invalid_dict_keys:
+ raise ValueError(f"Invalid input type for {invalid_dict_keys}")
+
+ invalid_list_dict_keys = [
+ k
+ for k, v in user_inputs.items()
+ if isinstance(v, list)
+ and any(isinstance(item, dict) for item in v)
+ and entity_dictionary[k].type != VariableEntityType.FILE_LIST
+ ]
+ if invalid_list_dict_keys:
+ raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
return user_inputs
@@ -178,12 +189,8 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
- if not isinstance(value, str):
- raise ValueError(f"{variable_entity.variable} in input form must be a string")
- try:
- json.loads(value)
- except json.JSONDecodeError:
- raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
+ if value and not isinstance(value, dict):
+ raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _:
raise AssertionError("this statement should be unreachable.")
diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py
index 698eee9894..b41bedbea4 100644
--- a/api/core/app/apps/base_app_queue_manager.py
+++ b/api/core/app/apps/base_app_queue_manager.py
@@ -90,6 +90,7 @@ class AppQueueManager:
"""
self._clear_task_belong_cache()
self._q.put(None)
+ self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
def _clear_task_belong_cache(self) -> None:
"""
diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py
index e2e6c11480..617515945b 100644
--- a/api/core/app/apps/base_app_runner.py
+++ b/api/core/app/apps/base_app_runner.py
@@ -1,6 +1,8 @@
+import base64
import logging
import time
from collections.abc import Generator, Mapping, Sequence
+from mimetypes import guess_extension
from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
ModelConfigWithCredentialsEntity,
)
-from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
+from core.app.entities.queue_entities import (
+ QueueAgentMessageEvent,
+ QueueLLMChunkEvent,
+ QueueMessageEndEvent,
+ QueueMessageFileEvent,
+)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
+from core.file.enums import FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
+ TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
@@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
-from models.model import App, AppMode, Message, MessageAnnotation
+from core.tools.tool_file_manager import ToolFileManager
+from extensions.ext_database import db
+from models.enums import CreatorUserRole
+from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING:
from core.file.models import File
@@ -203,6 +215,9 @@ class AppRunner:
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
+ message_id: str | None = None,
+ user_id: str | None = None,
+ tenant_id: str | None = None,
):
"""
Handle invoke result
@@ -210,21 +225,41 @@ class AppRunner:
:param queue_manager: application queue manager
:param stream: stream
:param agent: agent
+ :param message_id: message id for multimodal output
+ :param user_id: user id for multimodal output
+ :param tenant_id: tenant id for multimodal output
:return:
"""
if not stream and isinstance(invoke_result, LLMResult):
- self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
+ self._handle_invoke_result_direct(
+ invoke_result=invoke_result,
+ queue_manager=queue_manager,
+ )
elif stream and isinstance(invoke_result, Generator):
- self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
+ self._handle_invoke_result_stream(
+ invoke_result=invoke_result,
+ queue_manager=queue_manager,
+ agent=agent,
+ message_id=message_id,
+ user_id=user_id,
+ tenant_id=tenant_id,
+ )
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
- def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
+ def _handle_invoke_result_direct(
+ self,
+ invoke_result: LLMResult,
+ queue_manager: AppQueueManager,
+ ):
"""
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
+ :param message_id: message id for multimodal output
+ :param user_id: user id for multimodal output
+ :param tenant_id: tenant id for multimodal output
:return:
"""
queue_manager.publish(
@@ -235,13 +270,22 @@ class AppRunner:
)
def _handle_invoke_result_stream(
- self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
+ self,
+ invoke_result: Generator[LLMResultChunk, None, None],
+ queue_manager: AppQueueManager,
+ agent: bool,
+ message_id: str | None = None,
+ user_id: str | None = None,
+ tenant_id: str | None = None,
):
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
+ :param message_id: message id for multimodal output
+ :param user_id: user id for multimodal output
+ :param tenant_id: tenant id for multimodal output
:return:
"""
model: str = ""
@@ -259,12 +303,26 @@ class AppRunner:
text += message.content
elif isinstance(message.content, list):
for content in message.content:
- if not isinstance(content, str):
- # TODO(QuantumGhost): Add multimodal output support for easy ui.
- _logger.warning("received multimodal output, type=%s", type(content))
+ if isinstance(content, str):
+ text += content
+ elif isinstance(content, TextPromptMessageContent):
text += content.data
+ elif isinstance(content, ImagePromptMessageContent):
+ if message_id and user_id and tenant_id:
+ try:
+ self._handle_multimodal_image_content(
+ content=content,
+ message_id=message_id,
+ user_id=user_id,
+ tenant_id=tenant_id,
+ queue_manager=queue_manager,
+ )
+ except Exception:
+ _logger.exception("Failed to handle multimodal image output")
+ else:
+ _logger.warning("Received multimodal output but missing required parameters")
else:
- text += content # failback to str
+ text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model
@@ -289,6 +347,101 @@ class AppRunner:
PublishFrom.APPLICATION_MANAGER,
)
+ def _handle_multimodal_image_content(
+ self,
+ content: ImagePromptMessageContent,
+ message_id: str,
+ user_id: str,
+ tenant_id: str,
+ queue_manager: AppQueueManager,
+ ):
+ """
+ Handle multimodal image content from LLM response.
+ Save the image and create a MessageFile record.
+
+ :param content: ImagePromptMessageContent instance
+ :param message_id: message id
+ :param user_id: user id
+ :param tenant_id: tenant id
+ :param queue_manager: queue manager
+ :return:
+ """
+ _logger.info("Handling multimodal image content for message %s", message_id)
+
+ image_url = content.url
+ base64_data = content.base64_data
+
+ _logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
+
+ if not image_url and not base64_data:
+ _logger.warning("Image content has neither URL nor base64 data")
+ return
+
+ tool_file_manager = ToolFileManager()
+
+ # Save the image file
+ try:
+ if image_url:
+ # Download image from URL
+ _logger.info("Downloading image from URL: %s", image_url)
+ tool_file = tool_file_manager.create_file_by_url(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ file_url=image_url,
+ conversation_id=None,
+ )
+ _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
+ elif base64_data:
+ if base64_data.startswith("data:"):
+ base64_data = base64_data.split(",", 1)[1]
+
+ image_binary = base64.b64decode(base64_data)
+ mimetype = content.mime_type or "image/png"
+ extension = guess_extension(mimetype) or ".png"
+
+ tool_file = tool_file_manager.create_file_by_raw(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ conversation_id=None,
+ file_binary=image_binary,
+ mimetype=mimetype,
+ filename=f"generated_image{extension}",
+ )
+ _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
+ else:
+ return
+ except Exception:
+ _logger.exception("Failed to save image file")
+ return
+
+ # Create MessageFile record
+ message_file = MessageFile(
+ message_id=message_id,
+ type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ belongs_to="assistant",
+ url=f"/files/tools/{tool_file.id}",
+ upload_file_id=tool_file.id,
+ created_by_role=(
+ CreatorUserRole.ACCOUNT
+ if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
+ else CreatorUserRole.END_USER
+ ),
+ created_by=user_id,
+ )
+
+ db.session.add(message_file)
+ db.session.commit()
+ db.session.refresh(message_file)
+
+ # Publish QueueMessageFileEvent
+ queue_manager.publish(
+ QueueMessageFileEvent(message_file_id=message_file.id),
+ PublishFrom.APPLICATION_MANAGER,
+ )
+
+ _logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
+
def moderation_for_inputs(
self,
*,
diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py
index f8338b226b..7d1a4c619f 100644
--- a/api/core/app/apps/chat/app_runner.py
+++ b/api/core/app/apps/chat/app_runner.py
@@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
- invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
+ invoke_result=invoke_result,
+ queue_manager=queue_manager,
+ stream=application_generate_entity.stream,
+ message_id=message.id,
+ user_id=application_generate_entity.user_id,
+ tenant_id=app_config.tenant_id,
)
diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py
index ddfb5725b4..a872c2e1f7 100644
--- a/api/core/app/apps/completion/app_runner.py
+++ b/api/core/app/apps/completion/app_runner.py
@@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
- invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
+ invoke_result=invoke_result,
+ queue_manager=queue_manager,
+ stream=application_generate_entity.stream,
+ message_id=message.id,
+ user_id=application_generate_entity.user_id,
+ tenant_id=app_config.tenant_id,
)
diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py
index d67c4846aa..2959771940 100644
--- a/api/core/app/apps/message_based_app_generator.py
+++ b/api/core/app/apps/message_based_app_generator.py
@@ -1,3 +1,4 @@
+import json
import logging
import uuid
from collections.abc import Callable, Generator, Mapping
diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py
index 13eb40fd60..ea4441b5d8 100644
--- a/api/core/app/apps/pipeline/pipeline_generator.py
+++ b/api/core/app/apps/pipeline/pipeline_generator.py
@@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents: list[Document] = []
- if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
+ if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
@@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator):
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = args.get("original_document_id") or None
- if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
+ if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py
index 4be9e01fbf..8ea34344b2 100644
--- a/api/core/app/apps/pipeline/pipeline_runner.py
+++ b/api/core/app/apps/pipeline/pipeline_runner.py
@@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
+from core.app.workflow.node_factory import DifyNodeFactory
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
-from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@@ -73,9 +73,15 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
+ invoke_from = self.application_generate_entity.invoke_from
+
+ if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
+ invoke_from = InvokeFrom.DEBUGGER
+
+ user_from = self._resolve_user_from(invoke_from)
user_id = None
- if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
+ if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@@ -117,7 +123,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
- invoke_from=self.application_generate_entity.invoke_from.value,
+ invoke_from=invoke_from.value,
)
rag_pipeline_variables = []
@@ -149,6 +155,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
+ user_from=user_from,
+ invoke_from=invoke_from,
)
# RUN WORKFLOW
@@ -159,12 +167,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
- user_from=(
- UserFrom.ACCOUNT
- if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
- else UserFrom.END_USER
- ),
- invoke_from=self.application_generate_entity.invoke_from,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
@@ -210,7 +214,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
return workflow
def _init_rag_pipeline_graph(
- self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
+ self,
+ workflow: Workflow,
+ graph_runtime_state: GraphRuntimeState,
+ start_node_id: str | None = None,
+ user_from: UserFrom = UserFrom.ACCOUNT,
+ invoke_from: InvokeFrom = InvokeFrom.SERVICE_API,
) -> Graph:
"""
Init pipeline graph
@@ -253,8 +262,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
- user_from=UserFrom.ACCOUNT,
- invoke_from=InvokeFrom.SERVICE_API,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=0,
)
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 05ba53149b..dc5852d552 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -1,14 +1,16 @@
+from __future__ import annotations
+
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
-from typing import Any, Literal, Union, overload
+from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
-from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -24,6 +26,7 @@ from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTas
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
+from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
@@ -43,6 +46,9 @@ from models.model import App, EndUser
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
+if TYPE_CHECKING:
+ from controllers.console.app.workflow import LoopNodeRunPayload
+
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
logger = logging.getLogger(__name__)
@@ -434,7 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
- args: Mapping[str, Any],
+ args: LoopNodeRunPayload,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
@@ -450,7 +456,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
if not node_id:
raise ValueError("node_id is required")
- if args.get("inputs") is None:
+ if args.inputs is None:
raise ValueError("inputs is required")
# convert to app config
@@ -466,7 +472,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
- single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
+ single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
@@ -532,7 +538,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
- with Session(db.engine, expire_on_commit=False) as session:
+ with session_factory.create_session() as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py
index 44e63c7c4d..077b321104 100644
--- a/api/core/app/apps/workflow/app_runner.py
+++ b/api/core/app/apps/workflow/app_runner.py
@@ -7,10 +7,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@@ -20,7 +20,6 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from libs.datetime_utils import naive_utc_now
-from models.enums import UserFrom
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -66,6 +65,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
+ invoke_from = self.application_generate_entity.invoke_from
+ # if only single iteration or single loop run is requested
+ if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
+ invoke_from = InvokeFrom.DEBUGGER
+ user_from = self._resolve_user_from(invoke_from)
resume_state = self._resume_graph_runtime_state
@@ -109,6 +113,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
+ user_from=user_from,
+ invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
@@ -127,12 +133,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
- user_from=(
- UserFrom.ACCOUNT
- if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
- else UserFrom.END_USER
- ),
- invoke_from=self.application_generate_entity.invoke_from,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py
index b09385aad3..c9d7464c17 100644
--- a/api/core/app/apps/workflow_app_runner.py
+++ b/api/core/app/apps/workflow_app_runner.py
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@@ -28,6 +29,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
+from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph import Graph
@@ -60,7 +62,6 @@ from core.workflow.graph_events import (
)
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
-from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@@ -87,10 +88,18 @@ class WorkflowBasedAppRunner:
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
+ @staticmethod
+ def _resolve_user_from(invoke_from: InvokeFrom) -> UserFrom:
+ if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
+ return UserFrom.ACCOUNT
+ return UserFrom.END_USER
+
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
+ user_from: UserFrom,
+ invoke_from: InvokeFrom,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
@@ -115,8 +124,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
- user_from=UserFrom.ACCOUNT,
- invoke_from=self._queue_manager.invoke_from,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=0,
)
@@ -159,7 +168,7 @@ class WorkflowBasedAppRunner:
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
@@ -168,18 +177,22 @@ class WorkflowBasedAppRunner:
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
- graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
+ graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
+ node_type_filter_key="iteration_id",
+ node_type_label="iteration",
)
elif single_loop_run:
- graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
+ graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
+ node_type_filter_key="loop_id",
+ node_type_label="loop",
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
@@ -260,7 +273,7 @@ class WorkflowBasedAppRunner:
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT,
- invoke_from=self._queue_manager.invoke_from,
+ invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
@@ -270,7 +283,9 @@ class WorkflowBasedAppRunner:
)
# init graph
- graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
+ graph = Graph.init(
+ graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
+ )
if not graph:
raise ValueError("graph not found in workflow")
@@ -316,44 +331,6 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
- def _get_graph_and_variable_pool_of_single_iteration(
- self,
- workflow: Workflow,
- node_id: str,
- user_inputs: dict[str, Any],
- graph_runtime_state: GraphRuntimeState,
- ) -> tuple[Graph, VariablePool]:
- """
- Get variable pool of single iteration
- """
- return self._get_graph_and_variable_pool_for_single_node_run(
- workflow=workflow,
- node_id=node_id,
- user_inputs=user_inputs,
- graph_runtime_state=graph_runtime_state,
- node_type_filter_key="iteration_id",
- node_type_label="iteration",
- )
-
- def _get_graph_and_variable_pool_of_single_loop(
- self,
- workflow: Workflow,
- node_id: str,
- user_inputs: dict[str, Any],
- graph_runtime_state: GraphRuntimeState,
- ) -> tuple[Graph, VariablePool]:
- """
- Get variable pool of single loop
- """
- return self._get_graph_and_variable_pool_for_single_node_run(
- workflow=workflow,
- node_id=node_id,
- user_inputs=user_inputs,
- graph_runtime_state=graph_runtime_state,
- node_type_filter_key="loop_id",
- node_type_label="loop",
- )
-
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py
index 221a69cd3f..0e68e554c8 100644
--- a/api/core/app/entities/app_invoke_entities.py
+++ b/api/core/app/entities/app_invoke_entities.py
@@ -42,7 +42,8 @@ class InvokeFrom(StrEnum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
- PUBLISHED = "published"
+ # PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow.
+ PUBLISHED_PIPELINE = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"
diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py
index 79fbafe39e..3f9f3da9b2 100644
--- a/api/core/app/features/annotation_reply/annotation_reply.py
+++ b/api/core/app/features/annotation_reply/annotation_reply.py
@@ -75,7 +75,7 @@ class AnnotationReplyFeature:
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
- annotation.question,
+ annotation.question_text,
annotation.content,
query,
user_id,
diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py
new file mode 100644
index 0000000000..c070845b73
--- /dev/null
+++ b/api/core/app/layers/conversation_variable_persist_layer.py
@@ -0,0 +1,60 @@
+import logging
+
+from core.variables import VariableBase
+from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
+from core.workflow.conversation_variable_updater import ConversationVariableUpdater
+from core.workflow.enums import NodeType
+from core.workflow.graph_engine.layers.base import GraphEngineLayer
+from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
+from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
+
+logger = logging.getLogger(__name__)
+
+
+class ConversationVariablePersistenceLayer(GraphEngineLayer):
+ def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
+ super().__init__()
+ self._conversation_variable_updater = conversation_variable_updater
+
+ def on_graph_start(self) -> None:
+ pass
+
+ def on_event(self, event: GraphEngineEvent) -> None:
+ if not isinstance(event, NodeRunSucceededEvent):
+ return
+ if event.node_type != NodeType.VARIABLE_ASSIGNER:
+ return
+ if self.graph_runtime_state is None:
+ return
+
+ updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
+ if not updated_variables:
+ return
+
+ conversation_id = self.graph_runtime_state.system_variable.conversation_id
+ if conversation_id is None:
+ return
+
+ updated_any = False
+ for item in updated_variables:
+ selector = item.selector
+ if len(selector) < 2:
+ logger.warning("Conversation variable selector invalid. selector=%s", selector)
+ continue
+ if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
+ continue
+ variable = self.graph_runtime_state.variable_pool.get(selector)
+ if not isinstance(variable, VariableBase):
+ logger.warning(
+ "Conversation variable not found in variable pool. selector=%s",
+ selector,
+ )
+ continue
+ self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
+ updated_any = True
+
+ if updated_any:
+ self._conversation_variable_updater.flush()
+
+ def on_graph_end(self, error: Exception | None) -> None:
+ pass
diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py
index e1b5352c2a..1c267091a4 100644
--- a/api/core/app/layers/pause_state_persist_layer.py
+++ b/api/core/app/layers/pause_state_persist_layer.py
@@ -75,6 +75,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
+ super().__init__()
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
@@ -107,8 +108,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
if not isinstance(event, GraphRunPausedEvent):
return
- assert self.graph_runtime_state is not None
-
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py
index fe1a46a945..a7ea9ef446 100644
--- a/api/core/app/layers/trigger_post_layer.py
+++ b/api/core/app/layers/trigger_post_layer.py
@@ -3,8 +3,8 @@ from datetime import UTC, datetime
from typing import Any, ClassVar
from pydantic import TypeAdapter
-from sqlalchemy.orm import Session, sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
@@ -31,12 +31,11 @@ class TriggerPostLayer(GraphEngineLayer):
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
- session_maker: sessionmaker[Session],
):
+ super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
- self.session_maker = session_maker
def on_graph_start(self):
pass
@@ -46,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer):
Update trigger log with success or failure.
"""
if isinstance(event, tuple(self._STATUS_MAP.keys())):
- with self.session_maker() as session:
+ with session_factory.create_session() as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:
@@ -57,10 +56,6 @@ class TriggerPostLayer(GraphEngineLayer):
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
- if not self.graph_runtime_state:
- logger.exception("Graph runtime state is not set")
- return
-
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index 5bb93fa44a..6c997753fa 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
+ StreamEvent,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
+ _precomputed_event_type: StreamEvent | None = None
def __init__(
self,
@@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
- event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
+ # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
+ if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
+ self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
+ message_id=self._message_id
+ )
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
- event_type=event_type,
+ event_type=self._precomputed_event_type,
)
else:
yield self._agent_message_to_stream_response(
diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py
index c6e89a7663..d682083f34 100644
--- a/api/core/app/task_pipeline/message_cycle_manager.py
+++ b/api/core/app/task_pipeline/message_cycle_manager.py
@@ -5,7 +5,7 @@ from threading import Thread
from typing import Union
from flask import Flask, current_app
-from sqlalchemy import exists, select
+from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
StreamEvent,
WorkflowTaskState,
)
+from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
@@ -57,13 +58,15 @@ class MessageCycleManager:
self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent:
+ # Fast path: cached determination from prior QueueMessageFileEvent
if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE
- with Session(db.engine, expire_on_commit=False) as session:
- has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
+ # Use SQLAlchemy 2.x style session.scalar(select(...))
+ with session_factory.create_session() as session:
+ message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
- if has_file:
+ if message_file:
self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE
@@ -201,6 +204,8 @@ class MessageCycleManager:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
+ self._message_has_file.add(message_file.message_id)
+
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py
new file mode 100644
index 0000000000..172ee5d703
--- /dev/null
+++ b/api/core/app/workflow/__init__.py
@@ -0,0 +1,3 @@
+from .node_factory import DifyNodeFactory
+
+__all__ = ["DifyNodeFactory"]
diff --git a/api/core/app/workflow/layers/__init__.py b/api/core/app/workflow/layers/__init__.py
new file mode 100644
index 0000000000..945f75303c
--- /dev/null
+++ b/api/core/app/workflow/layers/__init__.py
@@ -0,0 +1,10 @@
+"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
+
+from .observability import ObservabilityLayer
+from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
+
+__all__ = [
+ "ObservabilityLayer",
+ "PersistenceWorkflowInfo",
+ "WorkflowPersistenceLayer",
+]
diff --git a/api/core/workflow/graph_engine/layers/observability.py b/api/core/app/workflow/layers/observability.py
similarity index 91%
rename from api/core/workflow/graph_engine/layers/observability.py
rename to api/core/app/workflow/layers/observability.py
index a674816884..94839c8ae3 100644
--- a/api/core/workflow/graph_engine/layers/observability.py
+++ b/api/core/app/workflow/layers/observability.py
@@ -18,12 +18,15 @@ from typing_extensions import override
from configs import dify_config
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
-from core.workflow.graph_engine.layers.node_parsers import (
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from extensions.otel.parser import (
DefaultNodeOTelParser,
+ LLMNodeOTelParser,
NodeOTelParser,
+ RetrievalNodeOTelParser,
ToolNodeOTelParser,
)
-from core.workflow.nodes.base.node import Node
from extensions.otel.runtime import is_instrument_flag_enabled
logger = logging.getLogger(__name__)
@@ -72,6 +75,8 @@ class ObservabilityLayer(GraphEngineLayer):
"""Initialize parser registry for node types."""
self._parsers = {
NodeType.TOOL: ToolNodeOTelParser(),
+ NodeType.LLM: LLMNodeOTelParser(),
+ NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
@@ -119,7 +124,9 @@ class ObservabilityLayer(GraphEngineLayer):
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
@override
- def on_node_run_end(self, node: Node, error: Exception | None) -> None:
+ def on_node_run_end(
+ self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
"""
Called when a node finishes execution.
@@ -139,7 +146,7 @@ class ObservabilityLayer(GraphEngineLayer):
span = node_context.span
parser = self._get_parser(node)
try:
- parser.parse(node=node, span=span, error=error)
+ parser.parse(node=node, span=span, error=error, result_event=result_event)
span.end()
finally:
token = node_context.token
diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/app/workflow/layers/persistence.py
similarity index 99%
rename from api/core/workflow/graph_engine/layers/persistence.py
rename to api/core/app/workflow/layers/persistence.py
index b70f36ec9e..41052b4f52 100644
--- a/api/core/workflow/graph_engine/layers/persistence.py
+++ b/api/core/app/workflow/layers/persistence.py
@@ -45,7 +45,6 @@ from core.workflow.graph_events import (
from core.workflow.node_events import NodeRunResult
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
-from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@@ -316,6 +315,9 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
# workflow inputs stay reusable without binding future runs to this conversation.
continue
inputs[f"sys.{field_name}"] = value
+ # Local import to avoid circular dependency during app bootstrapping.
+ from core.workflow.workflow_entry import WorkflowEntry
+
handled = WorkflowEntry.handle_special_values(inputs)
return handled or {}
@@ -337,8 +339,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
- if runtime_state is None:
- return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
@@ -404,6 +404,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
- if runtime_state is None:
- return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py
new file mode 100644
index 0000000000..e0a0059a38
--- /dev/null
+++ b/api/core/app/workflow/node_factory.py
@@ -0,0 +1,152 @@
+from collections.abc import Callable, Sequence
+from typing import TYPE_CHECKING, final
+
+from typing_extensions import override
+
+from configs import dify_config
+from core.file import file_manager
+from core.helper import ssrf_proxy
+from core.helper.code_executor.code_executor import CodeExecutor
+from core.helper.code_executor.code_node_provider import CodeNodeProvider
+from core.tools.tool_file_manager import ToolFileManager
+from core.workflow.enums import NodeType
+from core.workflow.graph import NodeFactory
+from core.workflow.nodes.base.node import Node
+from core.workflow.nodes.code.code_node import CodeNode
+from core.workflow.nodes.code.limits import CodeNodeLimits
+from core.workflow.nodes.http_request.node import HttpRequestNode
+from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
+from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
+from core.workflow.nodes.template_transform.template_renderer import (
+ CodeExecutorJinja2TemplateRenderer,
+ Jinja2TemplateRenderer,
+)
+from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
+from libs.typing import is_str, is_str_dict
+
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
+
+@final
+class DifyNodeFactory(NodeFactory):
+ """
+ Default implementation of NodeFactory that uses the traditional node mapping.
+
+ This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
+ and instantiating the appropriate node class.
+ """
+
+ def __init__(
+ self,
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ code_executor: type[CodeExecutor] | None = None,
+ code_providers: Sequence[type[CodeNodeProvider]] | None = None,
+ code_limits: CodeNodeLimits | None = None,
+ template_renderer: Jinja2TemplateRenderer | None = None,
+ http_request_http_client: HttpClientProtocol = ssrf_proxy,
+ http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
+ http_request_file_manager: FileManagerProtocol = file_manager,
+ ) -> None:
+ self.graph_init_params = graph_init_params
+ self.graph_runtime_state = graph_runtime_state
+ self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
+ self._code_providers: tuple[type[CodeNodeProvider], ...] = (
+ tuple(code_providers) if code_providers else CodeNode.default_code_providers()
+ )
+ self._code_limits = code_limits or CodeNodeLimits(
+ max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
+ max_number=dify_config.CODE_MAX_NUMBER,
+ min_number=dify_config.CODE_MIN_NUMBER,
+ max_precision=dify_config.CODE_MAX_PRECISION,
+ max_depth=dify_config.CODE_MAX_DEPTH,
+ max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
+ max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
+ max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
+ )
+ self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
+ self._http_request_http_client = http_request_http_client
+ self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
+ self._http_request_file_manager = http_request_file_manager
+
+ @override
+ def create_node(self, node_config: dict[str, object]) -> Node:
+ """
+ Create a Node instance from node configuration data using the traditional mapping.
+
+ :param node_config: node configuration dictionary containing type and other data
+ :return: initialized Node instance
+ :raises ValueError: if node type is unknown or configuration is invalid
+ """
+ # Get node_id from config
+ node_id = node_config.get("id")
+ if not is_str(node_id):
+ raise ValueError("Node config missing id")
+
+ # Get node type from config
+ node_data = node_config.get("data", {})
+ if not is_str_dict(node_data):
+ raise ValueError(f"Node {node_id} missing data information")
+
+ node_type_str = node_data.get("type")
+ if not is_str(node_type_str):
+ raise ValueError(f"Node {node_id} missing or invalid type information")
+
+ try:
+ node_type = NodeType(node_type_str)
+ except ValueError:
+ raise ValueError(f"Unknown node type: {node_type_str}")
+
+ # Get node class
+ node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
+ if not node_mapping:
+ raise ValueError(f"No class mapping found for node type: {node_type}")
+
+ latest_node_class = node_mapping.get(LATEST_VERSION)
+ node_version = str(node_data.get("version", "1"))
+ matched_node_class = node_mapping.get(node_version)
+ node_class = matched_node_class or latest_node_class
+ if not node_class:
+ raise ValueError(f"No latest version class found for node type: {node_type}")
+
+ # Create node instance
+ if node_type == NodeType.CODE:
+ return CodeNode(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ code_executor=self._code_executor,
+ code_providers=self._code_providers,
+ code_limits=self._code_limits,
+ )
+
+ if node_type == NodeType.TEMPLATE_TRANSFORM:
+ return TemplateTransformNode(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ template_renderer=self._template_renderer,
+ )
+
+ if node_type == NodeType.HTTP_REQUEST:
+ return HttpRequestNode(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ http_client=self._http_request_http_client,
+ tool_file_manager_factory=self._http_request_tool_file_manager_factory,
+ file_manager=self._http_request_file_manager,
+ )
+
+ return node_class(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ )
diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py
index 50c7249fe4..451e4fda0e 100644
--- a/api/core/datasource/__base/datasource_plugin.py
+++ b/api/core/datasource/__base/datasource_plugin.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from abc import ABC, abstractmethod
from configs import dify_config
@@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
"""
return DatasourceProviderType.LOCAL_FILE
- def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
+ def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py
index 260dcf04f5..dde7d59726 100644
--- a/api/core/datasource/entities/datasource_entities.py
+++ b/api/core/datasource/entities/datasource_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import enum
from enum import StrEnum
from typing import Any
@@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
ONLINE_DRIVE = "online_drive"
@classmethod
- def value_of(cls, value: str) -> "DatasourceProviderType":
+ def value_of(cls, value: str) -> DatasourceProviderType:
"""
Get value of given mode.
@@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
typ: DatasourceParameterType,
required: bool,
options: list[str] | None = None,
- ) -> "DatasourceParameter":
+ ) -> DatasourceParameter:
"""
get a simple datasource parameter
@@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
- def empty(cls) -> "DatasourceInvokeMeta":
+ def empty(cls) -> DatasourceInvokeMeta:
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
- def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
+ def error_instance(cls, error: str) -> DatasourceInvokeMeta:
"""
Get an instance of DatasourceInvokeMeta with error
"""
diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py
index 98ea15e3fc..ce23da1e09 100644
--- a/api/core/datasource/online_document/online_document_plugin.py
+++ b/api/core/datasource/online_document/online_document_plugin.py
@@ -1,4 +1,4 @@
-from collections.abc import Generator, Mapping
+from collections.abc import Generator
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
@@ -34,7 +34,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
def get_online_document_pages(
self,
user_id: str,
- datasource_parameters: Mapping[str, Any],
+ datasource_parameters: dict[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager()
diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py
index 1dae2eafd4..45d4bc4594 100644
--- a/api/core/db/session_factory.py
+++ b/api/core/db/session_factory.py
@@ -1,7 +1,7 @@
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
-_session_maker: sessionmaker | None = None
+_session_maker: sessionmaker[Session] | None = None
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
@@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
-def get_session_maker() -> sessionmaker:
+def get_session_maker() -> sessionmaker[Session]:
if _session_maker is None:
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
return _session_maker
@@ -27,7 +27,7 @@ class SessionFactory:
configure_session_factory(engine, expire_on_commit)
@staticmethod
- def get_session_maker() -> sessionmaker:
+ def get_session_maker() -> sessionmaker[Session]:
return get_session_maker()
@staticmethod
diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py
index 7fdf5e4be6..135d2a4945 100644
--- a/api/core/entities/mcp_provider.py
+++ b/api/core/entities/mcp_provider.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from datetime import datetime
from enum import StrEnum
@@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
updated_at: datetime
@classmethod
- def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
+ def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
"""Create entity from database model with decryption"""
return cls(
diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py
index 12431976f0..a123fb0321 100644
--- a/api/core/entities/model_entities.py
+++ b/api/core/entities/model_entities.py
@@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
- icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity):
@@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_small_dark=provider_entity.icon_small_dark,
- icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types,
)
@@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
- icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType] = []
diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py
index 8a8067332d..0078ec7e4f 100644
--- a/api/core/entities/provider_entities.py
+++ b/api/core/entities/provider_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from enum import StrEnum, auto
from typing import Union
@@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
- def value_of(cls, value: str) -> "ProviderConfig.Type":
+ def value_of(cls, value: str) -> ProviderConfig.Type:
"""
Get value of given mode.
diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py
index 6d553d7dc6..2ac483673a 100644
--- a/api/core/file/helpers.py
+++ b/api/core/file/helpers.py
@@ -8,8 +8,9 @@ import urllib.parse
from configs import dify_config
-def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
- url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
+def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
+ base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
+ url = f"{base_url}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
diff --git a/api/core/file/models.py b/api/core/file/models.py
index d149205d77..6324523b22 100644
--- a/api/core/file/models.py
+++ b/api/core/file/models.py
@@ -112,17 +112,17 @@ class File(BaseModel):
return text
- def generate_url(self) -> str | None:
+ def generate_url(self, for_external: bool = True) -> str | None:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
- return helpers.get_signed_file_url(upload_file_id=self.related_id)
+ return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
assert self.related_id is not None
assert self.extension is not None
- return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
+ return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
return None
def to_plugin_parameter(self) -> dict[str, Any]:
@@ -133,7 +133,7 @@ class File(BaseModel):
"extension": self.extension,
"size": self.size,
"type": self.type,
- "url": self.generate_url(),
+ "url": self.generate_url(for_external=False),
}
@model_validator(mode="after")
diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py
index 969125d2f7..5e4807401e 100644
--- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py
+++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py
@@ -1,9 +1,14 @@
+from collections.abc import Mapping
from textwrap import dedent
+from typing import Any
from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer):
+ # Use separate placeholder for base64-encoded template to avoid confusion
+ _template_b64_placeholder: str = "{{template_b64}}"
+
@classmethod
def transform_response(cls, response: str):
"""
@@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer):
"""
return {"result": cls.extract_result_str_from_response(response)}
+ @classmethod
+ def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
+ """
+ Override base class to use base64 encoding for template code.
+ This prevents issues with special characters (quotes, newlines) in templates
+ breaking the generated Python script. Fixes #26818.
+ """
+ script = cls.get_runner_script()
+ # Encode template as base64 to safely embed any content including quotes
+ code_b64 = cls.serialize_code(code)
+ script = script.replace(cls._template_b64_placeholder, code_b64)
+ inputs_str = cls.serialize_inputs(inputs)
+ script = script.replace(cls._inputs_placeholder, inputs_str)
+ return script
+
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
- # declare main function
- def main(**inputs):
- import jinja2
- template = jinja2.Template('''{cls._code_placeholder}''')
- return template.render(**inputs)
-
+ import jinja2
import json
from base64 import b64decode
+ # declare main function
+ def main(**inputs):
+ # Decode base64-encoded template to handle special characters safely
+ template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
+ template = jinja2.Template(template_code)
+ return template.render(**inputs)
+
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index 3965f8cb31..5cdea19a8d 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -13,6 +13,15 @@ class TemplateTransformer(ABC):
_inputs_placeholder: str = "{{inputs}}"
_result_tag: str = "<>"
+ @classmethod
+ def serialize_code(cls, code: str) -> str:
+ """
+ Serialize template code to base64 to safely embed in generated script.
+ This prevents issues with special characters like quotes breaking the script.
+ """
+ code_bytes = code.encode("utf-8")
+ return b64encode(code_bytes).decode("utf-8")
+
@classmethod
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
"""
@@ -67,7 +76,7 @@ class TemplateTransformer(ABC):
Post-process the result to convert scientific notation strings back to numbers
"""
- def convert_scientific_notation(value):
+ def convert_scientific_notation(value: Any) -> Any:
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
@@ -81,7 +90,7 @@ class TemplateTransformer(ABC):
return [convert_scientific_notation(v) for v in value]
return value
- return convert_scientific_notation(result) # type: ignore[no-any-return]
+ return convert_scientific_notation(result)
@classmethod
@abstractmethod
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index 0b36969cf9..128c64ff2c 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
pass
+request_error = httpx.RequestError
+max_retries_exceeded_error = MaxRetriesExceededError
+
+
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(
@@ -88,7 +92,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
+def _inject_trace_headers(headers: dict | None) -> dict:
+ """
+ Inject W3C traceparent header for distributed tracing.
+
+ When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
+ When OTEL is disabled, we manually inject the traceparent header.
+ """
+ if headers is None:
+ headers = {}
+
+ # Skip if already present (case-insensitive check)
+ for key in headers:
+ if key.lower() == "traceparent":
+ return headers
+
+ # Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
+ if dify_config.ENABLE_OTEL:
+ return headers
+
+ # Generate and inject traceparent for non-OTEL scenarios
+ try:
+ from core.helper.trace_id_helper import generate_traceparent_header
+
+ traceparent = generate_traceparent_header()
+ if traceparent:
+ headers["traceparent"] = traceparent
+ except Exception:
+ # Silently ignore errors to avoid breaking requests
+ logger.debug("Failed to generate traceparent header", exc_info=True)
+
+ return headers
+
+
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+ # Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
@@ -106,18 +144,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
+ # Inject traceparent header for distributed tracing (when OTEL is not enabled)
+ headers = kwargs.get("headers") or {}
+ headers = _inject_trace_headers(headers)
+ kwargs["headers"] = headers
+
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
- headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
- # Build the request manually to preserve the Host header
- # httpx may override the Host header when using a proxy, so we use
- # the request API to explicitly set headers before sending
+ # Preserve the user-provided Host header
+ # httpx may override the Host header when using a proxy
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host
diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py
deleted file mode 100644
index c5447c2b3f..0000000000
--- a/api/core/helper/tool_provider_cache.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import json
-import logging
-from typing import Any, cast
-
-from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
-from extensions.ext_redis import redis_client, redis_fallback
-
-logger = logging.getLogger(__name__)
-
-
-class ToolProviderListCache:
- """Cache for tool provider lists"""
-
- CACHE_TTL = 300 # 5 minutes
-
- @staticmethod
- def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
- """Generate cache key for tool providers list"""
- type_filter = typ or "all"
- return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
-
- @staticmethod
- @redis_fallback(default_return=None)
- def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
- """Get cached tool providers"""
- cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
- cached_data = redis_client.get(cache_key)
- if cached_data:
- try:
- return json.loads(cached_data.decode("utf-8"))
- except (json.JSONDecodeError, UnicodeDecodeError):
- logger.warning("Failed to decode cached tool providers data")
- return None
- return None
-
- @staticmethod
- @redis_fallback()
- def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
- """Cache tool providers"""
- cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
- redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
-
- @staticmethod
- @redis_fallback()
- def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
- """Invalidate cache for tool providers"""
- if typ:
- # Invalidate specific type cache
- cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
- redis_client.delete(cache_key)
- else:
- # Invalidate all caches for this tenant
- keys = ["builtin", "model", "api", "workflow", "mcp"]
- pipeline = redis_client.pipeline()
- for key in keys:
- cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
- pipeline.delete(cache_key)
- pipeline.execute()
diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py
index 820502e558..e827859109 100644
--- a/api/core/helper/trace_id_helper.py
+++ b/api/core/helper/trace_id_helper.py
@@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
return None
+
+
+def get_span_id_from_otel_context() -> str | None:
+ """
+ Retrieve the current span ID from the active OpenTelemetry trace context.
+
+ Returns:
+ A 16-character hex string representing the span ID, or None if not available.
+ """
+ try:
+ from opentelemetry.trace import get_current_span
+ from opentelemetry.trace.span import INVALID_SPAN_ID
+
+ span = get_current_span()
+ if not span:
+ return None
+
+ span_context = span.get_span_context()
+ if not span_context or span_context.span_id == INVALID_SPAN_ID:
+ return None
+
+ return f"{span_context.span_id:016x}"
+ except Exception:
+ return None
+
+
+def generate_traceparent_header() -> str | None:
+ """
+ Generate a W3C traceparent header from the current context.
+
+ Uses OpenTelemetry context if available, otherwise uses the
+ ContextVar-based trace_id from the logging context.
+
+ Format: {version}-{trace_id}-{span_id}-{flags}
+ Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
+
+ Returns:
+ A valid traceparent header string, or None if generation fails.
+ """
+ import uuid
+
+ # Try OTEL context first
+ trace_id = get_trace_id_from_otel_context()
+ span_id = get_span_id_from_otel_context()
+
+ if trace_id and span_id:
+ return f"00-{trace_id}-{span_id}-01"
+
+ # Fallback: use ContextVar-based trace_id or generate new one
+ from core.logging.context import get_trace_id as get_logging_trace_id
+
+ trace_id = get_logging_trace_id() or uuid.uuid4().hex
+
+ # Generate a new span_id (16 hex chars)
+ span_id = uuid.uuid4().hex[:16]
+
+ return f"00-{trace_id}-{span_id}-01"
diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py
index af860a1070..370e64e385 100644
--- a/api/core/hosting_configuration.py
+++ b/api/core/hosting_configuration.py
@@ -56,6 +56,10 @@ class HostingConfiguration:
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
+ self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
+ self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
+ self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
+ self.provider_map[f"{DEFAULT_PLUGIN_ID}/tongyi/tongyi"] = self.init_tongyi()
self.moderation_config = self.init_moderation_config()
@@ -128,7 +132,7 @@ class HostingConfiguration:
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
- hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
+ hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
@@ -156,18 +160,49 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
- @staticmethod
- def init_anthropic() -> HostingProvider:
- quota_unit = QuotaUnit.TOKENS
+ def init_gemini(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
+ quotas: list[HostingQuota] = []
+
+ if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
+ hosted_quota_limit = 0
+ trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
+ quotas.append(trial_quota)
+
+ if dify_config.HOSTED_GEMINI_PAID_ENABLED:
+ paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
+ quotas.append(paid_quota)
+
+ if len(quotas) > 0:
+ credentials = {
+ "google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
+ }
+
+ if dify_config.HOSTED_GEMINI_API_BASE:
+ credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
+
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
+
+ return HostingProvider(
+ enabled=False,
+ quota_unit=quota_unit,
+ )
+
+ def init_anthropic(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
- hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
- trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
+ hosted_quota_limit = 0
+ trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
- paid_quota = PaidHostingQuota()
+ paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
@@ -185,6 +220,94 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
+ def init_tongyi(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
+ quotas: list[HostingQuota] = []
+
+ if dify_config.HOSTED_TONGYI_TRIAL_ENABLED:
+ hosted_quota_limit = 0
+ trail_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
+ quotas.append(trial_quota)
+
+ if dify_config.HOSTED_TONGYI_PAID_ENABLED:
+ paid_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
+ quotas.append(paid_quota)
+
+ if len(quotas) > 0:
+ credentials = {
+ "dashscope_api_key": dify_config.HOSTED_TONGYI_API_KEY,
+ "use_international_endpoint": dify_config.HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT,
+ }
+
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
+
+ return HostingProvider(
+ enabled=False,
+ quota_unit=quota_unit,
+ )
+
+ def init_xai(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
+ quotas: list[HostingQuota] = []
+
+ if dify_config.HOSTED_XAI_TRIAL_ENABLED:
+ hosted_quota_limit = 0
+ trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
+ quotas.append(trial_quota)
+
+ if dify_config.HOSTED_XAI_PAID_ENABLED:
+ paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
+ quotas.append(paid_quota)
+
+ if len(quotas) > 0:
+ credentials = {
+ "api_key": dify_config.HOSTED_XAI_API_KEY,
+ }
+
+ if dify_config.HOSTED_XAI_API_BASE:
+ credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
+
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
+
+ return HostingProvider(
+ enabled=False,
+ quota_unit=quota_unit,
+ )
+
+ def init_deepseek(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
+ quotas: list[HostingQuota] = []
+
+ if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
+ hosted_quota_limit = 0
+ trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
+ quotas.append(trial_quota)
+
+ if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
+ paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
+ quotas.append(paid_quota)
+
+ if len(quotas) > 0:
+ credentials = {
+ "api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
+ }
+
+ if dify_config.HOSTED_DEEPSEEK_API_BASE:
+ credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
+
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
+
+ return HostingProvider(
+ enabled=False,
+ quota_unit=quota_unit,
+ )
+
@staticmethod
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index b4c3ec1caf..be1e306d47 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -71,8 +71,8 @@ class LLMGenerator:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
- answer = cast(str, response.message.content)
- if answer is None:
+ answer = response.message.get_text_content()
+ if answer == "":
return ""
try:
result_dict = json.loads(answer)
@@ -184,7 +184,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- rule_config["prompt"] = cast(str, response.message.content)
+ rule_config["prompt"] = response.message.get_text_content()
except InvokeError as e:
error = str(e)
@@ -237,13 +237,11 @@ class LLMGenerator:
return rule_config
- rule_config["prompt"] = cast(str, prompt_content.message.content)
+ rule_config["prompt"] = prompt_content.message.get_text_content()
- if not isinstance(prompt_content.message.content, str):
- raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
- "INPUT_TEXT": prompt_content.message.content,
+ "INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -253,7 +251,7 @@ class LLMGenerator:
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
- "INPUT_TEXT": prompt_content.message.content,
+ "INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -263,7 +261,7 @@ class LLMGenerator:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
- rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
+ rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
except InvokeError as e:
error = str(e)
error_step = "generate variables"
@@ -272,7 +270,7 @@ class LLMGenerator:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
- rule_config["opening_statement"] = cast(str, statement_content.message.content)
+ rule_config["opening_statement"] = statement_content.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
@@ -315,7 +313,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- generated_code = cast(str, response.message.content)
+ generated_code = response.message.get_text_content()
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
@@ -351,7 +349,7 @@ class LLMGenerator:
raise TypeError("Expected LLMResult when stream=False")
response = result
- answer = cast(str, response.message.content)
+ answer = response.message.get_text_content()
return answer.strip()
@classmethod
@@ -375,10 +373,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- raw_content = response.message.content
-
- if not isinstance(raw_content, str):
- raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
+ raw_content = response.message.get_text_content()
try:
parsed_content = json.loads(raw_content)
diff --git a/api/core/logging/__init__.py b/api/core/logging/__init__.py
new file mode 100644
index 0000000000..db046cc9fa
--- /dev/null
+++ b/api/core/logging/__init__.py
@@ -0,0 +1,20 @@
+"""Structured logging components for Dify."""
+
+from core.logging.context import (
+ clear_request_context,
+ get_request_id,
+ get_trace_id,
+ init_request_context,
+)
+from core.logging.filters import IdentityContextFilter, TraceContextFilter
+from core.logging.structured_formatter import StructuredJSONFormatter
+
+__all__ = [
+ "IdentityContextFilter",
+ "StructuredJSONFormatter",
+ "TraceContextFilter",
+ "clear_request_context",
+ "get_request_id",
+ "get_trace_id",
+ "init_request_context",
+]
diff --git a/api/core/logging/context.py b/api/core/logging/context.py
new file mode 100644
index 0000000000..18633a0b05
--- /dev/null
+++ b/api/core/logging/context.py
@@ -0,0 +1,35 @@
+"""Request context for logging - framework agnostic.
+
+This module provides request-scoped context variables for logging,
+using Python's contextvars for thread-safe and async-safe storage.
+"""
+
+import uuid
+from contextvars import ContextVar
+
+_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
+_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
+
+
+def get_request_id() -> str:
+ """Get current request ID (10 hex chars)."""
+ return _request_id.get()
+
+
+def get_trace_id() -> str:
+ """Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
+ return _trace_id.get()
+
+
+def init_request_context() -> None:
+ """Initialize request context. Call at start of each request."""
+ req_id = uuid.uuid4().hex[:10]
+ trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
+ _request_id.set(req_id)
+ _trace_id.set(trace_id)
+
+
+def clear_request_context() -> None:
+ """Clear request context. Call at end of request (optional)."""
+ _request_id.set("")
+ _trace_id.set("")
diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py
new file mode 100644
index 0000000000..1e8aa8d566
--- /dev/null
+++ b/api/core/logging/filters.py
@@ -0,0 +1,94 @@
+"""Logging filters for structured logging."""
+
+import contextlib
+import logging
+
+import flask
+
+from core.logging.context import get_request_id, get_trace_id
+
+
+class TraceContextFilter(logging.Filter):
+ """
+ Filter that adds trace_id and span_id to log records.
+ Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
+ """
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ # Get trace context from OpenTelemetry
+ trace_id, span_id = self._get_otel_context()
+
+ # Set trace_id (fallback to ContextVar if no OTEL context)
+ if trace_id:
+ record.trace_id = trace_id
+ else:
+ record.trace_id = get_trace_id()
+
+ record.span_id = span_id or ""
+
+ # For backward compatibility, also set req_id
+ record.req_id = get_request_id()
+
+ return True
+
+ def _get_otel_context(self) -> tuple[str, str]:
+ """Extract trace_id and span_id from OpenTelemetry context."""
+ with contextlib.suppress(Exception):
+ from opentelemetry.trace import get_current_span
+ from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
+
+ span = get_current_span()
+ if span and span.get_span_context():
+ ctx = span.get_span_context()
+ if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
+ trace_id = f"{ctx.trace_id:032x}"
+ span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
+ return trace_id, span_id
+ return "", ""
+
+
+class IdentityContextFilter(logging.Filter):
+ """
+ Filter that adds user identity context to log records.
+ Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
+ """
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ identity = self._extract_identity()
+ record.tenant_id = identity.get("tenant_id", "")
+ record.user_id = identity.get("user_id", "")
+ record.user_type = identity.get("user_type", "")
+ return True
+
+ def _extract_identity(self) -> dict[str, str]:
+ """Extract identity from current_user if in request context."""
+ try:
+ if not flask.has_request_context():
+ return {}
+ from flask_login import current_user
+
+ # Check if user is authenticated using the proxy
+ if not current_user.is_authenticated:
+ return {}
+
+ # Access the underlying user object
+ user = current_user
+
+ from models import Account
+ from models.model import EndUser
+
+ identity: dict[str, str] = {}
+
+ if isinstance(user, Account):
+ if user.current_tenant_id:
+ identity["tenant_id"] = user.current_tenant_id
+ identity["user_id"] = user.id
+ identity["user_type"] = "account"
+ elif isinstance(user, EndUser):
+ identity["tenant_id"] = user.tenant_id
+ identity["user_id"] = user.id
+ identity["user_type"] = user.type or "end_user"
+
+ return identity
+ except Exception:
+ return {}
diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py
new file mode 100644
index 0000000000..4295d2dd34
--- /dev/null
+++ b/api/core/logging/structured_formatter.py
@@ -0,0 +1,107 @@
+"""Structured JSON log formatter for Dify."""
+
+import logging
+import traceback
+from datetime import UTC, datetime
+from typing import Any
+
+import orjson
+
+from configs import dify_config
+
+
+class StructuredJSONFormatter(logging.Formatter):
+ """
+ JSON log formatter following the specified schema:
+ {
+ "ts": "ISO 8601 UTC",
+ "severity": "INFO|ERROR|WARN|DEBUG",
+ "service": "service name",
+ "caller": "file:line",
+ "trace_id": "hex 32",
+ "span_id": "hex 16",
+ "identity": { "tenant_id", "user_id", "user_type" },
+ "message": "log message",
+ "attributes": { ... },
+ "stack_trace": "..."
+ }
+ """
+
+ SEVERITY_MAP: dict[int, str] = {
+ logging.DEBUG: "DEBUG",
+ logging.INFO: "INFO",
+ logging.WARNING: "WARN",
+ logging.ERROR: "ERROR",
+ logging.CRITICAL: "ERROR",
+ }
+
+ def __init__(self, service_name: str | None = None):
+ super().__init__()
+ self._service_name = service_name or dify_config.APPLICATION_NAME
+
+ def format(self, record: logging.LogRecord) -> str:
+ log_dict = self._build_log_dict(record)
+ try:
+ return orjson.dumps(log_dict).decode("utf-8")
+ except TypeError:
+ # Fallback: convert non-serializable objects to string
+ import json
+
+ return json.dumps(log_dict, default=str, ensure_ascii=False)
+
+ def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
+ # Core fields
+ log_dict: dict[str, Any] = {
+ "ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
+ "severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
+ "service": self._service_name,
+ "caller": f"{record.filename}:{record.lineno}",
+ "message": record.getMessage(),
+ }
+
+ # Trace context (from TraceContextFilter)
+ trace_id = getattr(record, "trace_id", "")
+ span_id = getattr(record, "span_id", "")
+
+ if trace_id:
+ log_dict["trace_id"] = trace_id
+ if span_id:
+ log_dict["span_id"] = span_id
+
+ # Identity context (from IdentityContextFilter)
+ identity = self._extract_identity(record)
+ if identity:
+ log_dict["identity"] = identity
+
+ # Dynamic attributes
+ attributes = getattr(record, "attributes", None)
+ if attributes:
+ log_dict["attributes"] = attributes
+
+ # Stack trace for errors with exceptions
+ if record.exc_info and record.levelno >= logging.ERROR:
+ log_dict["stack_trace"] = self._format_exception(record.exc_info)
+
+ return log_dict
+
+ def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
+ tenant_id = getattr(record, "tenant_id", None)
+ user_id = getattr(record, "user_id", None)
+ user_type = getattr(record, "user_type", None)
+
+ if not any([tenant_id, user_id, user_type]):
+ return None
+
+ identity: dict[str, str] = {}
+ if tenant_id:
+ identity["tenant_id"] = tenant_id
+ if user_id:
+ identity["user_id"] = user_id
+ if user_type:
+ identity["user_type"] = user_type
+ return identity
+
+ def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
+ if exc_info and exc_info[0] is not None:
+ return "".join(traceback.format_exception(*exc_info))
+ return ""
diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py
index f81e7cead8..5c3cd0d8f8 100644
--- a/api/core/mcp/client/streamable_client.py
+++ b/api/core/mcp/client/streamable_client.py
@@ -313,17 +313,20 @@ class StreamableHTTPTransport:
if is_initialization:
self._maybe_extract_session_id_from_response(response)
- content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
+ # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
+ # The server MUST NOT send a response to notifications.
+ if isinstance(message.root, JSONRPCRequest):
+ content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
- if content_type.startswith(JSON):
- self._handle_json_response(response, ctx.server_to_client_queue)
- elif content_type.startswith(SSE):
- self._handle_sse_response(response, ctx)
- else:
- self._handle_unexpected_content_type(
- content_type,
- ctx.server_to_client_queue,
- )
+ if content_type.startswith(JSON):
+ self._handle_json_response(response, ctx.server_to_client_queue)
+ elif content_type.startswith(SSE):
+ self._handle_sse_response(response, ctx)
+ else:
+ self._handle_unexpected_content_type(
+ content_type,
+ ctx.server_to_client_queue,
+ )
def _handle_json_response(
self,
diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py
index c97ae6eac7..84a6fd0d1f 100644
--- a/api/core/mcp/session/base_session.py
+++ b/api/core/mcp/session/base_session.py
@@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
- session: """BaseSession[
- SendRequestT,
- SendNotificationT,
- SendResultT,
- ReceiveRequestT,
- ReceiveNotificationT
- ]""",
+ session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id
diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py
index 89dae2dbff..9e46d72893 100644
--- a/api/core/model_runtime/entities/message_entities.py
+++ b/api/core/model_runtime/entities/message_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
@@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
TOOL = auto()
@classmethod
- def value_of(cls, value: str) -> "PromptMessageRole":
+ def value_of(cls, value: str) -> PromptMessageRole:
"""
Get value of given mode.
@@ -249,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
- if not super().is_empty() and not self.tool_calls:
- return False
-
- return True
+ return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):
diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py
index aee6ce1108..19194d162c 100644
--- a/api/core/model_runtime/entities/model_entities.py
+++ b/api/core/model_runtime/entities/model_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from decimal import Decimal
from enum import StrEnum, auto
from typing import Any
@@ -20,7 +22,7 @@ class ModelType(StrEnum):
TTS = auto()
@classmethod
- def value_of(cls, origin_model_type: str) -> "ModelType":
+ def value_of(cls, origin_model_type: str) -> ModelType:
"""
Get model type from origin model type.
@@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
JSON_SCHEMA = auto()
@classmethod
- def value_of(cls, value: Any) -> "DefaultParameterName":
+ def value_of(cls, value: Any) -> DefaultParameterName:
"""
Get parameter name from value.
diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py
index 648b209ef1..2d88751668 100644
--- a/api/core/model_runtime/entities/provider_entities.py
+++ b/api/core/model_runtime/entities/provider_entities.py
@@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
- icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
@@ -123,7 +122,6 @@ class ProviderEntity(BaseModel):
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
- icon_large: I18nObject | None = None
icon_small_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
@@ -157,7 +155,6 @@ class ProviderEntity(BaseModel):
provider=self.provider,
label=self.label,
icon_small=self.icon_small,
- icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models,
)
diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py
index c0f4c504d9..7a0757f219 100644
--- a/api/core/model_runtime/model_providers/__base/large_language_model.py
+++ b/api/core/model_runtime/model_providers/__base/large_language_model.py
@@ -1,7 +1,7 @@
import logging
import time
import uuid
-from collections.abc import Generator, Sequence
+from collections.abc import Callable, Generator, Iterator, Sequence
from typing import Union
from pydantic import ConfigDict
@@ -30,6 +30,142 @@ def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
+def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None:
+ if not callbacks:
+ return
+
+ for callback in callbacks:
+ try:
+ invoke(callback)
+ except Exception as e:
+ if callback.raise_error:
+ raise
+ logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e)
+
+
+def _get_or_create_tool_call(
+ existing_tools_calls: list[AssistantPromptMessage.ToolCall],
+ tool_call_id: str,
+) -> AssistantPromptMessage.ToolCall:
+ """
+ Get or create a tool call by ID.
+
+ If `tool_call_id` is empty, returns the most recently created tool call.
+ """
+ if not tool_call_id:
+ if not existing_tools_calls:
+ raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta")
+ return existing_tools_calls[-1]
+
+ tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None)
+ if tool_call is None:
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=tool_call_id,
+ type="function",
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
+ )
+ existing_tools_calls.append(tool_call)
+
+ return tool_call
+
+
+def _merge_tool_call_delta(
+ tool_call: AssistantPromptMessage.ToolCall,
+ delta: AssistantPromptMessage.ToolCall,
+) -> None:
+ if delta.id:
+ tool_call.id = delta.id
+ if delta.type:
+ tool_call.type = delta.type
+ if delta.function.name:
+ tool_call.function.name = delta.function.name
+ if delta.function.arguments:
+ tool_call.function.arguments += delta.function.arguments
+
+
+def _build_llm_result_from_first_chunk(
+ model: str,
+ prompt_messages: Sequence[PromptMessage],
+ chunks: Iterator[LLMResultChunk],
+) -> LLMResult:
+ """
+ Build a single `LLMResult` from the first returned chunk.
+
+ This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
+ """
+ content = ""
+ content_list: list[PromptMessageContentUnionTypes] = []
+ usage = LLMUsage.empty_usage()
+ system_fingerprint: str | None = None
+ tools_calls: list[AssistantPromptMessage.ToolCall] = []
+
+ first_chunk = next(chunks, None)
+ if first_chunk is not None:
+ if isinstance(first_chunk.delta.message.content, str):
+ content += first_chunk.delta.message.content
+ elif isinstance(first_chunk.delta.message.content, list):
+ content_list.extend(first_chunk.delta.message.content)
+
+ if first_chunk.delta.message.tool_calls:
+ _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
+
+ usage = first_chunk.delta.usage or LLMUsage.empty_usage()
+ system_fingerprint = first_chunk.system_fingerprint
+
+ return LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ message=AssistantPromptMessage(
+ content=content or content_list,
+ tool_calls=tools_calls,
+ ),
+ usage=usage,
+ system_fingerprint=system_fingerprint,
+ )
+
+
+def _invoke_llm_via_plugin(
+ *,
+ tenant_id: str,
+ user_id: str,
+ plugin_id: str,
+ provider: str,
+ model: str,
+ credentials: dict,
+ model_parameters: dict,
+ prompt_messages: Sequence[PromptMessage],
+ tools: list[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: bool,
+) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
+ from core.plugin.impl.model import PluginModelClient
+
+ plugin_model_manager = PluginModelClient()
+ return plugin_model_manager.invoke_llm(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ plugin_id=plugin_id,
+ provider=provider,
+ model=model,
+ credentials=credentials,
+ model_parameters=model_parameters,
+ prompt_messages=list(prompt_messages),
+ tools=tools,
+ stop=list(stop) if stop else None,
+ stream=stream,
+ )
+
+
+def _normalize_non_stream_plugin_result(
+ model: str,
+ prompt_messages: Sequence[PromptMessage],
+ result: Union[LLMResult, Iterator[LLMResultChunk]],
+) -> LLMResult:
+ if isinstance(result, LLMResult):
+ return result
+ return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result)
+
+
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
@@ -40,42 +176,13 @@ def _increase_tool_call(
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
- def get_tool_call(tool_call_id: str):
- """
- Get or create a tool call by ID
-
- :param tool_call_id: tool call ID
- :return: existing or new tool call
- """
- if not tool_call_id:
- return existing_tools_calls[-1]
-
- _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
- if _tool_call is None:
- _tool_call = AssistantPromptMessage.ToolCall(
- id=tool_call_id,
- type="function",
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
- )
- existing_tools_calls.append(_tool_call)
-
- return _tool_call
-
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
- # get tool call
- tool_call = get_tool_call(new_tool_call.id)
- # update tool call
- if new_tool_call.id:
- tool_call.id = new_tool_call.id
- if new_tool_call.type:
- tool_call.type = new_tool_call.type
- if new_tool_call.function.name:
- tool_call.function.name = new_tool_call.function.name
- if new_tool_call.function.arguments:
- tool_call.function.arguments += new_tool_call.function.arguments
+
+ tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id)
+ _merge_tool_call_delta(tool_call, new_tool_call)
class LargeLanguageModel(AIModel):
@@ -141,10 +248,7 @@ class LargeLanguageModel(AIModel):
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
- from core.plugin.impl.model import PluginModelClient
-
- plugin_model_manager = PluginModelClient()
- result = plugin_model_manager.invoke_llm(
+ result = _invoke_llm_via_plugin(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
@@ -154,38 +258,13 @@ class LargeLanguageModel(AIModel):
model_parameters=model_parameters,
prompt_messages=prompt_messages,
tools=tools,
- stop=list(stop) if stop else None,
+ stop=stop,
stream=stream,
)
if not stream:
- content = ""
- content_list = []
- usage = LLMUsage.empty_usage()
- system_fingerprint = None
- tools_calls: list[AssistantPromptMessage.ToolCall] = []
-
- for chunk in result:
- if isinstance(chunk.delta.message.content, str):
- content += chunk.delta.message.content
- elif isinstance(chunk.delta.message.content, list):
- content_list.extend(chunk.delta.message.content)
- if chunk.delta.message.tool_calls:
- _increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
-
- usage = chunk.delta.usage or LLMUsage.empty_usage()
- system_fingerprint = chunk.system_fingerprint
- break
-
- result = LLMResult(
- model=model,
- prompt_messages=prompt_messages,
- message=AssistantPromptMessage(
- content=content or content_list,
- tool_calls=tools_calls,
- ),
- usage=usage,
- system_fingerprint=system_fingerprint,
+ result = _normalize_non_stream_plugin_result(
+ model=model, prompt_messages=prompt_messages, result=result
)
except Exception as e:
self._trigger_invoke_error_callbacks(
@@ -425,27 +504,21 @@ class LargeLanguageModel(AIModel):
:param user: unique user id
:param callbacks: callbacks
"""
- if callbacks:
- for callback in callbacks:
- try:
- callback.on_before_invoke(
- llm_instance=self,
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
- except Exception as e:
- if callback.raise_error:
- raise e
- else:
- logger.warning(
- "Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e
- )
+ _run_callbacks(
+ callbacks,
+ event="on_before_invoke",
+ invoke=lambda callback: callback.on_before_invoke(
+ llm_instance=self,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ ),
+ )
def _trigger_new_chunk_callbacks(
self,
@@ -473,26 +546,22 @@ class LargeLanguageModel(AIModel):
:param stream: is stream response
:param user: unique user id
"""
- if callbacks:
- for callback in callbacks:
- try:
- callback.on_new_chunk(
- llm_instance=self,
- chunk=chunk,
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
- except Exception as e:
- if callback.raise_error:
- raise e
- else:
- logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e)
+ _run_callbacks(
+ callbacks,
+ event="on_new_chunk",
+ invoke=lambda callback: callback.on_new_chunk(
+ llm_instance=self,
+ chunk=chunk,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ ),
+ )
def _trigger_after_invoke_callbacks(
self,
@@ -521,28 +590,22 @@ class LargeLanguageModel(AIModel):
:param user: unique user id
:param callbacks: callbacks
"""
- if callbacks:
- for callback in callbacks:
- try:
- callback.on_after_invoke(
- llm_instance=self,
- result=result,
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
- except Exception as e:
- if callback.raise_error:
- raise e
- else:
- logger.warning(
- "Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e
- )
+ _run_callbacks(
+ callbacks,
+ event="on_after_invoke",
+ invoke=lambda callback: callback.on_after_invoke(
+ llm_instance=self,
+ result=result,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ ),
+ )
def _trigger_invoke_error_callbacks(
self,
@@ -571,25 +634,19 @@ class LargeLanguageModel(AIModel):
:param user: unique user id
:param callbacks: callbacks
"""
- if callbacks:
- for callback in callbacks:
- try:
- callback.on_invoke_error(
- llm_instance=self,
- ex=ex,
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
- except Exception as e:
- if callback.raise_error:
- raise e
- else:
- logger.warning(
- "Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e
- )
+ _run_callbacks(
+ callbacks,
+ event="on_invoke_error",
+ invoke=lambda callback: callback.on_invoke_error(
+ llm_instance=self,
+ ex=ex,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ ),
+ )
diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py
index b8704ef4ed..28f162a928 100644
--- a/api/core/model_runtime/model_providers/model_provider_factory.py
+++ b/api/core/model_runtime/model_providers/model_provider_factory.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import hashlib
import logging
from collections.abc import Sequence
@@ -38,7 +40,7 @@ class ModelProviderFactory:
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
- def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
+ def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
@@ -76,7 +78,7 @@ class ModelProviderFactory:
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
- def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
+ def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
"""
Get plugin model provider
:param provider: provider name
@@ -285,7 +287,7 @@ class ModelProviderFactory:
"""
Get provider icon
:param provider: provider name
- :param icon_type: icon type (icon_small or icon_large)
+ :param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return: provider icon
"""
@@ -309,13 +311,7 @@ class ModelProviderFactory:
else:
file_name = provider_schema.icon_small_dark.en_US
else:
- if not provider_schema.icon_large:
- raise ValueError(f"Provider {provider} does not have large icon.")
-
- if lang.lower() == "zh_hans":
- file_name = provider_schema.icon_large.zh_Hans
- else:
- file_name = provider_schema.icon_large.en_US
+ raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py
index d6bd4d2015..22ad756c91 100644
--- a/api/core/ops/aliyun_trace/aliyun_trace.py
+++ b/api/core/ops/aliyun_trace/aliyun_trace.py
@@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
+from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
@@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)
diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
index d3324f8f82..7624586367 100644
--- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py
+++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
@@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
- kind=trace_api.SpanKind.INTERNAL,
+ kind=span_data.span_kind,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
index 20ff2d0875..9078031490 100644
--- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
+++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
@@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
-from opentelemetry.trace import Status, StatusCode
+from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field
@@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
+ span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index c36391c940..549e428f88 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -471,6 +471,9 @@ class TraceTask:
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
+ # Lazy import to avoid circular import during module initialization
+ from repositories.factory import DifyAPIRepositoryFactory
+
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py
index 3b83121357..6674228dc0 100644
--- a/api/core/plugin/entities/plugin_daemon.py
+++ b/api/core/plugin/entities/plugin_daemon.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
@@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
return [item.value for item in cls]
@classmethod
- def of(cls, credential_type: str) -> "CredentialType":
+ def of(cls, credential_type: str) -> CredentialType:
type_name = credential_type.lower()
if type_name in {"api-key", "api_key"}:
return cls.API_KEY
diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py
index 7bb2749afa..7a6a598a2f 100644
--- a/api/core/plugin/impl/base.py
+++ b/api/core/plugin/impl/base.py
@@ -103,6 +103,9 @@ class BasePluginClient:
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
+ # Inject traceparent header for distributed tracing
+ self._inject_trace_headers(prepared_headers)
+
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
@@ -114,6 +117,31 @@ class BasePluginClient:
return str(url), prepared_headers, prepared_data, params, files
+ def _inject_trace_headers(self, headers: dict[str, str]) -> None:
+ """
+ Inject W3C traceparent header for distributed tracing.
+
+ This ensures trace context is propagated to plugin daemon even if
+ HTTPXClientInstrumentor doesn't cover module-level httpx functions.
+ """
+ if not dify_config.ENABLE_OTEL:
+ return
+
+ import contextlib
+
+ # Skip if already present (case-insensitive check)
+ for key in headers:
+ if key.lower() == "traceparent":
+ return
+
+ # Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
+ with contextlib.suppress(Exception):
+ from core.helper.trace_id_helper import generate_traceparent_header
+
+ traceparent = generate_traceparent_header()
+ if traceparent:
+ headers["traceparent"] = traceparent
+
def _stream_request(
self,
method: str,
@@ -292,18 +320,17 @@ class BasePluginClient:
case PluginInvokeError.__name__:
error_object = json.loads(message)
invoke_error_type = error_object.get("error_type")
- args = error_object.get("args")
match invoke_error_type:
case InvokeRateLimitError.__name__:
- raise InvokeRateLimitError(description=args.get("description"))
+ raise InvokeRateLimitError(description=error_object.get("message"))
case InvokeAuthorizationError.__name__:
- raise InvokeAuthorizationError(description=args.get("description"))
+ raise InvokeAuthorizationError(description=error_object.get("message"))
case InvokeBadRequestError.__name__:
- raise InvokeBadRequestError(description=args.get("description"))
+ raise InvokeBadRequestError(description=error_object.get("message"))
case InvokeConnectionError.__name__:
- raise InvokeConnectionError(description=args.get("description"))
+ raise InvokeConnectionError(description=error_object.get("message"))
case InvokeServerUnavailableError.__name__:
- raise InvokeServerUnavailableError(description=args.get("description"))
+ raise InvokeServerUnavailableError(description=error_object.get("message"))
case CredentialsValidateFailedError.__name__:
raise CredentialsValidateFailedError(error_object.get("message"))
case EndpointSetupFailedError.__name__:
@@ -311,11 +338,11 @@ class BasePluginClient:
case TriggerProviderCredentialValidationError.__name__:
raise TriggerProviderCredentialValidationError(error_object.get("message"))
case TriggerPluginInvokeError.__name__:
- raise TriggerPluginInvokeError(description=error_object.get("description"))
+ raise TriggerPluginInvokeError(description=error_object.get("message"))
case TriggerInvokeError.__name__:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
- raise EventIgnoreError(description=error_object.get("description"))
+ raise EventIgnoreError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:
diff --git a/api/core/plugin/impl/endpoint.py b/api/core/plugin/impl/endpoint.py
index 5b88742be5..2db5185a2c 100644
--- a/api/core/plugin/impl/endpoint.py
+++ b/api/core/plugin/impl/endpoint.py
@@ -1,5 +1,6 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
+from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
@@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
+
+ This operation is idempotent: if the endpoint is already deleted (record not found),
+ it will return True instead of raising an error.
"""
- return self._request_with_plugin_daemon_response(
- "POST",
- f"plugin/{tenant_id}/endpoint/remove",
- bool,
- data={
- "endpoint_id": endpoint_id,
- },
- headers={
- "Content-Type": "application/json",
- },
- )
+ try:
+ return self._request_with_plugin_daemon_response(
+ "POST",
+ f"plugin/{tenant_id}/endpoint/remove",
+ bool,
+ data={
+ "endpoint_id": endpoint_id,
+ },
+ headers={
+ "Content-Type": "application/json",
+ },
+ )
+ except PluginDaemonInternalServerError as e:
+ # Make delete idempotent: if record is not found, consider it a success
+ if "record not found" in str(e.description).lower():
+ return True
+ raise
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 6c818bdc8b..fdbfca4330 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -331,7 +331,6 @@ class ProviderManager:
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
- icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types,
),
)
@@ -619,18 +618,18 @@ class ProviderManager:
)
for quota in configuration.quotas:
- if quota.quota_type == ProviderQuotaType.TRIAL:
+ if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
# Init trial provider records if not exists
- if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
+ if quota.quota_type not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
- provider_type=ProviderType.SYSTEM,
- quota_type=ProviderQuotaType.TRIAL,
- quota_limit=quota.quota_limit, # type: ignore
+ provider_type=ProviderType.SYSTEM.value,
+ quota_type=quota.quota_type,
+ quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
)
@@ -642,8 +641,8 @@ class ProviderManager:
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
- Provider.provider_type == ProviderType.SYSTEM,
- Provider.quota_type == ProviderQuotaType.TRIAL,
+ Provider.provider_type == ProviderType.SYSTEM.value,
+ Provider.quota_type == quota.quota_type,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@@ -913,6 +912,22 @@ class ProviderManager:
provider_record
)
quota_configurations = []
+
+ if dify_config.EDITION == "CLOUD":
+ from services.credit_pool_service import CreditPoolService
+
+ trail_pool = CreditPoolService.get_pool(
+ tenant_id=tenant_id,
+ pool_type=ProviderQuotaType.TRIAL.value,
+ )
+ paid_pool = CreditPoolService.get_pool(
+ tenant_id=tenant_id,
+ pool_type=ProviderQuotaType.PAID.value,
+ )
+ else:
+ trail_pool = None
+ paid_pool = None
+
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
@@ -933,16 +948,36 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
+ if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
+ quota_configuration = QuotaConfiguration(
+ quota_type=provider_quota.quota_type,
+ quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
+ quota_used=trail_pool.quota_used,
+ quota_limit=trail_pool.quota_limit,
+ is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
+ restrict_models=provider_quota.restrict_models,
+ )
- quota_configuration = QuotaConfiguration(
- quota_type=provider_quota.quota_type,
- quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
- quota_used=provider_record.quota_used,
- quota_limit=provider_record.quota_limit,
- is_valid=provider_record.quota_limit > provider_record.quota_used
- or provider_record.quota_limit == -1,
- restrict_models=provider_quota.restrict_models,
- )
+ elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
+ quota_configuration = QuotaConfiguration(
+ quota_type=provider_quota.quota_type,
+ quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
+ quota_used=paid_pool.quota_used,
+ quota_limit=paid_pool.quota_limit,
+ is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
+ restrict_models=provider_quota.restrict_models,
+ )
+
+ else:
+ quota_configuration = QuotaConfiguration(
+ quota_type=provider_quota.quota_type,
+ quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
+ quota_used=provider_record.quota_used,
+ quota_limit=provider_record.quota_limit,
+ is_valid=provider_record.quota_limit > provider_record.quota_used
+ or provider_record.quota_limit == -1,
+ restrict_models=provider_quota.restrict_models,
+ )
quota_configurations.append(quota_configuration)
diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py
index 9cb009035b..e182c35b99 100644
--- a/api/core/rag/cleaner/clean_processor.py
+++ b/api/core/rag/cleaner/clean_processor.py
@@ -27,26 +27,44 @@ class CleanProcessor:
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
- # Remove URL but keep Markdown image URLs
- # First, temporarily replace Markdown image URLs with a placeholder
- markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
- placeholders: list[str] = []
+ # Remove URL but keep Markdown image URLs and link URLs
+ # Replace the ENTIRE markdown link/image with a single placeholder to protect
+ # the link text (which might also be a URL) from being removed
+ markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
+ markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
+ placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
- def replace_with_placeholder(match, placeholders=placeholders):
+ def replace_markdown_with_placeholder(match, placeholders=placeholders):
+ link_type = "link"
+ link_text = match.group(1)
+ url = match.group(2)
+ placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
+ placeholders.append((link_type, link_text, url))
+ return placeholder
+
+ def replace_image_with_placeholder(match, placeholders=placeholders):
+ link_type = "image"
url = match.group(1)
- placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
- placeholders.append(url)
- return f""
+ placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
+ placeholders.append((link_type, "image", url))
+ return placeholder
- text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
+ # Protect markdown links first
+ text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
+ # Then protect markdown images
+ text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
# Now remove all remaining URLs
- url_pattern = r"https?://[^\s)]+"
+ url_pattern = r"https?://\S+"
text = re.sub(url_pattern, "", text)
- # Finally, restore the Markdown image URLs
- for i, url in enumerate(placeholders):
- text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
+ # Restore the Markdown links and images
+ for i, (link_type, text_or_alt, url) in enumerate(placeholders):
+ placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
+ if link_type == "link":
+ text = text.replace(placeholder, f"[{text_or_alt}]({url})")
+ else: # image
+ text = text.replace(placeholder, f"")
return text
def filter_string(self, text):
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 9807cb4e6a..8ec1ce6242 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -1,4 +1,5 @@
import concurrent.futures
+import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@@ -13,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
-from core.rag.embedding.retrieval import RetrievalSegments
+from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@@ -36,6 +37,8 @@ default_retrieval_model = {
"score_threshold_enabled": False,
}
+logger = logging.getLogger(__name__)
+
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@@ -106,7 +109,12 @@ class RetrievalService:
)
)
- concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
+ if futures:
+ for future in concurrent.futures.as_completed(futures, timeout=3600):
+ if exceptions:
+ for f in futures:
+ f.cancel()
+ break
if exceptions:
raise ValueError(";\n".join(exceptions))
@@ -210,6 +218,7 @@ class RetrievalService:
)
all_documents.extend(documents)
except Exception as e:
+ logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@@ -303,6 +312,7 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
+ logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@@ -351,6 +361,7 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
+ logger.error(e, exc_info=True)
exceptions.append(str(e))
@staticmethod
@@ -381,10 +392,9 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
- segment_file_map = {}
valid_dataset_documents = {}
- image_doc_ids = []
+ image_doc_ids: list[Any] = []
child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
@@ -417,28 +427,39 @@ class RetrievalService:
child_index_node_ids = [i for i in child_index_node_ids if i]
index_node_ids = [i for i in index_node_ids if i]
- segment_ids = []
+ segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
- attachment_map = {}
- child_chunk_map = {}
- doc_segment_map = {}
+ attachment_map: dict[str, list[dict[str, Any]]] = {}
+ child_chunk_map: dict[str, list[ChildChunk]] = {}
+ doc_segment_map: dict[str, list[str]] = {}
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
for attachment in attachments:
segment_ids.append(attachment["segment_id"])
- attachment_map[attachment["segment_id"]] = attachment
- doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
-
+ if attachment["segment_id"] in attachment_map:
+ attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
+ else:
+ attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
+ if attachment["segment_id"] in doc_segment_map:
+ doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
+ else:
+ doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
segment_ids.append(i.segment_id)
- child_chunk_map[i.segment_id] = i
- doc_segment_map[i.segment_id] = i.index_node_id
+ if i.segment_id in child_chunk_map:
+ child_chunk_map[i.segment_id].append(i)
+ else:
+ child_chunk_map[i.segment_id] = [i]
+ if i.segment_id in doc_segment_map:
+ doc_segment_map[i.segment_id].append(i.index_node_id)
+ else:
+ doc_segment_map[i.segment_id] = [i.index_node_id]
if index_node_ids:
document_segment_stmt = select(DocumentSegment).where(
@@ -448,7 +469,7 @@ class RetrievalService:
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
- doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
+ doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
@@ -461,95 +482,86 @@ class RetrievalService:
segments.extend(index_node_segments)
for segment in segments:
- doc_id = doc_segment_map.get(segment.id)
- child_chunk = child_chunk_map.get(segment.id)
- attachment_info = attachment_map.get(segment.id)
+ child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
+ attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
+ ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
- if doc_id:
- document = doc_to_document_map[doc_id]
- ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
- document.metadata.get("document_id")
- )
-
- if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- if segment.id not in include_segment_ids:
- include_segment_ids.add(segment.id)
- if child_chunk:
+ if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ if segment.id not in include_segment_ids:
+ include_segment_ids.add(segment.id)
+ if child_chunks or attachment_infos:
+ child_chunk_details = []
+ max_score = 0.0
+ for child_chunk in child_chunks:
+ document = doc_to_document_map[child_chunk.index_node_id]
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0) if document else 0.0,
}
- map_detail = {
- "max_score": document.metadata.get("score", 0.0) if document else 0.0,
- "child_chunks": [child_chunk_detail],
- }
- segment_child_map[segment.id] = map_detail
- record = {
- "segment": segment,
+ child_chunk_details.append(child_chunk_detail)
+ max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
+ for attachment_info in attachment_infos:
+ file_document = doc_to_document_map[attachment_info["id"]]
+ max_score = max(
+ max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
+ )
+
+ map_detail = {
+ "max_score": max_score,
+ "child_chunks": child_chunk_details,
}
- if attachment_info:
- segment_file_map[segment.id] = [attachment_info]
- records.append(record)
- else:
- if child_chunk:
- child_chunk_detail = {
- "id": child_chunk.id,
- "content": child_chunk.content,
- "position": child_chunk.position,
- "score": document.metadata.get("score", 0.0),
- }
- if segment.id in segment_child_map:
- segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
- segment_child_map[segment.id]["max_score"] = max(
- segment_child_map[segment.id]["max_score"],
- document.metadata.get("score", 0.0) if document else 0.0,
- )
- else:
- segment_child_map[segment.id] = {
- "max_score": document.metadata.get("score", 0.0) if document else 0.0,
- "child_chunks": [child_chunk_detail],
- }
- if attachment_info:
- if segment.id in segment_file_map:
- segment_file_map[segment.id].append(attachment_info)
- else:
- segment_file_map[segment.id] = [attachment_info]
- else:
- if segment.id not in include_segment_ids:
- include_segment_ids.add(segment.id)
- record = {
- "segment": segment,
- "score": document.metadata.get("score", 0.0), # type: ignore
- }
- if attachment_info:
- segment_file_map[segment.id] = [attachment_info]
- records.append(record)
- else:
- if attachment_info:
- attachment_infos = segment_file_map.get(segment.id, [])
- if attachment_info not in attachment_infos:
- attachment_infos.append(attachment_info)
- segment_file_map[segment.id] = attachment_infos
+ segment_child_map[segment.id] = map_detail
+ record: dict[str, Any] = {
+ "segment": segment,
+ }
+ records.append(record)
+ else:
+ if segment.id not in include_segment_ids:
+ include_segment_ids.add(segment.id)
+ max_score = 0.0
+ segment_document = doc_to_document_map.get(segment.index_node_id)
+ if segment_document:
+ max_score = max(max_score, segment_document.metadata.get("score", 0.0))
+ for attachment_info in attachment_infos:
+ file_doc = doc_to_document_map.get(attachment_info["id"])
+ if file_doc:
+ max_score = max(max_score, file_doc.metadata.get("score", 0.0))
+ record = {
+ "segment": segment,
+ "score": max_score,
+ }
+ records.append(record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
- if record["segment"].id in segment_file_map:
- record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
+ if record["segment"].id in attachment_map:
+ record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
- result = []
+ result: list[RetrievalSegments] = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
- child_chunks = record.get("child_chunks")
- if not isinstance(child_chunks, list):
- child_chunks = None
+ raw_child_chunks = record.get("child_chunks")
+ child_chunks_list: list[RetrievalChildChunk] | None = None
+ if isinstance(raw_child_chunks, list):
+ # Sort by score descending
+ sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
+ child_chunks_list = [
+ RetrievalChildChunk(
+ id=chunk["id"],
+ content=chunk["content"],
+ score=chunk.get("score", 0.0),
+ position=chunk["position"],
+ )
+ for chunk in sorted_chunks
+ ]
# Extract files, ensuring it's a list or None
files = record.get("files")
@@ -566,11 +578,11 @@ class RetrievalService:
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
- segment=segment, child_chunks=child_chunks, score=score, files=files
+ segment=segment, child_chunks=child_chunks_list, score=score, files=files
)
result.append(retrieval_segment)
- return result
+ return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
except Exception as e:
db.session.rollback()
raise e
@@ -662,7 +674,14 @@ class RetrievalService:
document_ids_filter=document_ids_filter,
)
)
- concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
+ # Use as_completed for early error propagation - cancel remaining futures on first error
+ if futures:
+ for future in concurrent.futures.as_completed(futures, timeout=300):
+ if future.exception():
+ # Cancel remaining futures to avoid unnecessary waiting
+ for f in futures:
+ f.cancel()
+ break
if exceptions:
raise ValueError(";\n".join(exceptions))
diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
index a306f9ba0c..91bb71bfa6 100644
--- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
+++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import contextlib
import json
import logging
@@ -6,7 +8,7 @@ import re
import threading
import time
import uuid
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
@@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
Manages connection reuse across ClickzettaVector instances.
"""
- _instance: Optional["ClickzettaConnectionPool"] = None
+ _instance: ClickzettaConnectionPool | None = None
_lock = threading.Lock()
def __init__(self):
@@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
self._start_cleanup_thread()
@classmethod
- def get_instance(cls) -> "ClickzettaConnectionPool":
+ def get_instance(cls) -> ClickzettaConnectionPool:
"""Get singleton instance of connection pool."""
if cls._instance is None:
with cls._lock:
@@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
)
- def _create_connection(self, config: ClickzettaConfig) -> "Connection":
+ def _create_connection(self, config: ClickzettaConfig) -> Connection:
"""Create a new ClickZetta connection."""
max_retries = 3
retry_delay = 1.0
@@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
- def _configure_connection(self, connection: "Connection"):
+ def _configure_connection(self, connection: Connection):
"""Configure connection session settings."""
try:
with connection.cursor() as cursor:
@@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
except Exception:
logger.exception("Failed to configure connection, continuing with defaults")
- def _is_connection_valid(self, connection: "Connection") -> bool:
+ def _is_connection_valid(self, connection: Connection) -> bool:
"""Check if connection is still valid."""
try:
with connection.cursor() as cursor:
@@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
except Exception:
return False
- def get_connection(self, config: ClickzettaConfig) -> "Connection":
+ def get_connection(self, config: ClickzettaConfig) -> Connection:
"""Get a connection from the pool or create a new one."""
config_key = self._get_config_key(config)
@@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
# No valid connection found, create new one
return self._create_connection(config)
- def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
+ def return_connection(self, config: ClickzettaConfig, connection: Connection):
"""Return a connection to the pool."""
config_key = self._get_config_key(config)
@@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
self._connection_pool = ClickzettaConnectionPool.get_instance()
self._init_write_queue()
- def _get_connection(self) -> "Connection":
+ def _get_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._connection_pool.get_connection(self._config)
- def _return_connection(self, connection: "Connection"):
+ def _return_connection(self, connection: Connection):
"""Return a connection to the pool."""
self._connection_pool.return_connection(self._config, connection)
class ConnectionContext:
"""Context manager for borrowing and returning connections."""
- def __init__(self, vector_instance: "ClickzettaVector"):
+ def __init__(self, vector_instance: ClickzettaVector):
self.vector = vector_instance
self.connection: Connection | None = None
- def __enter__(self) -> "Connection":
+ def __enter__(self) -> Connection:
self.connection = self.vector._get_connection()
return self.connection
@@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
if self.connection:
self.vector._return_connection(self.connection)
- def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
+ def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
"""Get a connection context manager."""
return self.ConnectionContext(self)
@@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
"""Return the vector database type."""
return "clickzetta"
- def _ensure_connection(self) -> "Connection":
+ def _ensure_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._get_connection()
@@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector):
# No need for dataset_id filter since each dataset has its own table
- # Use simple quote escaping for LIKE clause
- escaped_query = query.replace("'", "''")
- filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
+ # Escape special characters for LIKE clause to prevent SQL injection
+ from libs.helper import escape_like_pattern
+
+ escaped_query = escape_like_pattern(query).replace("'", "''")
+ filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""
diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py
index b1bfabb76e..50bb2429ec 100644
--- a/api/core/rag/datasource/vdb/iris/iris_vector.py
+++ b/api/core/rag/datasource/vdb/iris/iris_vector.py
@@ -154,7 +154,7 @@ class IrisConnectionPool:
# Add to cache to skip future checks
self._schemas_initialized.add(schema)
- except Exception as e:
+ except Exception:
conn.rollback()
logger.exception("Failed to ensure schema %s exists", schema)
raise
@@ -177,6 +177,9 @@ class IrisConnectionPool:
class IrisVector(BaseVector):
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
+ # Fallback score for full-text search when Rank function unavailable or TEXT_INDEX disabled
+ _FULL_TEXT_FALLBACK_SCORE = 0.5
+
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
super().__init__(collection_name)
self.config = config
@@ -272,37 +275,131 @@ class IrisVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
- """Search documents by full-text using iFind index or fallback to LIKE search."""
+ """Search documents by full-text using iFind index with BM25 relevance scoring.
+
+ When IRIS_TEXT_INDEX is enabled, this method uses the auto-generated Rank
+ function from %iFind.Index.Basic to calculate BM25 relevance scores. The Rank
+ function is automatically created with naming: {schema}.{table_name}_{index}Rank
+
+ Args:
+ query: Search query string
+ **kwargs: Optional parameters including top_k, document_ids_filter
+
+ Returns:
+ List of Document objects with relevance scores in metadata["score"]
+ """
top_k = kwargs.get("top_k", 5)
+ document_ids_filter = kwargs.get("document_ids_filter")
with self._get_cursor() as cursor:
if self.config.IRIS_TEXT_INDEX:
- # Use iFind full-text search with index
+ # Use iFind full-text search with auto-generated Rank function
text_index_name = f"idx_{self.table_name}_text"
+ # IRIS removes underscores from function names
+ table_no_underscore = self.table_name.replace("_", "")
+ index_no_underscore = text_index_name.replace("_", "")
+ rank_function = f"{self.schema}.{table_no_underscore}_{index_no_underscore}Rank"
+
+ # Build WHERE clause with document ID filter if provided
+ where_clause = f"WHERE %ID %FIND search_index({text_index_name}, ?)"
+ # First param for Rank function, second for FIND
+ params = [query, query]
+
+ if document_ids_filter:
+ # Add document ID filter
+ placeholders = ",".join("?" * len(document_ids_filter))
+ where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
+ params.extend(document_ids_filter)
+
sql = f"""
- SELECT TOP {top_k} id, text, meta
+ SELECT TOP {top_k}
+ id,
+ text,
+ meta,
+ {rank_function}(%ID, ?) AS score
FROM {self.schema}.{self.table_name}
- WHERE %ID %FIND search_index({text_index_name}, ?)
+ {where_clause}
+ ORDER BY score DESC
"""
- cursor.execute(sql, (query,))
+
+ logger.debug(
+ "iFind search: query='%s', index='%s', rank='%s'",
+ query,
+ text_index_name,
+ rank_function,
+ )
+
+ try:
+ cursor.execute(sql, params)
+ except Exception: # pylint: disable=broad-exception-caught
+ # Fallback to query without Rank function if it fails
+ logger.warning(
+ "Rank function '%s' failed, using fixed score",
+ rank_function,
+ exc_info=True,
+ )
+ sql_fallback = f"""
+ SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
+ FROM {self.schema}.{self.table_name}
+ {where_clause}
+ """
+ # Skip first param (for Rank function)
+ cursor.execute(sql_fallback, params[1:])
else:
- # Fallback to LIKE search (inefficient for large datasets)
- query_pattern = f"%{query}%"
+ # Fallback to LIKE search (IRIS_TEXT_INDEX disabled)
+ from libs.helper import ( # pylint: disable=import-outside-toplevel
+ escape_like_pattern,
+ )
+
+ escaped_query = escape_like_pattern(query)
+ query_pattern = f"%{escaped_query}%"
+
+ # Build WHERE clause with document ID filter if provided
+ where_clause = "WHERE text LIKE ? ESCAPE '\\\\'"
+ params = [query_pattern]
+
+ if document_ids_filter:
+ placeholders = ",".join("?" * len(document_ids_filter))
+ where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
+ params.extend(document_ids_filter)
+
sql = f"""
- SELECT TOP {top_k} id, text, meta
+ SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
FROM {self.schema}.{self.table_name}
- WHERE text LIKE ?
+ {where_clause}
+ ORDER BY LENGTH(text) ASC
"""
- cursor.execute(sql, (query_pattern,))
+
+ logger.debug(
+ "LIKE fallback (TEXT_INDEX disabled): query='%s'",
+ query_pattern,
+ )
+ cursor.execute(sql, params)
docs = []
for row in cursor.fetchall():
- if len(row) >= 3:
- metadata = json.loads(row[2]) if row[2] else {}
- docs.append(Document(page_content=row[1], metadata=metadata))
+ # Expecting 4 columns: id, text, meta, score
+ if len(row) >= 4:
+ text_content = row[1]
+ meta_str = row[2]
+ score_value = row[3]
+
+ metadata = json.loads(meta_str) if meta_str else {}
+ # Add score to metadata for hybrid search compatibility
+ score = float(score_value) if score_value is not None else 0.0
+ metadata["score"] = score
+
+ docs.append(Document(page_content=text_content, metadata=metadata))
+
+ logger.info(
+ "Full-text search completed: query='%s', results=%d/%d",
+ query,
+ len(docs),
+ top_k,
+ )
if not docs:
- logger.info("Full-text search for '%s' returned no results", query)
+ logger.warning("Full-text search for '%s' returned no results", query)
return docs
@@ -366,7 +463,11 @@ class IrisVector(BaseVector):
AS %iFind.Index.Basic
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
"""
- logger.info("Creating text index: %s with language: %s", text_index_name, language)
+ logger.info(
+ "Creating text index: %s with language: %s",
+ text_index_name,
+ language,
+ )
logger.info("SQL for text index: %s", sql_text_index)
cursor.execute(sql_text_index)
logger.info("Text index created successfully: %s", text_index_name)
diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py
index 445a0a7f8b..0615b8312c 100644
--- a/api/core/rag/datasource/vdb/pgvector/pgvector.py
+++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py
@@ -255,7 +255,10 @@ class PGVector(BaseVector):
return
with self._get_cursor() as cur:
- cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
+ cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
+ if not cur.fetchone():
+ cur.execute("CREATE EXTENSION vector")
+
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
index 84d1e26b34..b48dd93f04 100644
--- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
+++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
@@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
in a Weaviate collection.
"""
+ _DOCUMENT_ID_PROPERTY = "document_id"
+
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
@@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
return []
col = self._client.collections.use(self._collection_name)
- props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
+ props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
- ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
- where = ors[0]
- for f in ors[1:]:
- where = where | f
+ where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
- ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
- where = ors[0]
- for f in ors[1:]:
- where = where | f
+ where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))
diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py
index 1fe74d3042..69adac522d 100644
--- a/api/core/rag/docstore/dataset_docstore.py
+++ b/api/core/rag/docstore/dataset_docstore.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from collections.abc import Sequence
from typing import Any
@@ -22,7 +24,7 @@ class DatasetDocumentStore:
self._document_id = document_id
@classmethod
- def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
+ def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
return cls(**config_dict)
def to_dict(self) -> dict[str, Any]:
diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py
index 013c287248..6d28ce25bc 100644
--- a/api/core/rag/extractor/extract_processor.py
+++ b/api/core/rag/extractor/extract_processor.py
@@ -112,7 +112,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
- extractor = PdfExtractor(file_path)
+ extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@@ -148,7 +148,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
- extractor = PdfExtractor(file_path)
+ extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py
index 80530d99a6..6aabcac704 100644
--- a/api/core/rag/extractor/pdf_extractor.py
+++ b/api/core/rag/extractor/pdf_extractor.py
@@ -1,25 +1,57 @@
"""Abstract interface for document loader implementations."""
import contextlib
+import io
+import logging
+import uuid
from collections.abc import Iterator
+import pypdfium2
+import pypdfium2.raw as pdfium_c
+
+from configs import dify_config
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
+from extensions.ext_database import db
from extensions.ext_storage import storage
+from libs.datetime_utils import naive_utc_now
+from models.enums import CreatorUserRole
+from models.model import UploadFile
+
+logger = logging.getLogger(__name__)
class PdfExtractor(BaseExtractor):
- """Load pdf files.
-
+ """
+ PdfExtractor is used to extract text and images from PDF files.
Args:
- file_path: Path to the file to load.
+ file_path: Path to the PDF file.
+ tenant_id: Workspace ID.
+ user_id: ID of the user performing the extraction.
+ file_cache_key: Optional cache key for the extracted text.
"""
- def __init__(self, file_path: str, file_cache_key: str | None = None):
- """Initialize with file path."""
+ # Magic bytes for image format detection: (magic_bytes, extension, mime_type)
+ IMAGE_FORMATS = [
+ (b"\xff\xd8\xff", "jpg", "image/jpeg"),
+ (b"\x89PNG\r\n\x1a\n", "png", "image/png"),
+ (b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
+ (b"GIF8", "gif", "image/gif"),
+ (b"BM", "bmp", "image/bmp"),
+ (b"II*\x00", "tiff", "image/tiff"),
+ (b"MM\x00*", "tiff", "image/tiff"),
+ (b"II+\x00", "tiff", "image/tiff"),
+ (b"MM\x00+", "tiff", "image/tiff"),
+ ]
+ MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
+
+ def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
+ """Initialize PdfExtractor."""
self._file_path = file_path
+ self._tenant_id = tenant_id
+ self._user_id = user_id
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
@@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
- import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
@@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
+
+ image_content = self._extract_images(page)
+ if image_content:
+ content += "\n" + image_content
+
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()
+
+ def _extract_images(self, page) -> str:
+ """
+ Extract images from a PDF page, save them to storage and database,
+ and return markdown image links.
+
+ Args:
+ page: pypdfium2 page object.
+
+ Returns:
+ Markdown string containing links to the extracted images.
+ """
+ image_content = []
+ upload_files = []
+ base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
+
+ try:
+ image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
+ for obj in image_objects:
+ try:
+ # Extract image bytes
+ img_byte_arr = io.BytesIO()
+ # Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
+ # Fallback to png for other formats
+ obj.extract(img_byte_arr, fb_format="png")
+ img_bytes = img_byte_arr.getvalue()
+
+ if not img_bytes:
+ continue
+
+ header = img_bytes[: self.MAX_MAGIC_LEN]
+ image_ext = None
+ mime_type = None
+ for magic, ext, mime in self.IMAGE_FORMATS:
+ if header.startswith(magic):
+ image_ext = ext
+ mime_type = mime
+ break
+
+ if not image_ext or not mime_type:
+ continue
+
+ file_uuid = str(uuid.uuid4())
+ file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
+
+ storage.save(file_key, img_bytes)
+
+ # save file to db
+ upload_file = UploadFile(
+ tenant_id=self._tenant_id,
+ storage_type=dify_config.STORAGE_TYPE,
+ key=file_key,
+ name=file_key,
+ size=len(img_bytes),
+ extension=image_ext,
+ mime_type=mime_type,
+ created_by=self._user_id,
+ created_by_role=CreatorUserRole.ACCOUNT,
+ created_at=naive_utc_now(),
+ used=True,
+ used_by=self._user_id,
+ used_at=naive_utc_now(),
+ )
+ upload_files.append(upload_file)
+ image_content.append(f"")
+ except Exception as e:
+ logger.warning("Failed to extract image from PDF: %s", e)
+ continue
+ except Exception as e:
+ logger.warning("Failed to get objects from PDF page: %s", e)
+ if upload_files:
+ db.session.add_all(upload_files)
+ db.session.commit()
+ return "\n".join(image_content)
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index f67f613e9d..511f5a698d 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -7,10 +7,11 @@ import re
import tempfile
import uuid
from urllib.parse import urlparse
-from xml.etree import ElementTree
import httpx
from docx import Document as DocxDocument
+from docx.oxml.ns import qn
+from docx.text.run import Run
from configs import dify_config
from core.helper import ssrf_proxy
@@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
image_map = self._extract_images_from_docx(doc)
- hyperlinks_url = None
- url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
- for para in doc.paragraphs:
- for run in para.runs:
- if run.text and hyperlinks_url:
- result = f" [{run.text}]({hyperlinks_url}) "
- run.text = result
- hyperlinks_url = None
- if "HYPERLINK" in run.element.xml:
- try:
- xml = ElementTree.XML(run.element.xml)
- x_child = [c for c in xml.iter() if c is not None]
- for x in x_child:
- if x is None:
- continue
- if x.tag.endswith("instrText"):
- if x.text is None:
- continue
- for i in url_pattern.findall(x.text):
- hyperlinks_url = str(i)
- except Exception:
- logger.exception("Failed to parse HYPERLINK xml")
-
def parse_paragraph(paragraph):
- paragraph_content = []
-
- def append_image_link(image_id, has_drawing):
+ def append_image_link(image_id, has_drawing, target_buffer):
"""Helper to append image link from image_map based on relationship type."""
rel = doc.part.rels[image_id]
if rel.is_external:
if image_id in image_map and not has_drawing:
- paragraph_content.append(image_map[image_id])
+ target_buffer.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map and not has_drawing:
- paragraph_content.append(image_map[image_part])
+ target_buffer.append(image_map[image_part])
- for run in paragraph.runs:
+ def process_run(run, target_buffer):
+ # Helper to extract text and embedded images from a run element and append them to target_buffer
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
drawing_elements = run.element.findall(
@@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
# External image: use embed_id as key
if embed_id in image_map:
has_drawing = True
- paragraph_content.append(image_map[embed_id])
+ target_buffer.append(image_map[embed_id])
else:
# Internal image: use target_part as key
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
- paragraph_content.append(image_map[image_part])
+ target_buffer.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
@@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
- append_image_link(image_id, has_drawing)
+ append_image_link(image_id, has_drawing, target_buffer)
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
@@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
- append_image_link(image_id, has_drawing)
+ append_image_link(image_id, has_drawing, target_buffer)
if run.text.strip():
- paragraph_content.append(run.text.strip())
+ target_buffer.append(run.text.strip())
+
+ def process_hyperlink(hyperlink_elem, target_buffer):
+ # Helper to extract text from a hyperlink element and append it to target_buffer
+ r_id = hyperlink_elem.get(qn("r:id"))
+
+ # Extract text from runs inside the hyperlink
+ link_text_parts = []
+ for run_elem in hyperlink_elem.findall(qn("w:r")):
+ run = Run(run_elem, paragraph)
+ # Hyperlink text may be split across multiple runs (e.g., with different formatting),
+ # so collect all run texts first
+ if run.text:
+ link_text_parts.append(run.text)
+
+ link_text = "".join(link_text_parts).strip()
+
+ # Resolve URL
+ if r_id:
+ try:
+ rel = doc.part.rels.get(r_id)
+ if rel and rel.is_external:
+ link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
+ except Exception:
+ logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
+
+ if link_text:
+ target_buffer.append(link_text)
+
+ paragraph_content = []
+ # State for legacy HYPERLINK fields
+ hyperlink_field_url = None
+ hyperlink_field_text_parts: list = []
+ is_collecting_field_text = False
+ # Iterate through paragraph elements in document order
+ for child in paragraph._element:
+ tag = child.tag
+ if tag == qn("w:r"):
+ # Regular run
+ run = Run(child, paragraph)
+
+ # Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
+ fld_chars = child.findall(qn("w:fldChar"))
+ instr_texts = child.findall(qn("w:instrText"))
+
+ # Handle Fields
+ if fld_chars or instr_texts:
+ # Process instrText to find HYPERLINK "url"
+ for instr in instr_texts:
+ if instr.text and "HYPERLINK" in instr.text:
+ # Quick regex to extract URL
+ match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
+ if match:
+ hyperlink_field_url = match.group(1)
+
+ # Process fldChar
+ for fld_char in fld_chars:
+ fld_char_type = fld_char.get(qn("w:fldCharType"))
+ if fld_char_type == "begin":
+ # Start of a field: reset legacy link state
+ hyperlink_field_url = None
+ hyperlink_field_text_parts = []
+ is_collecting_field_text = False
+ elif fld_char_type == "separate":
+ # Separator: if we found a URL, start collecting visible text
+ if hyperlink_field_url:
+ is_collecting_field_text = True
+ elif fld_char_type == "end":
+ # End of field
+ if is_collecting_field_text and hyperlink_field_url:
+ # Create markdown link and append to main content
+ display_text = "".join(hyperlink_field_text_parts).strip()
+ if display_text:
+ link_md = f"[{display_text}]({hyperlink_field_url})"
+ paragraph_content.append(link_md)
+ # Reset state
+ hyperlink_field_url = None
+ hyperlink_field_text_parts = []
+ is_collecting_field_text = False
+
+ # Decide where to append content
+ target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
+ process_run(run, target_buffer)
+ elif tag == qn("w:hyperlink"):
+ process_hyperlink(child, paragraph_content)
return "".join(paragraph_content) if paragraph_content else ""
paragraphs = doc.paragraphs.copy()
diff --git a/api/core/rag/pipeline/queue.py b/api/core/rag/pipeline/queue.py
index 7472598a7f..bf8db95b4e 100644
--- a/api/core/rag/pipeline/queue.py
+++ b/api/core/rag/pipeline/queue.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from collections.abc import Sequence
from typing import Any
@@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
return self.model_dump_json()
@classmethod
- def deserialize(cls, serialized_data: str) -> "TaskWrapper":
+ def deserialize(cls, serialized_data: str) -> TaskWrapper:
return cls.model_validate_json(serialized_data)
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index baf879df95..f8f85d141a 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
-from sqlalchemy import and_, or_, select
+from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@@ -515,7 +515,11 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
+ dataset_count = len(available_datasets)
with measure_time() as timer:
+ cancel_event = threading.Event()
+ thread_exceptions: list[Exception] = []
+
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
@@ -534,6 +538,9 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
+ "dataset_count": dataset_count,
+ "cancel_event": cancel_event,
+ "thread_exceptions": thread_exceptions,
},
)
all_threads.append(query_thread)
@@ -557,12 +564,26 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
+ "dataset_count": dataset_count,
+ "cancel_event": cancel_event,
+ "thread_exceptions": thread_exceptions,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
- for thread in all_threads:
- thread.join()
+
+ # Poll threads with short timeout to detect errors quickly (fail-fast)
+ while any(t.is_alive() for t in all_threads):
+ for thread in all_threads:
+ thread.join(timeout=0.1)
+ if thread_exceptions:
+ cancel_event.set()
+ break
+ if thread_exceptions:
+ break
+
+ if thread_exceptions:
+ raise thread_exceptions[0]
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
@@ -1036,7 +1057,7 @@ class DatasetRetrieval:
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
- self._process_metadata_filter_func(
+ self.process_metadata_filter_func(
sequence,
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
@@ -1072,7 +1093,7 @@ class DatasetRetrieval:
value=expected_value,
)
)
- filters = self._process_metadata_filter_func(
+ filters = self.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@@ -1168,26 +1189,33 @@ class DatasetRetrieval:
return None
return automatic_metadata_filters
- def _process_metadata_filter_func(
- self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
+ @classmethod
+ def process_metadata_filter_func(
+ cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return filters
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
+ from libs.helper import escape_like_pattern
+
match condition:
case "contains":
- filters.append(json_field.like(f"%{value}%"))
+ escaped_value = escape_like_pattern(str(value))
+ filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
case "not contains":
- filters.append(json_field.notlike(f"%{value}%"))
+ escaped_value = escape_like_pattern(str(value))
+ filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
case "start with":
- filters.append(json_field.like(f"{value}%"))
+ escaped_value = escape_like_pattern(str(value))
+ filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
case "end with":
- filters.append(json_field.like(f"%{value}"))
+ escaped_value = escape_like_pattern(str(value))
+ filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
case "is" | "=":
if isinstance(value, str):
@@ -1218,6 +1246,20 @@ class DatasetRetrieval:
case "≥" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
+ case "in" | "not in":
+ if isinstance(value, str):
+ value_list = [v.strip() for v in value.split(",") if v.strip()]
+ elif isinstance(value, (list, tuple)):
+ value_list = [str(v) for v in value if v is not None]
+ else:
+ value_list = [str(value)] if value is not None else []
+
+ if not value_list:
+ # `field in []` is False, `field not in []` is True
+ filters.append(literal(condition == "not in"))
+ else:
+ op = json_field.in_ if condition == "in" else json_field.notin_
+ filters.append(op(value_list))
case _:
pass
@@ -1389,69 +1431,89 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
+ dataset_count: int,
+ cancel_event: threading.Event | None = None,
+ thread_exceptions: list[Exception] | None = None,
):
- with flask_app.app_context():
- threads = []
- all_documents_item: list[Document] = []
- index_type = None
- for dataset in available_datasets:
- index_type = dataset.indexing_technique
- document_ids_filter = None
- if dataset.provider != "external":
- if metadata_condition and not metadata_filter_document_ids:
- continue
- if metadata_filter_document_ids:
- document_ids = metadata_filter_document_ids.get(dataset.id, [])
- if document_ids:
- document_ids_filter = document_ids
- else:
+ try:
+ with flask_app.app_context():
+ threads = []
+ all_documents_item: list[Document] = []
+ index_type = None
+ for dataset in available_datasets:
+ # Check for cancellation signal
+ if cancel_event and cancel_event.is_set():
+ break
+ index_type = dataset.indexing_technique
+ document_ids_filter = None
+ if dataset.provider != "external":
+ if metadata_condition and not metadata_filter_document_ids:
continue
- retrieval_thread = threading.Thread(
- target=self._retriever,
- kwargs={
- "flask_app": flask_app,
- "dataset_id": dataset.id,
- "query": query,
- "top_k": top_k,
- "all_documents": all_documents_item,
- "document_ids_filter": document_ids_filter,
- "metadata_condition": metadata_condition,
- "attachment_ids": [attachment_id] if attachment_id else None,
- },
- )
- threads.append(retrieval_thread)
- retrieval_thread.start()
- for thread in threads:
- thread.join()
+ if metadata_filter_document_ids:
+ document_ids = metadata_filter_document_ids.get(dataset.id, [])
+ if document_ids:
+ document_ids_filter = document_ids
+ else:
+ continue
+ retrieval_thread = threading.Thread(
+ target=self._retriever,
+ kwargs={
+ "flask_app": flask_app,
+ "dataset_id": dataset.id,
+ "query": query,
+ "top_k": top_k,
+ "all_documents": all_documents_item,
+ "document_ids_filter": document_ids_filter,
+ "metadata_condition": metadata_condition,
+ "attachment_ids": [attachment_id] if attachment_id else None,
+ },
+ )
+ threads.append(retrieval_thread)
+ retrieval_thread.start()
- if reranking_enable:
- # do rerank for searched documents
- data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
- if query:
- all_documents_item = data_post_processor.invoke(
- query=query,
- documents=all_documents_item,
- score_threshold=score_threshold,
- top_n=top_k,
- query_type=QueryType.TEXT_QUERY,
- )
- if attachment_id:
- all_documents_item = data_post_processor.invoke(
- documents=all_documents_item,
- score_threshold=score_threshold,
- top_n=top_k,
- query_type=QueryType.IMAGE_QUERY,
- query=attachment_id,
- )
- else:
- if index_type == IndexTechniqueType.ECONOMY:
- if not query:
- all_documents_item = []
- else:
- all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
- elif index_type == IndexTechniqueType.HIGH_QUALITY:
- all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
+ # Poll threads with short timeout to respond quickly to cancellation
+ while any(t.is_alive() for t in threads):
+ for thread in threads:
+ thread.join(timeout=0.1)
+ if cancel_event and cancel_event.is_set():
+ break
+ if cancel_event and cancel_event.is_set():
+ break
+
+ # Skip second reranking when there is only one dataset
+ if reranking_enable and dataset_count > 1:
+ # do rerank for searched documents
+ data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
+ if query:
+ all_documents_item = data_post_processor.invoke(
+ query=query,
+ documents=all_documents_item,
+ score_threshold=score_threshold,
+ top_n=top_k,
+ query_type=QueryType.TEXT_QUERY,
+ )
+ if attachment_id:
+ all_documents_item = data_post_processor.invoke(
+ documents=all_documents_item,
+ score_threshold=score_threshold,
+ top_n=top_k,
+ query_type=QueryType.IMAGE_QUERY,
+ query=attachment_id,
+ )
else:
- all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
- if all_documents_item:
- all_documents.extend(all_documents_item)
+ if index_type == IndexTechniqueType.ECONOMY:
+ if not query:
+ all_documents_item = []
+ else:
+ all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
+ elif index_type == IndexTechniqueType.HIGH_QUALITY:
+ all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
+ else:
+ all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
+ if all_documents_item:
+ all_documents.extend(all_documents_item)
+ except Exception as e:
+ if cancel_event:
+ cancel_event.set()
+ if thread_exceptions is not None:
+ thread_exceptions.append(e)
diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py
index 51bfae1cd3..b4ecfe47ff 100644
--- a/api/core/schemas/registry.py
+++ b/api/core/schemas/registry.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
+
import json
import logging
import threading
from collections.abc import Mapping, MutableMapping
from pathlib import Path
-from typing import Any, ClassVar, Optional
+from typing import Any, ClassVar
class SchemaRegistry:
@@ -11,7 +13,7 @@ class SchemaRegistry:
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
- _default_instance: ClassVar[Optional["SchemaRegistry"]] = None
+ _default_instance: ClassVar[SchemaRegistry | None] = None
_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self, base_dir: str):
@@ -20,7 +22,7 @@ class SchemaRegistry:
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
@classmethod
- def default_registry(cls) -> "SchemaRegistry":
+ def default_registry(cls) -> SchemaRegistry:
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
if cls._default_instance is None:
with cls._lock:
diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py
index 8ca4eabb7a..ebd200a822 100644
--- a/api/core/tools/__base/tool.py
+++ b/api/core/tools/__base/tool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy
@@ -24,7 +26,7 @@ class Tool(ABC):
self.entity = entity
self.runtime = runtime
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
"""
fork a new tool with metadata
:return: the new tool
@@ -166,7 +168,7 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
- def create_file_message(self, file: "File") -> ToolInvokeMessage:
+ def create_file_message(self, file: File) -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),
diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py
index 84efefba07..51b0407886 100644
--- a/api/core/tools/builtin_tool/tool.py
+++ b/api/core/tools/builtin_tool/tool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool
@@ -24,7 +26,7 @@ class BuiltinTool(Tool):
super().__init__(**kwargs)
self.provider = provider
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
"""
fork a new tool with metadata
:return: the new tool
diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py
index 0cc992155a..e2f6c00555 100644
--- a/api/core/tools/custom_tool/provider.py
+++ b/api/core/tools/custom_tool/provider.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from pydantic import Field
from sqlalchemy import select
@@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = []
@classmethod
- def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
+ def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
credentials_schema = [
ProviderConfig(
name="auth_type",
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index 583a3584f7..96268d029e 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import contextlib
from collections.abc import Mapping
@@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
MCP = auto()
@classmethod
- def value_of(cls, value: str) -> "ToolProviderType":
+ def value_of(cls, value: str) -> ToolProviderType:
"""
Get value of given mode.
@@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
OPENAI_ACTIONS = auto()
@classmethod
- def value_of(cls, value: str) -> "ApiProviderSchemaType":
+ def value_of(cls, value: str) -> ApiProviderSchemaType:
"""
Get value of given mode.
@@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
API_KEY_QUERY = auto()
@classmethod
- def value_of(cls, value: str) -> "ApiProviderAuthType":
+ def value_of(cls, value: str) -> ApiProviderAuthType:
"""
Get value of given mode.
@@ -128,7 +130,7 @@ class ToolInvokeMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
- json_object: dict
+ json_object: dict | list
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
@@ -142,7 +144,14 @@ class ToolInvokeMessage(BaseModel):
end: bool = Field(..., description="Whether the chunk is the last chunk")
class FileMessage(BaseModel):
- pass
+ file_marker: str = Field(default="file_marker")
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_file_message(cls, values):
+ if isinstance(values, dict) and "file_marker" not in values:
+ raise ValueError("Invalid FileMessage: missing file_marker")
+ return values
class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable")
@@ -232,10 +241,22 @@ class ToolInvokeMessage(BaseModel):
@field_validator("message", mode="before")
@classmethod
- def decode_blob_message(cls, v):
+ def decode_blob_message(cls, v, info: ValidationInfo):
+ # 处理 blob 解码
if isinstance(v, dict) and "blob" in v:
with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"])
+
+ # Force correct message type based on type field
+ # Only wrap dict types to avoid wrapping already parsed Pydantic model objects
+ if info.data and isinstance(info.data, dict) and isinstance(v, dict):
+ msg_type = info.data.get("type")
+ if msg_type == cls.MessageType.JSON:
+ if "json_object" not in v:
+ v = {"json_object": v}
+ elif msg_type == cls.MessageType.FILE:
+ v = {"file_marker": "file_marker"}
+
return v
@field_serializer("message")
@@ -307,7 +328,7 @@ class ToolParameter(PluginParameter):
typ: ToolParameterType,
required: bool,
options: list[str] | None = None,
- ) -> "ToolParameter":
+ ) -> ToolParameter:
"""
get a simple tool parameter
@@ -429,14 +450,14 @@ class ToolInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
- def empty(cls) -> "ToolInvokeMeta":
+ def empty(cls) -> ToolInvokeMeta:
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
- def error_instance(cls, error: str) -> "ToolInvokeMeta":
+ def error_instance(cls, error: str) -> ToolInvokeMeta:
"""
Get an instance of ToolInvokeMeta with error
"""
diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py
index fbaf31ad09..ef9e9c103a 100644
--- a/api/core/tools/mcp_tool/tool.py
+++ b/api/core/tools/mcp_tool/tool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import json
import logging
@@ -6,7 +8,15 @@ from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
-from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
+from core.mcp.types import (
+ AudioContent,
+ BlobResourceContents,
+ CallToolResult,
+ EmbeddedResource,
+ ImageContent,
+ TextContent,
+ TextResourceContents,
+)
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@@ -53,10 +63,19 @@ class MCPTool(Tool):
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
- elif isinstance(content, ImageContent):
- yield self._process_image_content(content)
- elif isinstance(content, AudioContent):
- yield self._process_audio_content(content)
+ elif isinstance(content, ImageContent | AudioContent):
+ yield self.create_blob_message(
+ blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
+ )
+ elif isinstance(content, EmbeddedResource):
+ resource = content.resource
+ if isinstance(resource, TextResourceContents):
+ yield self.create_text_message(resource.text)
+ elif isinstance(resource, BlobResourceContents):
+ mime_type = resource.mimeType or "application/octet-stream"
+ yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
+ else:
+ raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
else:
logger.warning("Unsupported content type=%s", type(content))
@@ -101,15 +120,7 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
- def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
- """Process image content and return a blob message."""
- return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
-
- def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
- """Process audio content and return a blob message."""
- return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
-
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,
runtime=runtime,
diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py
index 828dc3b810..d3a2ad488c 100644
--- a/api/core/tools/plugin_tool/tool.py
+++ b/api/core/tools/plugin_tool/tool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from collections.abc import Generator
from typing import Any
@@ -46,7 +48,7 @@ class PluginTool(Tool):
message_id=message_id,
)
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
return PluginTool(
entity=self.entity,
runtime=runtime,
diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py
index fef3157f27..22e099deba 100644
--- a/api/core/tools/signature.py
+++ b/api/core/tools/signature.py
@@ -7,12 +7,12 @@ import time
from configs import dify_config
-def sign_tool_file(tool_file_id: str, extension: str) -> str:
+def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""
sign file to get a temporary url for plugin access
"""
- # Use internal URL for plugin/tool file access in Docker environments
- base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
+ # Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
+ base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index 13fd579e20..3f57a346cd 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -1,5 +1,6 @@
import contextlib
import json
+import logging
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
@@ -36,6 +37,8 @@ from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import Message, MessageFile
+logger = logging.getLogger(__name__)
+
class ToolEngine:
"""
@@ -123,25 +126,31 @@ class ToolEngine:
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
+ logger.error(e, exc_info=True)
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
+ logger.error(e, exc_info=True)
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py
index 3486182192..584975de05 100644
--- a/api/core/tools/utils/parser.py
+++ b/api/core/tools/utils/parser.py
@@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
- ) -> tuple[list[ApiToolBundle], str]:
+ ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
auto parse to tool bundle
diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py
index 0f9a91a111..4bfaa5e49b 100644
--- a/api/core/tools/utils/text_processing_utils.py
+++ b/api/core/tools/utils/text_processing_utils.py
@@ -4,6 +4,7 @@ import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
+ Preserves markdown links like [text](url) at the start.
Args:
text (str): The input text to process.
@@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
Returns:
str: The text with leading punctuation or symbols removed.
"""
+ # Check if text starts with a markdown link - preserve it
+ markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
+ if re.match(markdown_link_pattern, text):
+ return text
+
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py
index 2bd973f831..a706f101ca 100644
--- a/api/core/tools/workflow_as_tool/provider.py
+++ b/api/core/tools/workflow_as_tool/provider.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from collections.abc import Mapping
from pydantic import Field
@@ -47,14 +49,13 @@ class WorkflowToolProviderController(ToolProviderController):
self.provider_id = provider_id
@classmethod
- def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
+ def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
with session_factory.create_session() as session, session.begin():
app = session.get(App, db_provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
-
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
@@ -67,7 +68,7 @@ class WorkflowToolProviderController(ToolProviderController):
credentials_schema=[],
plugin_id=None,
),
- provider_id="",
+ provider_id=db_provider.id,
)
controller.tools = [
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 30334f5da8..9c1ceff145 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -1,12 +1,13 @@
+from __future__ import annotations
+
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
-from flask import has_request_context
from sqlalchemy import select
-from sqlalchemy.orm import Session
+from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@@ -18,9 +19,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
-from extensions.ext_database import db
from factories.file_factory import build_from_mapping
-from libs.login import current_user
from models import Account, Tenant
from models.model import App, EndUser
from models.workflow import Workflow
@@ -181,7 +180,7 @@ class WorkflowTool(Tool):
return found
return None
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
"""
fork a new tool with metadata
@@ -208,50 +207,38 @@ class WorkflowTool(Tool):
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
- if has_request_context():
- return self._resolve_user_from_request()
- else:
- return self._resolve_user_from_database(user_id=user_id)
-
- def _resolve_user_from_request(self) -> Account | EndUser | None:
- """
- Resolve user from Flask request context.
- """
- try:
- # Note: `current_user` is a LocalProxy. Never compare it with None directly.
- return getattr(current_user, "_get_current_object", lambda: current_user)()
- except Exception as e:
- logger.warning("Failed to resolve user from request context: %s", e)
- return None
+ return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
"""
Resolve user from database (worker/Celery context).
"""
+ with session_factory.create_session() as session:
+ tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
+ tenant = session.scalar(tenant_stmt)
+ if not tenant:
+ return None
+
+ user_stmt = select(Account).where(Account.id == user_id)
+ user = session.scalar(user_stmt)
+ if user:
+ user.current_tenant = tenant
+ session.expunge(user)
+ return user
+
+ end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
+ end_user = session.scalar(end_user_stmt)
+ if end_user:
+ session.expunge(end_user)
+ return end_user
- tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
- tenant = db.session.scalar(tenant_stmt)
- if not tenant:
return None
- user_stmt = select(Account).where(Account.id == user_id)
- user = db.session.scalar(user_stmt)
- if user:
- user.current_tenant = tenant
- return user
-
- end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
- end_user = db.session.scalar(end_user_stmt)
- if end_user:
- return end_user
-
- return None
-
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@@ -263,22 +250,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
- if not workflow:
- raise ValueError("workflow not found or not published")
+ if not workflow:
+ raise ValueError("workflow not found or not published")
- return workflow
+ session.expunge(workflow)
+ return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
- if not app:
- raise ValueError("app not found")
+ if not app:
+ raise ValueError("app not found")
- return app
+ session.expunge(app)
+ return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""
diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py
index 7a1cbf9940..7498224923 100644
--- a/api/core/variables/__init__.py
+++ b/api/core/variables/__init__.py
@@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
+ VariableBase,
)
__all__ = [
@@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
+ "VariableBase",
]
diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py
index 406b4e6f93..8330f1fe19 100644
--- a/api/core/variables/segments.py
+++ b/api/core/variables/segments.py
@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
-# - `Variable` and its subclasses, which are handled by `VariableUnion`.
+# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
diff --git a/api/core/variables/types.py b/api/core/variables/types.py
index ce71711344..13b926c978 100644
--- a/api/core/variables/types.py
+++ b/api/core/variables/types.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
from collections.abc import Mapping
from enum import StrEnum
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any
from core.file.models import File
@@ -52,7 +54,7 @@ class SegmentType(StrEnum):
return self in _ARRAY_TYPES
@classmethod
- def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
+ def infer_segment_type(cls, value: Any) -> SegmentType | None:
"""
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
@@ -173,7 +175,7 @@ class SegmentType(StrEnum):
raise AssertionError("this statement should be unreachable.")
@staticmethod
- def cast_value(value: Any, type_: "SegmentType"):
+ def cast_value(value: Any, type_: SegmentType):
# Cast Python's `bool` type to `int` when the runtime type requires
# an integer or number.
#
@@ -193,7 +195,7 @@ class SegmentType(StrEnum):
return [int(i) for i in value]
return value
- def exposed_type(self) -> "SegmentType":
+ def exposed_type(self) -> SegmentType:
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
@@ -202,7 +204,7 @@ class SegmentType(StrEnum):
return SegmentType.NUMBER
return self
- def element_type(self) -> "SegmentType | None":
+ def element_type(self) -> SegmentType | None:
"""Return the element type of the current segment type, or `None` if the element type is undefined.
Raises:
@@ -217,7 +219,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
- def get_zero_value(t: "SegmentType"):
+ def get_zero_value(t: SegmentType):
# Lazy import to avoid circular dependency
from factories import variable_factory
diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py
index 9fd0bbc5b2..a19c53918d 100644
--- a/api/core/variables/variables.py
+++ b/api/core/variables/variables.py
@@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
-class Variable(Segment):
+class VariableBase(Segment):
"""
A variable is a segment that has a name.
@@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
-class StringVariable(StringSegment, Variable):
+class StringVariable(StringSegment, VariableBase):
pass
-class FloatVariable(FloatSegment, Variable):
+class FloatVariable(FloatSegment, VariableBase):
pass
-class IntegerVariable(IntegerSegment, Variable):
+class IntegerVariable(IntegerSegment, VariableBase):
pass
-class ObjectVariable(ObjectSegment, Variable):
+class ObjectVariable(ObjectSegment, VariableBase):
pass
-class ArrayVariable(ArraySegment, Variable):
+class ArrayVariable(ArraySegment, VariableBase):
pass
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
-class NoneVariable(NoneSegment, Variable):
+class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
-class FileVariable(FileSegment, Variable):
+class FileVariable(FileSegment, VariableBase):
pass
-class BooleanVariable(BooleanSegment, Variable):
+class BooleanVariable(BooleanSegment, VariableBase):
pass
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
-# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
-# Use `Variable` for type hinting when serialization is not required.
+# The `Variable` type is used to enable serialization and deserialization with Pydantic.
+# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
-# - All variants in `VariableUnion` must inherit from the `Variable` class.
-# - The union must include all non-abstract subclasses of `Segment`, except:
-VariableUnion: TypeAlias = Annotated[
+# - All variants in `Variable` must inherit from the `VariableBase` class.
+# - The union must include all non-abstract subclasses of `VariableBase`.
+Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md
index 72f5dbe1e2..9a39f976a6 100644
--- a/api/core/workflow/README.md
+++ b/api/core/workflow/README.md
@@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
+`engine.layer()` binds the read-only runtime state before execution, so layer hooks
+can assume `graph_runtime_state` is available.
+
### Event-Driven Architecture
All node executions emit events for monitoring and integration:
diff --git a/api/core/workflow/context/__init__.py b/api/core/workflow/context/__init__.py
new file mode 100644
index 0000000000..1237d6a017
--- /dev/null
+++ b/api/core/workflow/context/__init__.py
@@ -0,0 +1,34 @@
+"""
+Execution Context - Context management for workflow execution.
+
+This package provides Flask-independent context management for workflow
+execution in multi-threaded environments.
+"""
+
+from core.workflow.context.execution_context import (
+ AppContext,
+ ContextProviderNotFoundError,
+ ExecutionContext,
+ IExecutionContext,
+ NullAppContext,
+ capture_current_context,
+ read_context,
+ register_context,
+ register_context_capturer,
+ reset_context_provider,
+)
+from core.workflow.context.models import SandboxContext
+
+__all__ = [
+ "AppContext",
+ "ContextProviderNotFoundError",
+ "ExecutionContext",
+ "IExecutionContext",
+ "NullAppContext",
+ "SandboxContext",
+ "capture_current_context",
+ "read_context",
+ "register_context",
+ "register_context_capturer",
+ "reset_context_provider",
+]
diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py
new file mode 100644
index 0000000000..e3007530f0
--- /dev/null
+++ b/api/core/workflow/context/execution_context.py
@@ -0,0 +1,284 @@
+"""
+Execution Context - Abstracted context management for workflow execution.
+"""
+
+import contextvars
+import threading
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Generator
+from contextlib import AbstractContextManager, contextmanager
+from typing import Any, Protocol, TypeVar, final, runtime_checkable
+
+from pydantic import BaseModel
+
+
+class AppContext(ABC):
+ """
+ Abstract application context interface.
+
+ This abstraction allows workflow execution to work with or without Flask
+ by providing a common interface for application context management.
+ """
+
+ @abstractmethod
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value by key."""
+ pass
+
+ @abstractmethod
+ def get_extension(self, name: str) -> Any:
+ """Get Flask extension by name (e.g., 'db', 'cache')."""
+ pass
+
+ @abstractmethod
+ def enter(self) -> AbstractContextManager[None]:
+ """Enter the application context."""
+ pass
+
+
+@runtime_checkable
+class IExecutionContext(Protocol):
+ """
+ Protocol for execution context.
+
+ This protocol defines the interface that all execution contexts must implement,
+ allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
+ """
+
+ def __enter__(self) -> "IExecutionContext":
+ """Enter the execution context."""
+ ...
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the execution context."""
+ ...
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ ...
+
+
+@final
+class ExecutionContext:
+ """
+ Execution context for workflow execution in worker threads.
+
+ This class encapsulates all context needed for workflow execution:
+ - Application context (Flask app or standalone)
+ - Context variables for Python contextvars
+ - User information (optional)
+
+ It is designed to be serializable and passable to worker threads.
+ """
+
+ def __init__(
+ self,
+ app_context: AppContext | None = None,
+ context_vars: contextvars.Context | None = None,
+ user: Any = None,
+ ) -> None:
+ """
+ Initialize execution context.
+
+ Args:
+ app_context: Application context (Flask or standalone)
+ context_vars: Python contextvars to preserve
+ user: User object (optional)
+ """
+ self._app_context = app_context
+ self._context_vars = context_vars
+ self._user = user
+ self._local = threading.local()
+
+ @property
+ def app_context(self) -> AppContext | None:
+ """Get application context."""
+ return self._app_context
+
+ @property
+ def context_vars(self) -> contextvars.Context | None:
+ """Get context variables."""
+ return self._context_vars
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ return self._user
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """
+ Enter this execution context.
+
+ This is a convenience method that creates a context manager.
+ """
+ # Restore context variables if provided
+ if self._context_vars:
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter app context if available
+ if self._app_context is not None:
+ with self._app_context.enter():
+ yield
+ else:
+ yield
+
+ def __enter__(self) -> "ExecutionContext":
+ """Enter the execution context."""
+ cm = self.enter()
+ self._local.cm = cm
+ cm.__enter__()
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the execution context."""
+ cm = getattr(self._local, "cm", None)
+ if cm is not None:
+ cm.__exit__(*args)
+
+
+class NullAppContext(AppContext):
+ """
+ Null implementation of AppContext for non-Flask environments.
+
+ This is used when running without Flask (e.g., in tests or standalone mode).
+ """
+
+ def __init__(self, config: dict[str, Any] | None = None) -> None:
+ """
+ Initialize null app context.
+
+ Args:
+ config: Optional configuration dictionary
+ """
+ self._config = config or {}
+ self._extensions: dict[str, Any] = {}
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value by key."""
+ return self._config.get(key, default)
+
+ def get_extension(self, name: str) -> Any:
+ """Get extension by name."""
+ return self._extensions.get(name)
+
+ def set_extension(self, name: str, extension: Any) -> None:
+ """Set extension by name."""
+ self._extensions[name] = extension
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter null context (no-op)."""
+ yield
+
+
+class ExecutionContextBuilder:
+ """
+ Builder for creating ExecutionContext instances.
+
+ This provides a fluent API for building execution contexts.
+ """
+
+ def __init__(self) -> None:
+ self._app_context: AppContext | None = None
+ self._context_vars: contextvars.Context | None = None
+ self._user: Any = None
+
+ def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
+ """Set application context."""
+ self._app_context = app_context
+ return self
+
+ def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
+ """Set context variables."""
+ self._context_vars = context_vars
+ return self
+
+ def with_user(self, user: Any) -> "ExecutionContextBuilder":
+ """Set user."""
+ self._user = user
+ return self
+
+ def build(self) -> ExecutionContext:
+ """Build the execution context."""
+ return ExecutionContext(
+ app_context=self._app_context,
+ context_vars=self._context_vars,
+ user=self._user,
+ )
+
+
+_capturer: Callable[[], IExecutionContext] | None = None
+
+# Tenant-scoped providers using tuple keys for clarity and constant-time lookup.
+# Key mapping:
+# (name, tenant_id) -> provider
+# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox")
+# - tenant_id: tenant identifier string
+# Value:
+# provider: Callable[[], BaseModel] returning the typed context value
+# Type-safety note:
+# - This registry cannot enforce that all providers for a given name return the same BaseModel type.
+# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice),
+# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and
+# def read_sandbox_ctx(tenant_id: str) -> SandboxContext.
+_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
+
+T = TypeVar("T", bound=BaseModel)
+
+
+class ContextProviderNotFoundError(KeyError):
+ """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id)."""
+
+ pass
+
+
+def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
+ """Register a single enterable execution context capturer (e.g., Flask)."""
+ global _capturer
+ _capturer = capturer
+
+
+def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
+ """Register a tenant-specific provider for a named context.
+
+ Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions.
+ Consider adding a typed wrapper for this registration in your feature module.
+ """
+ _tenant_context_providers[(name, tenant_id)] = provider
+
+
+def read_context(name: str, *, tenant_id: str) -> BaseModel:
+ """
+ Read a context value for a specific tenant.
+
+ Raises KeyError if the provider for (name, tenant_id) is not registered.
+ """
+ prov = _tenant_context_providers.get((name, tenant_id))
+ if prov is None:
+ raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
+ return prov()
+
+
+def capture_current_context() -> IExecutionContext:
+ """
+ Capture current execution context from the calling environment.
+
+ If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal
+ context with NullAppContext + copy of current contextvars.
+ """
+ if _capturer is None:
+ return ExecutionContext(
+ app_context=NullAppContext(),
+ context_vars=contextvars.copy_context(),
+ )
+ return _capturer()
+
+
+def reset_context_provider() -> None:
+ """Reset the capturer and all tenant-scoped context providers (primarily for tests)."""
+ global _capturer
+ _capturer = None
+ _tenant_context_providers.clear()
diff --git a/api/core/workflow/context/models.py b/api/core/workflow/context/models.py
new file mode 100644
index 0000000000..af5a4b2614
--- /dev/null
+++ b/api/core/workflow/context/models.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+from pydantic import AnyHttpUrl, BaseModel
+
+
+class SandboxContext(BaseModel):
+ """Typed context for sandbox integration. All fields optional by design."""
+
+ sandbox_url: AnyHttpUrl | None = None
+ sandbox_token: str | None = None # optional, if later needed for auth
+
+
+__all__ = ["SandboxContext"]
diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py
index fd78248c17..75f47691da 100644
--- a/api/core/workflow/conversation_variable_updater.py
+++ b/api/core/workflow/conversation_variable_updater.py
@@ -1,7 +1,7 @@
import abc
from typing import Protocol
-from core.variables import Variable
+from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
- def update(self, conversation_id: str, variable: "Variable"):
+ def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
- :param variable: The `Variable` instance containing the updated value.
+ :param variable: The `VariableBase` instance containing the updated value.
"""
pass
diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py
index a8a86d3db2..1b3fb36f1f 100644
--- a/api/core/workflow/entities/workflow_execution.py
+++ b/api/core/workflow/entities/workflow_execution.py
@@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
+from __future__ import annotations
+
from collections.abc import Mapping
from datetime import datetime
from typing import Any
@@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
- ) -> "WorkflowExecution":
+ ) -> WorkflowExecution:
return WorkflowExecution(
id_=id_,
workflow_id=workflow_id,
diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py
index c08b62a253..bb3b13e8c6 100644
--- a/api/core/workflow/enums.py
+++ b/api/core/workflow/enums.py
@@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
def is_ended(self) -> bool:
return self in _END_STATE
+ @classmethod
+ def ended_values(cls) -> list[str]:
+ return [status.value for status in _END_STATE]
+
_END_STATE = frozenset(
[
diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py
index ba5a01fc94..31bf6f3b27 100644
--- a/api/core/workflow/graph/graph.py
+++ b/api/core/workflow/graph/graph.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
@@ -175,7 +177,7 @@ class Graph:
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
- node_factory: "NodeFactory",
+ node_factory: NodeFactory,
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
@@ -197,7 +199,7 @@ class Graph:
return nodes
@classmethod
- def new(cls) -> "GraphBuilder":
+ def new(cls) -> GraphBuilder:
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@@ -284,9 +286,10 @@ class Graph:
cls,
*,
graph_config: Mapping[str, object],
- node_factory: "NodeFactory",
+ node_factory: NodeFactory,
root_node_id: str | None = None,
- ) -> "Graph":
+ skip_validation: bool = False,
+ ) -> Graph:
"""
Initialize graph
@@ -337,8 +340,9 @@ class Graph:
root_node=root_node,
)
- # Validate the graph structure using built-in validators
- get_graph_validator().validate(graph)
+ if not skip_validation:
+ # Validate the graph structure using built-in validators
+ get_graph_validator().validate(graph)
return graph
@@ -383,7 +387,7 @@ class GraphBuilder:
self._edges: list[Edge] = []
self._edge_counter = 0
- def add_root(self, node: Node) -> "GraphBuilder":
+ def add_root(self, node: Node) -> GraphBuilder:
"""Register the root node. Must be called exactly once."""
if self._nodes:
@@ -398,7 +402,7 @@ class GraphBuilder:
*,
from_node_id: str | None = None,
source_handle: str = "source",
- ) -> "GraphBuilder":
+ ) -> GraphBuilder:
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
@@ -419,7 +423,7 @@ class GraphBuilder:
return self
- def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
+ def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:
diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py
index 4be3adb8f8..0fccd4a0fd 100644
--- a/api/core/workflow/graph_engine/command_channels/redis_channel.py
+++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py
@@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
-from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
+from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
+ if command_type == CommandType.UPDATE_VARIABLES:
+ return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/core/workflow/graph_engine/command_processing/__init__.py
index 837f5e55fd..7b4f0dfff7 100644
--- a/api/core/workflow/graph_engine/command_processing/__init__.py
+++ b/api/core/workflow/graph_engine/command_processing/__init__.py
@@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
-from .command_handlers import AbortCommandHandler, PauseCommandHandler
+from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
+ "UpdateVariablesCommandHandler",
]
diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py
index e9f109c88c..cfe856d9e8 100644
--- a/api/core/workflow/graph_engine/command_processing/command_handlers.py
+++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py
@@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
+from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
-from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
+from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
+
+
+@final
+class UpdateVariablesCommandHandler(CommandHandler):
+ def __init__(self, variable_pool: VariablePool) -> None:
+ self._variable_pool = variable_pool
+
+ @override
+ def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
+ assert isinstance(command, UpdateVariablesCommand)
+ for update in command.updates:
+ try:
+ variable = update.value
+ self._variable_pool.add(variable.selector, variable)
+ logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
+ except ValueError as exc:
+ logger.warning(
+ "Skipping invalid variable selector %s for workflow %s: %s",
+ getattr(update.value, "selector", None),
+ execution.workflow_id,
+ exc,
+ )
diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py
index 0d51b2b716..41276eb444 100644
--- a/api/core/workflow/graph_engine/entities/commands.py
+++ b/api/core/workflow/graph_engine/entities/commands.py
@@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
-from enum import StrEnum
+from collections.abc import Sequence
+from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
+from core.variables.variables import Variable
+
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
- ABORT = "abort"
- PAUSE = "pause"
+ ABORT = auto()
+ PAUSE = auto()
+ UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
+
+
+class VariableUpdate(BaseModel):
+ """Represents a single variable update instruction."""
+
+ value: Variable = Field(description="New variable value")
+
+
+class UpdateVariablesCommand(GraphEngineCommand):
+ """Command to update a group of variables in the variable pool."""
+
+ command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
+ updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index 744013cb04..187dfcf021 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -5,14 +5,15 @@ This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
-import contextvars
+from __future__ import annotations
+
import logging
import queue
+import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
-from flask import Flask, current_app
-
+from core.workflow.context import capture_current_context
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
@@ -31,8 +32,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
-from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
-from .entities.commands import AbortCommand, PauseCommand
+from .command_processing import (
+ AbortCommandHandler,
+ CommandProcessor,
+ PauseCommandHandler,
+ UpdateVariablesCommandHandler,
+)
+from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@@ -71,10 +77,13 @@ class GraphEngine:
scale_down_idle_time: float | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
+ # stop event
+ self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
+ self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
@@ -141,22 +150,16 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
+ update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
+ self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
+
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
- # Capture Flask app context for worker threads
- flask_app: Flask | None = None
- try:
- app = current_app._get_current_object() # type: ignore
- if isinstance(app, Flask):
- flask_app = app
- except RuntimeError:
- pass
-
- # Capture context variables for worker threads
- context_vars = contextvars.copy_context()
+ # Capture execution context for worker threads
+ execution_context = capture_current_context()
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
@@ -164,12 +167,12 @@ class GraphEngine:
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
- flask_app=flask_app,
- context_vars=context_vars,
+ execution_context=execution_context,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
+ stop_event=self._stop_event,
)
# === Orchestration ===
@@ -200,6 +203,7 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
+ stop_event=self._stop_event,
)
# === Validation ===
@@ -213,9 +217,16 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
- def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
+ def _bind_layer_context(
+ self,
+ layer: GraphEngineLayer,
+ ) -> None:
+ layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
+
+ def layer(self, layer: GraphEngineLayer) -> GraphEngine:
"""Add a layer for extending functionality."""
self._layers.append(layer)
+ self._bind_layer_context(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
@@ -304,14 +315,7 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
- # Create a read-only wrapper for the runtime state
- read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
- try:
- layer.initialize(read_only_state, self._command_channel)
- except Exception:
- logger.exception("Failed to initialize layer %s", layer.__class__.__name__)
-
try:
layer.on_graph_start()
except Exception:
@@ -319,6 +323,7 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
+ self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
@@ -352,13 +357,12 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
+ self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
- logger = logging.getLogger(__name__)
-
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)
diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py
index 78f8ecdcdf..b9c9243963 100644
--- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py
+++ b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py
@@ -60,6 +60,7 @@ class SkipPropagator:
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
+ self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/core/workflow/graph_engine/layers/README.md
index 17845ee1f0..b0f295037c 100644
--- a/api/core/workflow/graph_engine/layers/README.md
+++ b/api/core/workflow/graph_engine/layers/README.md
@@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
Abstract base class for layers.
-- `initialize()` - Receive runtime context
+- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
@@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```
+`engine.layer()` binds the read-only runtime state before execution, so
+`graph_runtime_state` is always available inside layer hooks.
+
## Custom Layers
```python
diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/core/workflow/graph_engine/layers/__init__.py
index 772433e48c..0a29a52993 100644
--- a/api/core/workflow/graph_engine/layers/__init__.py
+++ b/api/core/workflow/graph_engine/layers/__init__.py
@@ -8,11 +8,9 @@ with middleware-like components that can observe events and interact with execut
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
-from .observability import ObservabilityLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
- "ObservabilityLayer",
]
diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py
index 780f92a0f4..ff4a483aed 100644
--- a/api/core/workflow/graph_engine/layers/base.py
+++ b/api/core/workflow/graph_engine/layers/base.py
@@ -8,11 +8,19 @@ intercept and respond to GraphEngine events.
from abc import ABC, abstractmethod
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
-from core.workflow.graph_events import GraphEngineEvent
+from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
+class GraphEngineLayerNotInitializedError(Exception):
+ """Raised when a layer's runtime state is accessed before initialization."""
+
+ def __init__(self, layer_name: str | None = None) -> None:
+ name = layer_name or "GraphEngineLayer"
+ super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
+
+
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
- self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
+ self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
+ @property
+ def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
+ if self._graph_runtime_state is None:
+ raise GraphEngineLayerNotInitializedError(type(self).__name__)
+ return self._graph_runtime_state
+
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
- Called by GraphEngine before execution starts to inject the read-only runtime state
- and command channel. This allows layers to observe engine context and send
- commands, but prevents direct state modification.
-
+ Called by GraphEngine to inject the read-only runtime state and command channel.
+ This is invoked when the layer is registered with a `GraphEngine` instance.
+ Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
- self.graph_runtime_state = graph_runtime_state
+ self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod
@@ -85,7 +98,7 @@ class GraphEngineLayer(ABC):
"""
pass
- def on_node_run_start(self, node: Node) -> None: # noqa: B027
+ def on_node_run_start(self, node: Node) -> None:
"""
Called immediately before a node begins execution.
@@ -96,9 +109,11 @@ class GraphEngineLayer(ABC):
Args:
node: The node instance about to be executed
"""
- pass
+ return
- def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
+ def on_node_run_end(
+ self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
"""
Called after a node finishes execution.
@@ -108,5 +123,6 @@ class GraphEngineLayer(ABC):
Args:
node: The node instance that just finished execution
error: Exception instance if the node failed, otherwise None
+ result_event: The final result event from node execution (succeeded/failed/paused), if any
"""
- pass
+ return
diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py
index 034ebcf54f..e0402cd09c 100644
--- a/api/core/workflow/graph_engine/layers/debug_logging.py
+++ b/api/core/workflow/graph_engine/layers/debug_logging.py
@@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
-
- if self.graph_runtime_state:
- # Log initial state
- self.logger.info("Initial State:")
+ # Log initial state
+ self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
@@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
- if self.graph_runtime_state and self.include_outputs:
- if self.graph_runtime_state.outputs:
- self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
+ if self.include_outputs and self.graph_runtime_state.outputs:
+ self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)
diff --git a/api/core/workflow/graph_engine/layers/node_parsers.py b/api/core/workflow/graph_engine/layers/node_parsers.py
deleted file mode 100644
index b6bac794df..0000000000
--- a/api/core/workflow/graph_engine/layers/node_parsers.py
+++ /dev/null
@@ -1,61 +0,0 @@
-"""
-Node-level OpenTelemetry parser interfaces and defaults.
-"""
-
-import json
-from typing import Protocol
-
-from opentelemetry.trace import Span
-from opentelemetry.trace.status import Status, StatusCode
-
-from core.workflow.nodes.base.node import Node
-from core.workflow.nodes.tool.entities import ToolNodeData
-
-
-class NodeOTelParser(Protocol):
- """Parser interface for node-specific OpenTelemetry enrichment."""
-
- def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
-
-
-class DefaultNodeOTelParser:
- """Fallback parser used when no node-specific parser is registered."""
-
- def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
- span.set_attribute("node.id", node.id)
- if node.execution_id:
- span.set_attribute("node.execution_id", node.execution_id)
- if hasattr(node, "node_type") and node.node_type:
- span.set_attribute("node.type", node.node_type.value)
-
- if error:
- span.record_exception(error)
- span.set_status(Status(StatusCode.ERROR, str(error)))
- else:
- span.set_status(Status(StatusCode.OK))
-
-
-class ToolNodeOTelParser:
- """Parser for tool nodes that captures tool-specific metadata."""
-
- def __init__(self) -> None:
- self._delegate = DefaultNodeOTelParser()
-
- def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
- self._delegate.parse(node=node, span=span, error=error)
-
- tool_data = getattr(node, "_node_data", None)
- if not isinstance(tool_data, ToolNodeData):
- return
-
- span.set_attribute("tool.provider.id", tool_data.provider_id)
- span.set_attribute("tool.provider.type", tool_data.provider_type.value)
- span.set_attribute("tool.provider.name", tool_data.provider_name)
- span.set_attribute("tool.name", tool_data.tool_name)
- span.set_attribute("tool.label", tool_data.tool_label)
- if tool_data.plugin_unique_identifier:
- span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
- if tool_data.credential_id:
- span.set_attribute("tool.credential.id", tool_data.credential_id)
- if tool_data.tool_configurations:
- span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))
diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py
index 0577ba8f02..d2cfa755d9 100644
--- a/api/core/workflow/graph_engine/manager.py
+++ b/api/core/workflow/graph_engine/manager.py
@@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
-Supports stop, pause, and resume operations.
"""
import logging
+from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
-from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
+from core.workflow.graph_engine.entities.commands import (
+ AbortCommand,
+ GraphEngineCommand,
+ PauseCommand,
+ UpdateVariablesCommand,
+ VariableUpdate,
+)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
- Supports stop and pause operations.
"""
@staticmethod
@@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
+ @staticmethod
+ def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
+ """Send a command to update variables in a running workflow."""
+
+ if not updates:
+ return
+
+ update_command = UpdateVariablesCommand(updates=updates)
+ GraphEngineManager._send_command(task_id, update_command)
+
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py
index 685347fbce..d40d15c545 100644
--- a/api/core/workflow/graph_engine/orchestration/dispatcher.py
+++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py
@@ -44,6 +44,7 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
+ stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@@ -61,7 +62,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
- self._stop_event = threading.Event()
+ self._stop_event = stop_event
self._start_time: float | None = None
def start(self) -> None:
@@ -69,16 +70,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
- self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
- self._stop_event.set()
if self._thread and self._thread.is_alive():
- self._thread.join(timeout=10.0)
+ self._thread.join(timeout=2.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/core/workflow/graph_engine/ready_queue/factory.py
index 1144e1de69..a9d4f470e5 100644
--- a/api/core/workflow/graph_engine/ready_queue/factory.py
+++ b/api/core/workflow/graph_engine/ready_queue/factory.py
@@ -2,6 +2,8 @@
Factory for creating ReadyQueue instances from serialized state.
"""
+from __future__ import annotations
+
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
@@ -11,7 +13,7 @@ if TYPE_CHECKING:
from .protocol import ReadyQueue
-def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
+def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
"""
Create a ReadyQueue instance from a serialized state.
diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py
index 8b7c2e441e..8ceaa428c3 100644
--- a/api/core/workflow/graph_engine/response_coordinator/session.py
+++ b/api/core/workflow/graph_engine/response_coordinator/session.py
@@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
+from __future__ import annotations
+
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
@@ -27,7 +29,7 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
- def from_node(cls, node: Node) -> "ResponseSession":
+ def from_node(cls, node: Node) -> ResponseSession:
"""
Create a ResponseSession from an AnswerNode or EndNode.
diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py
index e37a08ae47..512df6ff86 100644
--- a/api/core/workflow/graph_engine/worker.py
+++ b/api/core/workflow/graph_engine/worker.py
@@ -5,26 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
-import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
-from typing import final
-from uuid import uuid4
+from typing import TYPE_CHECKING, final
-from flask import Flask
from typing_extensions import override
+from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
-from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
+from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
from core.workflow.nodes.base.node import Node
-from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
+if TYPE_CHECKING:
+ pass
+
@final
class Worker(threading.Thread):
@@ -42,9 +42,9 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
+ stop_event: threading.Event,
worker_id: int = 0,
- flask_app: Flask | None = None,
- context_vars: contextvars.Context | None = None,
+ execution_context: IExecutionContext | None = None,
) -> None:
"""
Initialize worker thread.
@@ -55,23 +55,24 @@ class Worker(threading.Thread):
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
- flask_app: Optional Flask application for context preservation
- context_vars: Optional context variables to preserve in worker thread
+ execution_context: Optional execution context for context preservation
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
- self._flask_app = flask_app
- self._context_vars = context_vars
- self._stop_event = threading.Event()
- self._last_task_time = time.time()
+ self._execution_context = execution_context
+ self._stop_event = stop_event
self._layers = layers if layers is not None else []
+ self._last_task_time = time.time()
def stop(self) -> None:
- """Signal the worker to stop processing."""
- self._stop_event.set()
+ """Worker is controlled via shared stop_event from GraphEngine.
+
+ This method is a no-op retained for backward compatibility.
+ """
+ pass
@property
def is_idle(self) -> bool:
@@ -111,7 +112,7 @@ class Worker(threading.Thread):
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
- id=str(uuid4()),
+ id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
@@ -130,33 +131,36 @@ class Worker(threading.Thread):
node.ensure_execution_id()
error: Exception | None = None
+ result_event: GraphNodeEventBase | None = None
- if self._flask_app and self._context_vars:
- with preserve_flask_contexts(
- flask_app=self._flask_app,
- context_vars=self._context_vars,
- ):
+ # Execute the node with preserved context if execution context is provided
+ if self._execution_context is not None:
+ with self._execution_context:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
+ if is_node_result_event(event):
+ result_event = event
except Exception as exc:
error = exc
raise
finally:
- self._invoke_node_run_end_hooks(node, error)
+ self._invoke_node_run_end_hooks(node, error, result_event)
else:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
+ if is_node_result_event(event):
+ result_event = event
except Exception as exc:
error = exc
raise
finally:
- self._invoke_node_run_end_hooks(node, error)
+ self._invoke_node_run_end_hooks(node, error, result_event)
def _invoke_node_run_start_hooks(self, node: Node) -> None:
"""Invoke on_node_run_start hooks for all layers."""
@@ -167,11 +171,13 @@ class Worker(threading.Thread):
# Silently ignore layer errors to prevent disrupting node execution
continue
- def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
+ def _invoke_node_run_end_hooks(
+ self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
"""Invoke on_node_run_end hooks for all layers."""
for layer in self._layers:
try:
- layer.on_node_run_end(node, error)
+ layer.on_node_run_end(node, error, result_event)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py
index 5b9234586b..9ce7d16e93 100644
--- a/api/core/workflow/graph_engine/worker_management/worker_pool.py
+++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py
@@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class.
import logging
import queue
import threading
-from typing import TYPE_CHECKING, final
+from typing import final
from configs import dify_config
+from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
@@ -20,11 +21,6 @@ from ..worker import Worker
logger = logging.getLogger(__name__)
-if TYPE_CHECKING:
- from contextvars import Context
-
- from flask import Flask
-
@final
class WorkerPool:
@@ -41,8 +37,8 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
- flask_app: "Flask | None" = None,
- context_vars: "Context | None" = None,
+ stop_event: threading.Event,
+ execution_context: IExecutionContext | None = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
@@ -56,8 +52,7 @@ class WorkerPool:
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
- flask_app: Optional Flask app for context preservation
- context_vars: Optional context variables
+ execution_context: Optional execution context for context preservation
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
@@ -66,8 +61,7 @@ class WorkerPool:
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
- self._flask_app = flask_app
- self._context_vars = context_vars
+ self._execution_context = execution_context
self._layers = layers
# Scaling parameters with defaults
@@ -81,6 +75,7 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
+ self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@@ -135,7 +130,7 @@ class WorkerPool:
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
- worker.join(timeout=10.0)
+ worker.join(timeout=2.0)
self._workers.clear()
@@ -150,8 +145,8 @@ class WorkerPool:
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
- flask_app=self._flask_app,
- context_vars=self._context_vars,
+ execution_context=self._execution_context,
+ stop_event=self._stop_event,
)
worker.start()
diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py
index 7bb6346cb7..56ea642092 100644
--- a/api/core/workflow/graph_events/__init__.py
+++ b/api/core/workflow/graph_events/__init__.py
@@ -46,6 +46,7 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
+ is_node_result_event,
)
__all__ = [
@@ -77,4 +78,5 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
+ "is_node_result_event",
]
diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py
index 140b4de1da..975d72ad1f 100644
--- a/api/core/workflow/graph_events/node.py
+++ b/api/core/workflow/graph_events/node.py
@@ -72,3 +72,26 @@ class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: PauseReason = Field(..., description="pause reason")
+
+
+def is_node_result_event(event: GraphNodeEventBase) -> bool:
+ """
+ Check if an event is a final result event from node execution.
+
+ A result event indicates the completion of a node execution and contains
+ runtime information such as inputs, outputs, or error details.
+
+ Args:
+ event: The event to check
+
+ Returns:
+ True if the event is a node result event (succeeded/failed/paused), False otherwise
+ """
+ return isinstance(
+ event,
+ (
+ NodeRunSucceededEvent,
+ NodeRunFailedEvent,
+ NodeRunPauseRequestedEvent,
+ ),
+ )
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 4be006de11..5a365f769d 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
@@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
- strategy: "PluginAgentStrategy",
+ strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@@ -233,7 +235,18 @@ class AgentNode(Node[AgentNodeData]):
0,
):
value_param = param.get("value", {})
- params[key] = value_param.get("value", "") if value_param is not None else None
+ if value_param and value_param.get("type", "") == "variable":
+ variable_selector = value_param.get("value")
+ if not variable_selector:
+ raise ValueError("Variable selector is missing for a variable-type parameter.")
+
+ variable = variable_pool.get(variable_selector)
+ if variable is None:
+ raise AgentVariableNotFoundError(str(variable_selector))
+
+ params[key] = variable.value
+ else:
+ params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
@@ -328,7 +341,7 @@ class AgentNode(Node[AgentNodeData]):
def _generate_credentials(
self,
parameters: dict[str, Any],
- ) -> "InvokeCredentials":
+ ) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
@@ -442,9 +455,7 @@ class AgentNode(Node[AgentNodeData]):
model_schema.features.remove(feature)
return model_schema
- def _filter_mcp_type_tool(
- self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
- ) -> list[dict[str, Any]]:
+ def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
@@ -483,7 +494,7 @@ class AgentNode(Node[AgentNodeData]):
text = ""
files: list[File] = []
- json_list: list[dict] = []
+ json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
@@ -557,13 +568,18 @@ class AgentNode(Node[AgentNodeData]):
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
- msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
- llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
- agent_execution_metadata = {
- WorkflowNodeExecutionMetadataKey(key): value
- for key, value in msg_metadata.items()
- if key in WorkflowNodeExecutionMetadataKey.__members__.values()
- }
+ if isinstance(message.message.json_object, dict):
+ msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
+ llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
+ agent_execution_metadata = {
+ WorkflowNodeExecutionMetadataKey(key): value
+ for key, value in msg_metadata.items()
+ if key in WorkflowNodeExecutionMetadataKey.__members__.values()
+ }
+ else:
+ msg_metadata = {}
+ llm_usage = LLMUsage.empty_usage()
+ agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
@@ -672,7 +688,7 @@ class AgentNode(Node[AgentNodeData]):
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
- json_output: list[dict[str, Any]] = []
+ json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py
index 944f5f0b20..ba2c83d8a6 100644
--- a/api/core/workflow/nodes/agent/exc.py
+++ b/api/core/workflow/nodes/agent/exc.py
@@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
+
+
+class AgentMaxIterationError(AgentNodeError):
+ """Exception raised when the agent exceeds the maximum iteration limit."""
+
+ def __init__(self, max_iteration: int):
+ self.max_iteration = max_iteration
+ super().__init__(
+ f"Agent exceeded the maximum iteration limit of {max_iteration}. "
+ f"The agent was unable to complete the task within the allowed number of iterations."
+ )
diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py
index 5aab6bbde4..e5a20c8e91 100644
--- a/api/core/workflow/nodes/base/entities.py
+++ b/api/core/workflow/nodes/base/entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from abc import ABC
from builtins import type as type_
@@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
- def validate_value_type(self) -> "DefaultValue":
+ def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {
diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py
index a08c6d3ed8..2b773b537c 100644
--- a/api/core/workflow/nodes/base/node.py
+++ b/api/core/workflow/nodes/base/node.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import importlib
import logging
import operator
@@ -72,7 +74,7 @@ class Node(Generic[NodeDataT]):
in its output.
"""
- node_type: ClassVar["NodeType"]
+ node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@@ -211,14 +213,14 @@ class Node(Generic[NodeDataT]):
return None
# Global registry populated via __init_subclass__
- _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
+ _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
def __init__(
self,
id: str,
config: Mapping[str, Any],
- graph_init_params: "GraphInitParams",
- graph_runtime_state: "GraphRuntimeState",
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
) -> None:
self._graph_init_params = graph_init_params
self.id = id
@@ -254,7 +256,7 @@ class Node(Generic[NodeDataT]):
return
@property
- def graph_init_params(self) -> "GraphInitParams":
+ def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
@property
@@ -300,6 +302,10 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
+ def _should_stop(self) -> bool:
+ """Check if execution should be stopped."""
+ return self.graph_runtime_state.stop_event.is_set()
+
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@@ -368,6 +374,21 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
+
+ if self._should_stop():
+ error_message = "Execution cancelled"
+ yield NodeRunFailedEvent(
+ id=self.execution_id,
+ node_id=self._node_id,
+ node_type=self.node_type,
+ start_at=self._start_at,
+ node_run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ error=error_message,
+ ),
+ error=error_message,
+ )
+ return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(
@@ -474,7 +495,7 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
- def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
+ def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.
@@ -484,12 +505,8 @@ class Node(Generic[NodeDataT]):
import core.workflow.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
- # Avoid importing modules that depend on the registry to prevent circular imports
- # e.g. node_factory imports node_mapping which builds the mapping here.
- if _modname in {
- "core.workflow.nodes.node_factory",
- "core.workflow.nodes.node_mapping",
- }:
+ # Avoid importing modules that depend on the registry to prevent circular imports.
+ if _modname == "core.workflow.nodes.node_mapping":
continue
importlib.import_module(_modname)
diff --git a/api/core/workflow/nodes/base/template.py b/api/core/workflow/nodes/base/template.py
index ba3e2058cf..81f4b9f6fb 100644
--- a/api/core/workflow/nodes/base/template.py
+++ b/api/core/workflow/nodes/base/template.py
@@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
+from __future__ import annotations
+
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
@@ -58,7 +60,7 @@ class Template:
segments: list[TemplateSegmentUnion]
@classmethod
- def from_answer_template(cls, template_str: str) -> "Template":
+ def from_answer_template(cls, template_str: str) -> Template:
"""Create a Template from an Answer node template string.
Example:
@@ -107,7 +109,7 @@ class Template:
return cls(segments=segments)
@classmethod
- def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
+ def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.
diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py
index a38e10030a..e3035d3bf0 100644
--- a/api/core/workflow/nodes/code/code_node.py
+++ b/api/core/workflow/nodes/code/code_node.py
@@ -1,8 +1,7 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
-from typing import Any, cast
+from typing import TYPE_CHECKING, Any, ClassVar, cast
-from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
@@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
+from core.workflow.nodes.code.limits import CodeNodeLimits
from .exc import (
CodeNodeError,
@@ -20,9 +20,41 @@ from .exc import (
OutputValidationError,
)
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
+ _DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
+ Python3CodeProvider,
+ JavascriptCodeProvider,
+ )
+ _limits: CodeNodeLimits
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ code_executor: type[CodeExecutor] | None = None,
+ code_providers: Sequence[type[CodeNodeProvider]] | None = None,
+ code_limits: CodeNodeLimits,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
+ self._code_providers: tuple[type[CodeNodeProvider], ...] = (
+ tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
+ )
+ self._limits = code_limits
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
- providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
- code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
+ code_provider: type[CodeNodeProvider] = next(
+ provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
+ )
return code_provider.get_default_config()
+ @classmethod
+ def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
+ return cls._DEFAULT_CODE_PROVIDERS
+
@classmethod
def version(cls) -> str:
return "1"
@@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
- result = CodeExecutor.execute_workflow_code_template(
+ _ = self._select_code_provider(code_language)
+ result = self._code_executor.execute_workflow_code_template(
language=code_language,
code=code,
inputs=variables,
@@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
+ def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
+ for provider in self._code_providers:
+ if provider.is_accept_language(code_language):
+ return provider
+ raise CodeNodeError(f"Unsupported code language: {code_language}")
+
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
@@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
- if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
+ if len(value) > self._limits.max_string_length:
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
- f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
+ f" less than {self._limits.max_string_length} characters"
)
return value.replace("\x00", "")
@@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
- if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
+ if value > self._limits.max_number or value < self._limits.min_number:
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
- f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
+ f" it must be between {self._limits.min_number} and {self._limits.max_number}."
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
- if precision > dify_config.CODE_MAX_PRECISION:
+ if precision > self._limits.max_precision:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
- f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
+ f" it must be less than {self._limits.max_precision} digits."
)
return value
@@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
- if depth > dify_config.CODE_MAX_DEPTH:
- raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
+ if depth > self._limits.max_depth:
+ raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:
@@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
- if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
+ if len(value) > self._limits.max_number_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
- f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
+ f" less than {self._limits.max_number_array_length} elements."
)
for i, inner_value in enumerate(value):
@@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
- if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
+ if len(result[output_name]) > self._limits.max_string_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
- f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
+ f" less than {self._limits.max_string_array_length} elements."
)
transformed_result[output_name] = [
@@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
- if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
+ if len(result[output_name]) > self._limits.max_object_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
- f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
+ f" less than {self._limits.max_object_array_length} elements."
)
for i, value in enumerate(result[output_name]):
diff --git a/api/core/workflow/nodes/code/limits.py b/api/core/workflow/nodes/code/limits.py
new file mode 100644
index 0000000000..a6b9e9e68e
--- /dev/null
+++ b/api/core/workflow/nodes/code/limits.py
@@ -0,0 +1,13 @@
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class CodeNodeLimits:
+ max_string_length: int
+ max_number: int | float
+ min_number: int | float
+ max_precision: int
+ max_depth: int
+ max_number_array_length: int
+ max_string_array_length: int
+ max_object_array_length: int
diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py
index bb2140f42e..925561cf7c 100644
--- a/api/core/workflow/nodes/datasource/datasource_node.py
+++ b/api/core/workflow/nodes/datasource/datasource_node.py
@@ -301,7 +301,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
text = ""
files: list[File] = []
- json: list[dict] = []
+ json: list[dict | list] = []
variables: dict[str, Any] = {}
diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py
index 931c6113a7..429f8411a6 100644
--- a/api/core/workflow/nodes/http_request/executor.py
+++ b/api/core/workflow/nodes/http_request/executor.py
@@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
+from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@@ -78,6 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
+ http_client: HttpClientProtocol = ssrf_proxy,
+ file_manager: FileManagerProtocol = file_manager,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@@ -104,6 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
+ self._http_client = http_client
+ self._file_manager = file_manager
# init template
self.variable_pool = variable_pool
@@ -200,7 +205,7 @@ class Executor:
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
- self.content = file_manager.download(file)
+ self.content = self._file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
@@ -239,7 +244,7 @@ class Executor:
):
file_tuple = (
file.filename,
- file_manager.download(file),
+ self._file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
@@ -332,19 +337,18 @@ class Executor:
do http request depending on api bundle
"""
_METHOD_MAP = {
- "get": ssrf_proxy.get,
- "head": ssrf_proxy.head,
- "post": ssrf_proxy.post,
- "put": ssrf_proxy.put,
- "delete": ssrf_proxy.delete,
- "patch": ssrf_proxy.patch,
+ "get": self._http_client.get,
+ "head": self._http_client.head,
+ "post": self._http_client.post,
+ "put": self._http_client.put,
+ "delete": self._http_client.delete,
+ "patch": self._http_client.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
- "url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
@@ -357,8 +361,12 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
- response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
- except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
+ response: httpx.Response = _METHOD_MAP[method_lc](
+ url=self.url,
+ **request_args,
+ max_retries=self.max_retries,
+ )
+ except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response
diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py
index 9bd1cb9761..964e53e03c 100644
--- a/api/core/workflow/nodes/http_request/node.py
+++ b/api/core/workflow/nodes/http_request/node.py
@@ -1,10 +1,11 @@
import logging
import mimetypes
-from collections.abc import Mapping, Sequence
-from typing import Any
+from collections.abc import Callable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any
from configs import dify_config
-from core.file import File, FileTransferMethod
+from core.file import File, FileTransferMethod, file_manager
+from core.helper import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
+from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .entities import (
@@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ http_client: HttpClientProtocol = ssrf_proxy,
+ tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
+ file_manager: FileManagerProtocol = file_manager,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ self._http_client = http_client
+ self._tool_file_manager_factory = tool_file_manager_factory
+ self._file_manager = file_manager
+
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
+ http_client=self._http_client,
+ file_manager=self._file_manager,
)
process_data["request"] = http_executor.to_log()
@@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
- tool_file_manager = ToolFileManager()
+ tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
index e5d86414c1..ced996e7e0 100644
--- a/api/core/workflow/nodes/iteration/iteration_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -1,17 +1,15 @@
-import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
-from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
-from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -51,6 +48,7 @@ from .exc import (
)
if TYPE_CHECKING:
+ from core.workflow.context import IExecutionContext
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@@ -240,7 +238,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
- dict[str, VariableUnion],
+ dict[str, Variable],
LLMUsage,
]
],
@@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._execute_single_iteration_parallel,
index=index,
item=item,
- flask_app=current_app._get_current_object(), # type: ignore
- context_vars=contextvars.copy_context(),
+ execution_context=self._capture_execution_context(),
)
future_to_index[future] = index
@@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self,
index: int,
item: object,
- flask_app: Flask,
- context_vars: contextvars.Context,
- ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
+ execution_context: "IExecutionContext",
+ ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
- with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
+ with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
@@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph_engine.graph_runtime_state.llm_usage,
)
+ def _capture_execution_context(self) -> "IExecutionContext":
+ """Capture current execution context for parallel iterations."""
+ from core.workflow.context import capture_current_context
+
+ return capture_current_context()
+
def _handle_iteration_success(
self,
started_at: datetime,
@@ -515,11 +517,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
- def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
+ def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
- def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
+ def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
@@ -586,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
+ from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
- from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index adc474bd60..8670a71aa3 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
-from sqlalchemy import and_, func, literal, or_, select
+from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
- self._process_metadata_filter_func(
+ DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
@@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
value=expected_value,
)
)
- filters = self._process_metadata_filter_func(
+ filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return [], usage
return automatic_metadata_filters, usage
- def _process_metadata_filter_func(
- self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
- ) -> list[Any]:
- if value is None and condition not in ("empty", "not empty"):
- return filters
-
- json_field = Document.doc_metadata[metadata_name].as_string()
-
- match condition:
- case "contains":
- filters.append(json_field.like(f"%{value}%"))
-
- case "not contains":
- filters.append(json_field.notlike(f"%{value}%"))
-
- case "start with":
- filters.append(json_field.like(f"{value}%"))
-
- case "end with":
- filters.append(json_field.like(f"%{value}"))
- case "in":
- if isinstance(value, str):
- value_list = [v.strip() for v in value.split(",") if v.strip()]
- elif isinstance(value, (list, tuple)):
- value_list = [str(v) for v in value if v is not None]
- else:
- value_list = [str(value)] if value is not None else []
-
- if not value_list:
- filters.append(literal(False))
- else:
- filters.append(json_field.in_(value_list))
-
- case "not in":
- if isinstance(value, str):
- value_list = [v.strip() for v in value.split(",") if v.strip()]
- elif isinstance(value, (list, tuple)):
- value_list = [str(v) for v in value if v is not None]
- else:
- value_list = [str(value)] if value is not None else []
-
- if not value_list:
- filters.append(literal(True))
- else:
- filters.append(json_field.notin_(value_list))
-
- case "is" | "=":
- if isinstance(value, str):
- filters.append(json_field == value)
- elif isinstance(value, (int, float)):
- filters.append(Document.doc_metadata[metadata_name].as_float() == value)
-
- case "is not" | "≠":
- if isinstance(value, str):
- filters.append(json_field != value)
- elif isinstance(value, (int, float)):
- filters.append(Document.doc_metadata[metadata_name].as_float() != value)
-
- case "empty":
- filters.append(Document.doc_metadata[metadata_name].is_(None))
-
- case "not empty":
- filters.append(Document.doc_metadata[metadata_name].isnot(None))
-
- case "before" | "<":
- filters.append(Document.doc_metadata[metadata_name].as_float() < value)
-
- case "after" | ">":
- filters.append(Document.doc_metadata[metadata_name].as_float() > value)
-
- case "≤" | "<=":
- filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
-
- case "≥" | ">=":
- filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
-
- case _:
- pass
-
- return filters
-
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py
index 0c545469bc..01e25cbf5c 100644
--- a/api/core/workflow/nodes/llm/llm_utils.py
+++ b/api/core/workflow/nodes/llm/llm_utils.py
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.entities.provider_entities import QuotaUnit
+from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@@ -136,21 +136,37 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
- with Session(db.engine) as session:
- stmt = (
- update(Provider)
- .where(
- Provider.tenant_id == tenant_id,
- # TODO: Use provider name with prefix after the data migration.
- Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
- Provider.provider_type == ProviderType.SYSTEM,
- Provider.quota_type == system_configuration.current_quota_type.value,
- Provider.quota_limit > Provider.quota_used,
- )
- .values(
- quota_used=Provider.quota_used + used_quota,
- last_used=naive_utc_now(),
- )
+ if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.check_and_deduct_credits(
+ tenant_id=tenant_id,
+ credits_required=used_quota,
)
- session.execute(stmt)
- session.commit()
+ elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.check_and_deduct_credits(
+ tenant_id=tenant_id,
+ credits_required=used_quota,
+ pool_type="paid",
+ )
+ else:
+ with Session(db.engine) as session:
+ stmt = (
+ update(Provider)
+ .where(
+ Provider.tenant_id == tenant_id,
+ # TODO: Use provider name with prefix after the data migration.
+ Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
+ Provider.provider_type == ProviderType.SYSTEM.value,
+ Provider.quota_type == system_configuration.current_quota_type.value,
+ Provider.quota_limit > Provider.quota_used,
+ )
+ .values(
+ quota_used=Provider.quota_used + used_quota,
+ last_used=naive_utc_now(),
+ )
+ )
+ session.execute(stmt)
+ session.commit()
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index 04e2802191..dfb55dcd80 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import io
import json
@@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]):
# Instance attributes specific to LLMNode.
# Output variable for file
- _file_outputs: list["File"]
+ _file_outputs: list[File]
_llm_file_saver: LLMFileSaver
@@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]):
self,
id: str,
config: Mapping[str, Any],
- graph_init_params: "GraphInitParams",
- graph_runtime_state: "GraphRuntimeState",
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]):
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
- file_outputs: list["File"],
+ file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
- file_outputs: list["File"],
+ file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]):
)
@staticmethod
- def _image_file_to_markdown(file: "File", /):
+ def _image_file_to_markdown(file: File, /):
text_chunk = f"})"
return text_chunk
@@ -774,7 +776,7 @@ class LLMNode(Node[LLMNodeData]):
def fetch_prompt_messages(
*,
sys_query: str | None = None,
- sys_files: Sequence["File"],
+ sys_files: Sequence[File],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@@ -785,7 +787,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
- context_files: list["File"] | None = None,
+ context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | LLMResultWithStructuredOutput,
saver: LLMFileSaver,
- file_outputs: list["File"],
+ file_outputs: list[File],
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_latency: float | None = None,
) -> ModelInvokeCompletedEvent:
@@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]):
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
- ) -> "File":
+ ) -> File:
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]):
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
- file_outputs: list["File"],
+ file_outputs: list[File],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py
index 1f9fc8a115..07d05966cc 100644
--- a/api/core/workflow/nodes/loop/loop_node.py
+++ b/api/core/workflow/nodes/loop/loop_node.py
@@ -413,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
+ from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
- from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py
deleted file mode 100644
index c55ad346bf..0000000000
--- a/api/core/workflow/nodes/node_factory.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from typing import TYPE_CHECKING, final
-
-from typing_extensions import override
-
-from core.workflow.enums import NodeType
-from core.workflow.graph import NodeFactory
-from core.workflow.nodes.base.node import Node
-from libs.typing import is_str, is_str_dict
-
-from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
-
-if TYPE_CHECKING:
- from core.workflow.entities import GraphInitParams
- from core.workflow.runtime import GraphRuntimeState
-
-
-@final
-class DifyNodeFactory(NodeFactory):
- """
- Default implementation of NodeFactory that uses the traditional node mapping.
-
- This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
- and instantiating the appropriate node class.
- """
-
- def __init__(
- self,
- graph_init_params: "GraphInitParams",
- graph_runtime_state: "GraphRuntimeState",
- ) -> None:
- self.graph_init_params = graph_init_params
- self.graph_runtime_state = graph_runtime_state
-
- @override
- def create_node(self, node_config: dict[str, object]) -> Node:
- """
- Create a Node instance from node configuration data using the traditional mapping.
-
- :param node_config: node configuration dictionary containing type and other data
- :return: initialized Node instance
- :raises ValueError: if node type is unknown or configuration is invalid
- """
- # Get node_id from config
- node_id = node_config.get("id")
- if not is_str(node_id):
- raise ValueError("Node config missing id")
-
- # Get node type from config
- node_data = node_config.get("data", {})
- if not is_str_dict(node_data):
- raise ValueError(f"Node {node_id} missing data information")
-
- node_type_str = node_data.get("type")
- if not is_str(node_type_str):
- raise ValueError(f"Node {node_id} missing or invalid type information")
-
- try:
- node_type = NodeType(node_type_str)
- except ValueError:
- raise ValueError(f"Unknown node type: {node_type_str}")
-
- # Get node class
- node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
- if not node_mapping:
- raise ValueError(f"No class mapping found for node type: {node_type}")
-
- latest_node_class = node_mapping.get(LATEST_VERSION)
- node_version = str(node_data.get("version", "1"))
- matched_node_class = node_mapping.get(node_version)
- node_class = matched_node_class or latest_node_class
- if not node_class:
- raise ValueError(f"No latest version class found for node type: {node_type}")
-
- # Create node instance
- return node_class(
- id=node_id,
- config=node_config,
- graph_init_params=self.graph_init_params,
- graph_runtime_state=self.graph_runtime_state,
- )
diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py
new file mode 100644
index 0000000000..e7dcf62fcf
--- /dev/null
+++ b/api/core/workflow/nodes/protocols.py
@@ -0,0 +1,29 @@
+from typing import Protocol
+
+import httpx
+
+from core.file import File
+
+
+class HttpClientProtocol(Protocol):
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]: ...
+
+ @property
+ def request_error(self) -> type[Exception]: ...
+
+ def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+
+class FileManagerProtocol(Protocol):
+ def download(self, f: File, /) -> bytes: ...
diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py
index 36fc5078c5..53c1b4ee6b 100644
--- a/api/core/workflow/nodes/start/start_node.py
+++ b/api/core/workflow/nodes/start/start_node.py
@@ -1,4 +1,3 @@
-import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
+ # If no value provided, skip further processing for this key
+ if not value:
+ continue
+
+ if not isinstance(value, dict):
+ raise ValueError(f"JSON object for '{key}' must be an object")
+
+ # Overwrite with normalized dict to ensure downstream consistency
+ node_inputs[key] = value
+
+ # If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
- if not value:
- continue
-
try:
- json_schema = json.loads(schema)
- except json.JSONDecodeError as e:
- raise ValueError(f"{schema} must be a valid JSON object")
-
- try:
- json_value = json.loads(value)
- except json.JSONDecodeError as e:
- raise ValueError(f"{value} must be a valid JSON object")
-
- try:
- Draft7Validator(json_schema).validate(json_value)
+ Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
- node_inputs[key] = json_value
diff --git a/api/core/workflow/nodes/template_transform/template_renderer.py b/api/core/workflow/nodes/template_transform/template_renderer.py
new file mode 100644
index 0000000000..a5f06bf2bb
--- /dev/null
+++ b/api/core/workflow/nodes/template_transform/template_renderer.py
@@ -0,0 +1,40 @@
+from __future__ import annotations
+
+from collections.abc import Mapping
+from typing import Any, Protocol
+
+from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
+
+
+class TemplateRenderError(ValueError):
+ """Raised when rendering a Jinja2 template fails."""
+
+
+class Jinja2TemplateRenderer(Protocol):
+ """Render Jinja2 templates for template transform nodes."""
+
+ def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
+ """Render a Jinja2 template with provided variables."""
+ raise NotImplementedError
+
+
+class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
+ """Adapter that renders Jinja2 templates via CodeExecutor."""
+
+ _code_executor: type[CodeExecutor]
+
+ def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
+ self._code_executor = code_executor or CodeExecutor
+
+ def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
+ try:
+ result = self._code_executor.execute_workflow_code_template(
+ language=CodeLanguage.JINJA2, code=template, inputs=variables
+ )
+ except CodeExecutionError as exc:
+ raise TemplateRenderError(str(exc)) from exc
+
+ rendered = result.get("result")
+ if not isinstance(rendered, str):
+ raise TemplateRenderError("Template render result must be a string.")
+ return rendered
diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py
index 2274323960..f7e0bccccf 100644
--- a/api/core/workflow/nodes/template_transform/template_transform_node.py
+++ b/api/core/workflow/nodes/template_transform/template_transform_node.py
@@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
-from typing import Any
+from typing import TYPE_CHECKING, Any
from configs import dify_config
-from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
+from core.workflow.nodes.template_transform.template_renderer import (
+ CodeExecutorJinja2TemplateRenderer,
+ Jinja2TemplateRenderer,
+ TemplateRenderError,
+)
+
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
+ _template_renderer: Jinja2TemplateRenderer
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ template_renderer: Jinja2TemplateRenderer | None = None,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
- result = CodeExecutor.execute_workflow_code_template(
- language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
- )
- except CodeExecutionError as e:
+ rendered = self._template_renderer.render_template(self.node_data.template, variables)
+ except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
- if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
+ if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
- status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
+ status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index 2e7ec757b4..68ac60e4f6 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]):
text = ""
files: list[File] = []
- json: list[dict] = []
+ json: list[dict | list] = []
variables: dict[str, Any] = {}
@@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]):
message.message.metadata = dict_metadata
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
- json_output: list[dict[str, Any]] = []
+ json_output: list[dict[str, Any] | list[Any]] = []
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py
deleted file mode 100644
index 050e213535..0000000000
--- a/api/core/workflow/nodes/variable_assigner/common/impl.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from sqlalchemy import select
-from sqlalchemy.orm import Session
-
-from core.variables.variables import Variable
-from extensions.ext_database import db
-from models import ConversationVariable
-
-from .exc import VariableOperatorNodeError
-
-
-class ConversationVariableUpdaterImpl:
- def update(self, conversation_id: str, variable: Variable):
- stmt = select(ConversationVariable).where(
- ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
- )
- with Session(db.engine) as session:
- row = session.scalar(stmt)
- if not row:
- raise VariableOperatorNodeError("conversation variable not found in the database")
- row.data = variable.model_dump_json()
- session.commit()
-
- def flush(self):
- pass
-
-
-def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
- return ConversationVariableUpdaterImpl()
diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py
index da23207b62..9f5818f4bb 100644
--- a/api/core/workflow/nodes/variable_assigner/v1/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v1/node.py
@@ -1,9 +1,8 @@
-from collections.abc import Callable, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, TypeAlias
+from collections.abc import Mapping, Sequence
+from typing import TYPE_CHECKING, Any
-from core.variables import SegmentType, Variable
+from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
-from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
-from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
-_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
-
-
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
- _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
def __init__(
self,
@@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
):
super().__init__(
id=id,
@@ -39,7 +32,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
- self._conv_var_updater_factory = conv_var_updater_factory
+
+ def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
+ """
+ Check if this Variable Assigner node blocks the output of specific variables.
+
+ Returns True if this node updates any of the requested conversation variables.
+ """
+ assigned_selector = tuple(self.node_data.assigned_variable_selector)
+ return assigned_selector in variable_selectors
@classmethod
def version(cls) -> str:
@@ -72,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
- if not isinstance(original_variable, Variable):
+ if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
@@ -96,16 +97,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
- # TODO: Move database operation to the pipeline.
- # Update conversation variable.
- conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
- if not conversation_id:
- raise VariableOperatorNodeError("conversation_id not found")
- conv_var_updater = self._conv_var_updater_factory()
- conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
- conv_var_updater.flush()
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
-
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py
index 389fb54d35..5857702e72 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/node.py
@@ -1,24 +1,20 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
-from typing import Any, cast
+from typing import TYPE_CHECKING, Any
-from core.app.entities.app_invoke_entities import InvokeFrom
-from core.variables import SegmentType, Variable
+from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
-from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
-from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
- ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidDataError,
InvalidInputValueError,
@@ -26,6 +22,10 @@ from .exc import (
VariableNotFoundError,
)
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ ):
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
@@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
return False
- def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
- return conversation_variable_updater_factory()
-
@classmethod
def version(cls) -> str:
return "2"
@@ -107,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
- conv_var_updater = self._conv_var_updater_factory()
- # Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
- if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
- conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
- if not conversation_id:
- if self.invoke_from != InvokeFrom.DEBUGGER:
- raise ConversationIDNotFoundError
- else:
- conversation_id = conversation_id.value
- conv_var_updater.update(
- conversation_id=cast(str, conversation_id),
- variable=variable,
- )
- conv_var_updater.flush()
updated_variables = [
common_helpers.variable_to_processed_data(selector, seg)
for selector in updated_variable_selectors
@@ -216,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
- variable: Variable,
+ variable: VariableBase,
operation: Operation,
value: Any,
):
diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/core/workflow/repositories/draft_variable_repository.py
index 97bfcd5666..66ef714c16 100644
--- a/api/core/workflow/repositories/draft_variable_repository.py
+++ b/api/core/workflow/repositories/draft_variable_repository.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import abc
from collections.abc import Mapping
from typing import Any, Protocol
@@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
- ) -> "DraftVariableSaver":
+ ) -> DraftVariableSaver:
pass
diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py
index e060cb704d..f79230217c 100644
--- a/api/core/workflow/runtime/graph_runtime_state.py
+++ b/api/core/workflow/runtime/graph_runtime_state.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import importlib
import json
+import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
@@ -211,6 +212,8 @@ class GraphRuntimeState:
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
+ self.stop_event: threading.Event = threading.Event()
+
if graph is not None:
self.attach_graph(graph)
diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py
index 5e0878e873..bfbb5ba704 100644
--- a/api/core/workflow/runtime/graph_runtime_state_protocol.py
+++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py
@@ -1,4 +1,4 @@
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
- def get(self, node_id: str, variable_key: str) -> Segment | None:
+ def get(self, selector: Sequence[str], /) -> Segment | None:
"""Get a variable value (read-only)."""
...
diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py
index 8539727fd6..d3e4c60d9b 100644
--- a/api/core/workflow/runtime/read_only_wrappers.py
+++ b/api/core/workflow/runtime/read_only_wrappers.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any
@@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
- def get(self, node_id: str, variable_key: str) -> Segment | None:
+ def get(self, selector: Sequence[str], /) -> Segment | None:
"""Return a copy of a variable value if present."""
- value = self._variable_pool.get([node_id, variable_key])
+ value = self._variable_pool.get(selector)
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py
index 7fbaec9e70..c4b077fa69 100644
--- a/api/core/workflow/runtime/variable_pool.py
+++ b/api/core/workflow/runtime/variable_pool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
@@ -7,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
-from core.variables import Segment, SegmentGroup, Variable
+from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
-from core.variables.variables import RAGPipelineVariableInput, VariableUnion
+from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@@ -30,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
- variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
+ variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@@ -42,15 +44,15 @@ class VariablePool(BaseModel):
)
system_variables: SystemVariable = Field(
description="System variables",
- default_factory=SystemVariable.empty,
+ default_factory=SystemVariable.default,
)
- environment_variables: Sequence[VariableUnion] = Field(
+ environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
- default_factory=list[VariableUnion],
+ default_factory=list[Variable],
)
- conversation_variables: Sequence[VariableUnion] = Field(
+ conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
- default_factory=list[VariableUnion],
+ default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@@ -103,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
- if isinstance(value, Variable):
+ if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@@ -112,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
+ self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
@@ -267,6 +269,6 @@ class VariablePool(BaseModel):
self.add(selector, value)
@classmethod
- def empty(cls) -> "VariablePool":
+ def empty(cls) -> VariablePool:
"""Create an empty variable pool."""
- return cls(system_variables=SystemVariable.empty())
+ return cls(system_variables=SystemVariable.default())
diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py
index ad925912a4..6946e3e6ab 100644
--- a/api/core/workflow/system_variable.py
+++ b/api/core/workflow/system_variable.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Any
+from uuid import uuid4
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
@@ -70,8 +73,8 @@ class SystemVariable(BaseModel):
return data
@classmethod
- def empty(cls) -> "SystemVariable":
- return cls()
+ def default(cls) -> SystemVariable:
+ return cls(workflow_execution_id=str(uuid4()))
def to_dict(self) -> dict[SystemVariableKey, Any]:
# NOTE: This method is provided for compatibility with legacy code.
@@ -114,7 +117,7 @@ class SystemVariable(BaseModel):
d[SystemVariableKey.TIMESTAMP] = self.timestamp
return d
- def as_view(self) -> "SystemVariableReadOnlyView":
+ def as_view(self) -> SystemVariableReadOnlyView:
return SystemVariableReadOnlyView(self)
diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py
index ea0bdc3537..7992785fe1 100644
--- a/api/core/workflow/variable_loader.py
+++ b/api/core/workflow/variable_loader.py
@@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
-from core.variables import Variable
+from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
- :return: a list of Variable objects that match the provided selectors.
+ :return: a list of VariableBase objects that match the provided selectors.
"""
pass
@@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index ddf545bb34..b645f29d27 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -7,6 +7,8 @@ from typing import Any
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.workflow.layers.observability import ObservabilityLayer
+from core.app.workflow.node_factory import DifyNodeFactory
from core.file.models import File
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
@@ -14,7 +16,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
-from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer
+from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
@@ -136,13 +138,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
- node_config = workflow.get_node_config_by_id(node_id)
+ node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
- # Get node class
+ # Get node type
node_type = NodeType(node_config_data.get("type"))
- node_version = node_config_data.get("version", "1")
- node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -158,12 +158,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
- node = node_cls(
- id=str(uuid.uuid4()),
- config=node_config,
+ node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
+ node = node_factory.create_node(node_config)
+ node_cls = type(node)
try:
# variable selector to variable mapping
@@ -190,8 +190,7 @@ class WorkflowEntry:
)
try:
- # run node
- generator = node.run()
+ generator = cls._traced_node_run(node)
except Exception as e:
logger.exception(
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
@@ -278,7 +277,7 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=[],
)
@@ -324,8 +323,7 @@ class WorkflowEntry:
tenant_id=tenant_id,
)
- # run node
- generator = node.run()
+ generator = cls._traced_node_run(node)
return node, generator
except Exception as e:
@@ -431,3 +429,26 @@ class WorkflowEntry:
input_value = current_variable.value | input_value
variable_pool.add([variable_node_id] + variable_key_list, input_value)
+
+ @staticmethod
+ def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]:
+ """
+ Wraps a node's run method with OpenTelemetry tracing and returns a generator.
+ """
+ # Wrap node.run() with ObservabilityLayer hooks to produce node-level spans
+ layer = ObservabilityLayer()
+ layer.on_graph_start()
+ node.ensure_execution_id()
+
+ def _gen():
+ error: Exception | None = None
+ layer.on_node_run_start(node)
+ try:
+ yield from node.run()
+ except Exception as exc:
+ error = exc
+ raise
+ finally:
+ layer.on_node_run_end(node, error)
+
+ return _gen()
diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh
index 470d75d352..7c2b1ba181 100755
--- a/api/docker/entrypoint.sh
+++ b/api/docker/entrypoint.sh
@@ -3,8 +3,9 @@
set -e
# Set UTF-8 encoding to address potential encoding issues in containerized environments
-export LANG=${LANG:-en_US.UTF-8}
-export LC_ALL=${LC_ALL:-en_US.UTF-8}
+# Use C.UTF-8 which is universally available in all containers
+export LANG=${LANG:-C.UTF-8}
+export LC_ALL=${LC_ALL:-C.UTF-8}
export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8}
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
diff --git a/api/enums/hosted_provider.py b/api/enums/hosted_provider.py
new file mode 100644
index 0000000000..c6d3715dc1
--- /dev/null
+++ b/api/enums/hosted_provider.py
@@ -0,0 +1,21 @@
+from enum import StrEnum
+
+
+class HostedTrialProvider(StrEnum):
+ """
+ Enum representing hosted model provider names for trial access.
+ """
+
+ OPENAI = "langgenius/openai/openai"
+ ANTHROPIC = "langgenius/anthropic/anthropic"
+ GEMINI = "langgenius/gemini/google"
+ X = "langgenius/x/x"
+ DEEPSEEK = "langgenius/deepseek/deepseek"
+ TONGYI = "langgenius/tongyi/tongyi"
+
+ @property
+ def config_key(self) -> str:
+ """Return the config key used in dify_config (e.g., HOSTED_{config_key}_PAID_ENABLED)."""
+ if self == HostedTrialProvider.X:
+ return "XAI"
+ return self.name
diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py
index c79764983b..d37217e168 100644
--- a/api/events/event_handlers/__init__.py
+++ b/api/events/event_handlers/__init__.py
@@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
)
+from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
@@ -30,6 +31,7 @@ __all__ = [
"handle_create_installed_app_when_app_created",
"handle_create_site_record_when_app_created",
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
+ "handle_queue_credential_sync_when_tenant_created",
"handle_sync_plugin_trigger_when_app_created",
"handle_sync_webhook_when_app_created",
"handle_sync_workflow_schedule_when_app_published",
diff --git a/api/events/event_handlers/queue_credential_sync_when_tenant_created.py b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py
new file mode 100644
index 0000000000..6566c214b0
--- /dev/null
+++ b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py
@@ -0,0 +1,19 @@
+from configs import dify_config
+from events.tenant_event import tenant_was_created
+from services.enterprise.workspace_sync import WorkspaceSyncService
+
+
+@tenant_was_created.connect
+def handle(sender, **kwargs):
+ """Queue credential sync when a tenant/workspace is created."""
+ # Only queue sync tasks if plugin manager (enterprise feature) is enabled
+ if not dify_config.ENTERPRISE_ENABLED:
+ return
+
+ tenant = sender
+
+ # Determine source from kwargs if available, otherwise use generic
+ source = kwargs.get("source", "tenant_created")
+
+ # Queue credential sync task to Redis for enterprise backend to process
+ WorkspaceSyncService.queue_credential_sync(tenant.id, source=source)
diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py
index 84266ab0fa..1ddcc8f792 100644
--- a/api/events/event_handlers/update_provider_when_message_created.py
+++ b/api/events/event_handlers/update_provider_when_message_created.py
@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
-from core.entities.provider_entities import QuotaUnit, SystemConfiguration
+from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
@@ -134,22 +134,38 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration,
model_name=model_config.model,
)
-
if used_quota is not None:
- quota_update = _ProviderUpdateOperation(
- filters=_ProviderUpdateFilters(
+ if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
- provider_name=ModelProviderID(model_config.provider).provider_name,
- provider_type=ProviderType.SYSTEM,
- quota_type=provider_configuration.system_configuration.current_quota_type.value,
- ),
- values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
- additional_filters=_ProviderUpdateAdditionalFilters(
- quota_limit_check=True # Provider.quota_limit > Provider.quota_used
- ),
- description="quota_deduction_update",
- )
- updates_to_perform.append(quota_update)
+ credits_required=used_quota,
+ pool_type="trial",
+ )
+ elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.check_and_deduct_credits(
+ tenant_id=tenant_id,
+ credits_required=used_quota,
+ pool_type="paid",
+ )
+ else:
+ quota_update = _ProviderUpdateOperation(
+ filters=_ProviderUpdateFilters(
+ tenant_id=tenant_id,
+ provider_name=ModelProviderID(model_config.provider).provider_name,
+ provider_type=ProviderType.SYSTEM.value,
+ quota_type=provider_configuration.system_configuration.current_quota_type.value,
+ ),
+ values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
+ additional_filters=_ProviderUpdateAdditionalFilters(
+ quota_limit_check=True # Provider.quota_limit > Provider.quota_used
+ ),
+ description="quota_deduction_update",
+ )
+ updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()
diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py
index cf994c11df..7d13f0c061 100644
--- a/api/extensions/ext_blueprints.py
+++ b/api/extensions/ext_blueprints.py
@@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
+EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
@@ -42,10 +43,28 @@ def init_app(app: DifyApp):
_apply_cors_once(
web_bp,
- resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
- supports_credentials=True,
- allow_headers=list(AUTHENTICATED_HEADERS),
- methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+ resources={
+ # Embedded bot endpoints (unauthenticated, cross-origin safe)
+ r"^/chat-messages$": {
+ "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
+ "supports_credentials": False,
+ "allow_headers": list(EMBED_HEADERS),
+ "methods": ["GET", "POST", "OPTIONS"],
+ },
+ r"^/chat-messages/.*": {
+ "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
+ "supports_credentials": False,
+ "allow_headers": list(EMBED_HEADERS),
+ "methods": ["GET", "POST", "OPTIONS"],
+ },
+ # Default web application endpoints (authenticated)
+ r"/*": {
+ "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
+ "supports_credentials": True,
+ "allow_headers": list(AUTHENTICATED_HEADERS),
+ "methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+ },
+ },
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(web_bp)
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 9d042a291a..decf8655da 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -12,9 +12,8 @@ from dify_app import DifyApp
def _get_celery_ssl_options() -> dict[str, Any] | None:
"""Get SSL configuration for Celery broker/backend connections."""
- # Use REDIS_USE_SSL for consistency with the main Redis client
# Only apply SSL if we're using Redis as broker/backend
- if not dify_config.REDIS_USE_SSL:
+ if not dify_config.BROKER_USE_SSL:
return None
# Check if Celery is actually using Redis
@@ -47,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
+ from core.logging.context import init_request_context
+
with app.app_context():
+ # Initialize logging context for this task (similar to before_request in Flask)
+ init_request_context()
return self.run(*args, **kwargs)
broker_transport_options = {}
@@ -166,6 +169,13 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
"schedule": crontab(minute="0", hour="2"),
}
+ if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
+ # for saas only
+ imports.append("schedule.clean_workflow_runs_task")
+ beat_schedule["clean_workflow_runs_task"] = {
+ "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
+ "schedule": crontab(minute="0", hour="0"),
+ }
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {
diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py
index 71a63168a5..46885761a1 100644
--- a/api/extensions/ext_commands.py
+++ b/api/extensions/ext_commands.py
@@ -4,13 +4,18 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
+ archive_workflow_runs,
+ clean_expired_messages,
+ clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
convert_to_agent_apps,
create_tenant,
+ delete_archived_workflow_runs,
extract_plugins,
extract_unique_plugins,
+ file_usage,
fix_app_site_missing,
install_plugins,
install_rag_pipeline_plugins,
@@ -21,6 +26,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
+ restore_workflow_runs,
setup_datasource_oauth_client,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
@@ -47,6 +53,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
+ file_usage,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
cleanup_orphaned_draft_variables,
@@ -54,6 +61,11 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
+ archive_workflow_runs,
+ delete_archived_workflow_runs,
+ restore_workflow_runs,
+ clean_workflow_runs,
+ clean_expired_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)
diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py
index c90b1d0a9f..2e0d4c889a 100644
--- a/api/extensions/ext_database.py
+++ b/api/extensions/ext_database.py
@@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
def init_app(app: DifyApp):
db.init_app(app)
_setup_gevent_compatibility()
+
+ # Eagerly build the engine so pool_size/max_overflow/etc. come from config
+ try:
+ with app.app_context():
+ _ = db.engine # triggers engine creation with the configured options
+ except Exception:
+ logger.exception("Failed to initialize SQLAlchemy engine during app startup")
diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py
new file mode 100644
index 0000000000..e6c1bc6bee
--- /dev/null
+++ b/api/extensions/ext_fastopenapi.py
@@ -0,0 +1,45 @@
+from fastopenapi.routers import FlaskRouter
+from flask_cors import CORS
+
+from configs import dify_config
+from controllers.fastopenapi import console_router
+from dify_app import DifyApp
+from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
+
+DOCS_PREFIX = "/fastopenapi"
+
+
+def init_app(app: DifyApp) -> None:
+ docs_enabled = dify_config.SWAGGER_UI_ENABLED
+ docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
+ redoc_url = f"{DOCS_PREFIX}/redoc" if docs_enabled else None
+ openapi_url = f"{DOCS_PREFIX}/openapi.json" if docs_enabled else None
+
+ router = FlaskRouter(
+ app=app,
+ docs_url=docs_url,
+ redoc_url=redoc_url,
+ openapi_url=openapi_url,
+ openapi_version="3.0.0",
+ title="Dify API (FastOpenAPI PoC)",
+ version="1.0",
+ description="FastOpenAPI proof of concept for Dify API",
+ )
+
+ # Ensure route decorators are evaluated.
+ import controllers.console.ping as ping_module
+ from controllers.console import setup
+
+ _ = ping_module
+ _ = setup
+
+ router.include_router(console_router, prefix="/console/api")
+ CORS(
+ app,
+ resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
+ supports_credentials=True,
+ allow_headers=list(AUTHENTICATED_HEADERS),
+ methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+ expose_headers=list(EXPOSED_HEADERS),
+ )
+ app.extensions["fastopenapi"] = router
diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py
index 000d03ac41..978a40c503 100644
--- a/api/extensions/ext_logging.py
+++ b/api/extensions/ext_logging.py
@@ -1,18 +1,19 @@
+"""Logging extension for Dify Flask application."""
+
import logging
import os
import sys
-import uuid
from logging.handlers import RotatingFileHandler
-import flask
-
from configs import dify_config
-from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
def init_app(app: DifyApp):
+ """Initialize logging with support for text or JSON format."""
log_handlers: list[logging.Handler] = []
+
+ # File handler
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@@ -25,27 +26,53 @@ def init_app(app: DifyApp):
)
)
- # Always add StreamHandler to log to console
+ # Console handler
sh = logging.StreamHandler(sys.stdout)
log_handlers.append(sh)
- # Apply RequestIdFilter to all handlers
- for handler in log_handlers:
- handler.addFilter(RequestIdFilter())
+ # Apply filters to all handlers
+ from core.logging.filters import IdentityContextFilter, TraceContextFilter
+ for handler in log_handlers:
+ handler.addFilter(TraceContextFilter())
+ handler.addFilter(IdentityContextFilter())
+
+ # Configure formatter based on format type
+ formatter = _create_formatter()
+ for handler in log_handlers:
+ handler.setFormatter(formatter)
+
+ # Configure root logger
logging.basicConfig(
level=dify_config.LOG_LEVEL,
- format=dify_config.LOG_FORMAT,
- datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
- # Apply RequestIdFormatter to all handlers
- apply_request_id_formatter()
-
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
+
+ # Apply timezone if specified (only for text format)
+ if dify_config.LOG_OUTPUT_FORMAT == "text":
+ _apply_timezone(log_handlers)
+
+
+def _create_formatter() -> logging.Formatter:
+ """Create appropriate formatter based on configuration."""
+ if dify_config.LOG_OUTPUT_FORMAT == "json":
+ from core.logging.structured_formatter import StructuredJSONFormatter
+
+ return StructuredJSONFormatter()
+ else:
+ # Text format - use existing pattern with backward compatible formatter
+ return _TextFormatter(
+ fmt=dify_config.LOG_FORMAT,
+ datefmt=dify_config.LOG_DATEFORMAT,
+ )
+
+
+def _apply_timezone(handlers: list[logging.Handler]):
+ """Apply timezone conversion to text formatters."""
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime
@@ -57,34 +84,51 @@ def init_app(app: DifyApp):
def time_converter(seconds):
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
- for handler in logging.root.handlers:
+ for handler in handlers:
if handler.formatter:
- handler.formatter.converter = time_converter
+ handler.formatter.converter = time_converter # type: ignore[attr-defined]
-def get_request_id():
- if getattr(flask.g, "request_id", None):
- return flask.g.request_id
+class _TextFormatter(logging.Formatter):
+ """Text formatter that ensures trace_id and req_id are always present."""
- new_uuid = uuid.uuid4().hex[:10]
- flask.g.request_id = new_uuid
-
- return new_uuid
+ def format(self, record: logging.LogRecord) -> str:
+ if not hasattr(record, "req_id"):
+ record.req_id = ""
+ if not hasattr(record, "trace_id"):
+ record.trace_id = ""
+ if not hasattr(record, "span_id"):
+ record.span_id = ""
+ return super().format(record)
+def get_request_id() -> str:
+ """Get request ID for current request context.
+
+ Deprecated: Use core.logging.context.get_request_id() directly.
+ """
+ from core.logging.context import get_request_id as _get_request_id
+
+ return _get_request_id()
+
+
+# Backward compatibility aliases
class RequestIdFilter(logging.Filter):
- # This is a logging filter that makes the request ID available for use in
- # the logging format. Note that we're checking if we're in a request
- # context, as we may want to log things before Flask is fully loaded.
- def filter(self, record):
- trace_id = get_trace_id_from_otel_context() or ""
- record.req_id = get_request_id() if flask.has_request_context() else ""
- record.trace_id = trace_id
+ """Deprecated: Use TraceContextFilter from core.logging.filters instead."""
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ from core.logging.context import get_request_id as _get_request_id
+ from core.logging.context import get_trace_id as _get_trace_id
+
+ record.req_id = _get_request_id()
+ record.trace_id = _get_trace_id()
return True
class RequestIdFormatter(logging.Formatter):
- def format(self, record):
+ """Deprecated: Use _TextFormatter instead."""
+
+ def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
@@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
def apply_request_id_formatter():
+ """Deprecated: Formatter is now applied in init_app."""
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py
index 502f0bb46b..cda2d1ad1e 100644
--- a/api/extensions/ext_logstore.py
+++ b/api/extensions/ext_logstore.py
@@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv
+from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@@ -19,12 +20,17 @@ def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
+ Logstore is considered enabled when:
+ 1. All required Aliyun SLS environment variables are set
+ 2. At least one repository configuration points to a logstore implementation
+
Returns:
- True if all required Aliyun SLS environment variables are set, False otherwise
+ True if logstore should be initialized, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
+ # Check if Aliyun SLS connection parameters are configured
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
@@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME",
]
- all_set = all(os.environ.get(var) for var in required_vars)
+ sls_vars_set = all(os.environ.get(var) for var in required_vars)
- if not all_set:
- logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
+ if not sls_vars_set:
+ return False
- return all_set
+ # Check if any repository configuration points to logstore implementation
+ repository_configs = [
+ dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
+ dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
+ dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
+ dify_config.API_WORKFLOW_RUN_REPOSITORY,
+ ]
+
+ uses_logstore = any("logstore" in config.lower() for config in repository_configs)
+
+ if not uses_logstore:
+ return False
+
+ logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
+ return True
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
-
- This function:
- 1. Creates Aliyun SLS project if it doesn't exist
- 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
- 3. Creates indexes with field configurations based on PostgreSQL table structures
-
- This operation is idempotent and only executes once during application startup.
+ If initialization fails, the application continues running without logstore features.
Args:
app: The Dify application instance
@@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
- logger.info("Initializing logstore...")
+ logger.info("Initializing Aliyun SLS Logstore...")
- # Create logstore client and initialize project/logstores/indexes
+ # Create logstore client and initialize resources
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
- # Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
+
except Exception:
- logger.exception("Failed to initialize logstore")
- # Don't raise - allow application to continue even if logstore init fails
- # This ensures that the application can still run if logstore is misconfigured
+ logger.exception(
+ "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
+ "Application will continue but logstore features will NOT work.",
+ os.environ.get("ALIYUN_SLS_ENDPOINT"),
+ os.environ.get("ALIYUN_SLS_REGION"),
+ os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
+ os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
+ )
+ # Don't raise - allow application to continue even if logstore setup fails
diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py
index 22d1f473a3..f6a4765f14 100644
--- a/api/extensions/logstore/aliyun_logstore.py
+++ b/api/extensions/logstore/aliyun_logstore.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import logging
import os
+import socket
import threading
import time
from collections.abc import Sequence
@@ -33,7 +36,7 @@ class AliyunLogStore:
Ensures only one instance exists to prevent multiple PG connection pools.
"""
- _instance: "AliyunLogStore | None" = None
+ _instance: AliyunLogStore | None = None
_initialized: bool = False
# Track delayed PG connection for newly created projects
@@ -66,7 +69,7 @@ class AliyunLogStore:
"\t",
]
- def __new__(cls) -> "AliyunLogStore":
+ def __new__(cls) -> AliyunLogStore:
"""Implement singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
@@ -177,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
- self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
+ self.log_enabled: bool = (
+ os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
+ or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
+ )
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
+ # Get timeout configuration
+ check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
+
+ # Pre-check endpoint connectivity to prevent indefinite hangs
+ self._check_endpoint_connectivity(self.endpoint, check_timeout)
+
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@@ -197,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True
+ @staticmethod
+ def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
+ """
+ Check if the SLS endpoint is reachable before creating LogClient.
+ Prevents indefinite hangs when the endpoint is unreachable.
+
+ Args:
+ endpoint: SLS endpoint URL
+ timeout: Connection timeout in seconds
+
+ Raises:
+ ConnectionError: If endpoint is not reachable
+ """
+ # Parse endpoint URL to extract hostname and port
+ from urllib.parse import urlparse
+
+ parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
+ hostname = parsed_url.hostname
+ port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
+
+ if not hostname:
+ raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
+
+ sock = None
+ try:
+ # Create socket and set timeout
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(timeout)
+ sock.connect((hostname, port))
+ except Exception as e:
+ # Catch all exceptions and provide clear error message
+ error_type = type(e).__name__
+ raise ConnectionError(
+ f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
+ ) from e
+ finally:
+ # Ensure socket is properly closed
+ if sock:
+ try:
+ sock.close()
+ except Exception: # noqa: S110
+ pass # Ignore errors during cleanup
+
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
@@ -218,19 +273,16 @@ class AliyunLogStore:
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
- logger.info("Successfully connected to project %s using PG protocol", self.project_name)
+ logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
- logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
+ logger.info("Using SDK mode for project %s", self.project_name)
return False
except Exception as e:
- logger.warning(
- "Failed to establish PG connection for project %s: %s. Will use SDK mode.",
- self.project_name,
- str(e),
- )
+ logger.info("Using SDK mode for project %s", self.project_name)
+ logger.debug("PG connection details: %s", str(e))
self._use_pg_protocol = False
return False
@@ -244,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol:
return
- logger.info(
- "Attempting delayed PG connection for newly created project %s ...",
- self.project_name,
- )
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
@@ -282,11 +330,7 @@ class AliyunLogStore:
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
- logger.info(
- "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
- self.project_name,
- self.__class__._pg_connection_delay,
- )
+ logger.info("Using SDK mode for project %s (newly created)", self.project_name)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
@@ -297,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
- logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@@ -316,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
- "Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
- "PG protocol requires scan_index to be enabled.",
+ "Logstore %s requires scan_index enabled, using SDK mode for project %s",
logstore_name,
+ self.project_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
@@ -746,7 +789,6 @@ class AliyunLogStore:
reverse=reverse,
)
- # Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@@ -768,7 +810,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
- # Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@@ -843,7 +884,6 @@ class AliyunLogStore:
query=full_query,
)
- # Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@@ -851,8 +891,7 @@ class AliyunLogStore:
self.project_name,
from_time,
to_time,
- query,
- sql,
+ full_query,
)
try:
@@ -863,7 +902,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
- # Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py
index 35aa51ce53..874c20d144 100644
--- a/api/extensions/logstore/aliyun_logstore_pg.py
+++ b/api/extensions/logstore/aliyun_logstore_pg.py
@@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any
import psycopg2
-import psycopg2.pool
-from psycopg2 import InterfaceError, OperationalError
+from sqlalchemy import create_engine
from configs import dify_config
@@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG:
- """
- PostgreSQL protocol support for Aliyun SLS LogStore.
-
- Handles PG connection pooling and operations for regions that support PG protocol.
- """
+ """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
@@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
- self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
+ self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
- """
- Check if a TCP port is reachable using socket connection.
-
- This provides a fast check before attempting full database connection,
- preventing long waits when connecting to unsupported regions.
-
- Args:
- host: Hostname or IP address
- port: Port number
- timeout: Connection timeout in seconds (default: 2.0)
-
- Returns:
- True if port is reachable, False otherwise
- """
+ """Fast TCP port check to avoid long waits on unsupported regions."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
@@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False
def init_connection(self) -> bool:
- """
- Initialize PostgreSQL connection pool for SLS PG protocol support.
-
- Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
- _use_pg_protocol to True and creates a connection pool. If connection fails
- (region doesn't support PG protocol or other errors), returns False.
-
- Returns:
- True if PG protocol is supported and initialized, False otherwise
- """
+ """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
try:
- # Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
- # Get pool configuration
- pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
+ # Pool configuration
+ pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
+ max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
+ pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
+ pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
- logger.debug(
- "Check PG protocol connection to SLS: host=%s, project=%s",
- pg_host,
- self.project_name,
- )
+ logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
- # Fast port connectivity check before attempting full connection
- # This prevents long waits when connecting to unsupported regions
+ # Fast port check to avoid long waits
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
- logger.info(
- "USE SDK mode for read/write operations, host=%s",
- pg_host,
- )
+ logger.debug("Using SDK mode for host=%s", pg_host)
return False
- # Create connection pool
- self._pg_pool = psycopg2.pool.SimpleConnectionPool(
- minconn=1,
- maxconn=pg_max_connections,
- host=pg_host,
- port=5432,
- database=self.project_name,
- user=self._access_key_id,
- password=self._access_key_secret,
- sslmode="require",
- connect_timeout=5,
- application_name=f"Dify-{dify_config.project.version}",
+ # Build connection URL
+ from urllib.parse import quote_plus
+
+ username = quote_plus(self._access_key_id)
+ password = quote_plus(self._access_key_secret)
+ database_url = (
+ f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
)
- # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
- # Connection pool creation success already indicates connectivity
+ # Create SQLAlchemy engine with connection pool
+ self._engine = create_engine(
+ database_url,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ pool_pre_ping=pool_pre_ping,
+ pool_timeout=30,
+ connect_args={
+ "connect_timeout": 5,
+ "application_name": f"Dify-{dify_config.project.version}-fixautocommit",
+ "keepalives": 1,
+ "keepalives_idle": 60,
+ "keepalives_interval": 10,
+ "keepalives_count": 5,
+ },
+ )
self._use_pg_protocol = True
logger.info(
- "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
+ "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name,
+ pool_size,
+ pool_recycle,
)
return True
except Exception as e:
- # PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
- if self._pg_pool:
+ if self._engine:
try:
- self._pg_pool.closeall()
+ self._engine.dispose()
except Exception:
- logger.debug("Failed to close PG connection pool during cleanup, ignoring")
- self._pg_pool = None
+ logger.debug("Failed to dispose engine during cleanup, ignoring")
+ self._engine = None
- logger.info(
- "PG protocol connection failed (region may not support PG protocol): %s. "
- "Falling back to SDK mode for read/write operations.",
- str(e),
- )
- return False
-
- def _is_connection_valid(self, conn: Any) -> bool:
- """
- Check if a connection is still valid.
-
- Args:
- conn: psycopg2 connection object
-
- Returns:
- True if connection is valid, False otherwise
- """
- try:
- # Check if connection is closed
- if conn.closed:
- return False
-
- # Quick ping test - execute a lightweight query
- # For SLS PG protocol, we can't use SELECT 1 without FROM,
- # so we just check the connection status
- with conn.cursor() as cursor:
- cursor.execute("SELECT 1")
- cursor.fetchone()
- return True
- except Exception:
+ logger.debug("Using SDK mode for region: %s", str(e))
return False
@contextmanager
def _get_connection(self):
- """
- Context manager to get a PostgreSQL connection from the pool.
+ """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
+ if not self._engine:
+ raise RuntimeError("SQLAlchemy engine is not initialized")
- Automatically validates and refreshes stale connections.
-
- Note: Aliyun SLS PG protocol does not support transactions, so we always
- use autocommit mode.
-
- Yields:
- psycopg2 connection object
-
- Raises:
- RuntimeError: If PG pool is not initialized
- """
- if not self._pg_pool:
- raise RuntimeError("PG connection pool is not initialized")
-
- conn = self._pg_pool.getconn()
+ connection = self._engine.raw_connection()
try:
- # Validate connection and get a fresh one if needed
- if not self._is_connection_valid(conn):
- logger.debug("Connection is stale, marking as bad and getting a new one")
- # Mark connection as bad and get a new one
- self._pg_pool.putconn(conn, close=True)
- conn = self._pg_pool.getconn()
-
- # Aliyun SLS PG protocol does not support transactions, always use autocommit
- conn.autocommit = True
- yield conn
+ connection.autocommit = True # SLS PG protocol does not support transactions
+ yield connection
+ except Exception:
+ raise
finally:
- # Return connection to pool (or close if it's bad)
- if self._is_connection_valid(conn):
- self._pg_pool.putconn(conn)
- else:
- self._pg_pool.putconn(conn, close=True)
+ connection.close()
def close(self) -> None:
- """Close the PostgreSQL connection pool."""
- if self._pg_pool:
+ """Dispose SQLAlchemy engine and close all connections."""
+ if self._engine:
try:
- self._pg_pool.closeall()
- logger.info("PG connection pool closed")
+ self._engine.dispose()
+ logger.info("SQLAlchemy engine disposed")
except Exception:
- logger.exception("Failed to close PG connection pool")
+ logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool:
- """
- Check if an error is retriable (connection-related issues).
-
- Args:
- error: Exception to check
-
- Returns:
- True if the error is retriable, False otherwise
- """
- # Retry on connection-related errors
- if isinstance(error, (OperationalError, InterfaceError)):
+ """Check if error is retriable (connection-related issues)."""
+ # Check for psycopg2 connection errors directly
+ if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
return True
- # Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
@@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer",
"no route to host",
"network",
+ "operational error",
+ "interface error",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
- """
- Write log to SLS using PostgreSQL protocol with automatic retry.
-
- Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
- writes with log_version field for versioning, same as SDK implementation.
-
- Args:
- logstore: Name of the logstore table
- contents: List of (field_name, value) tuples
- log_enabled: Whether to enable logging
-
- Raises:
- psycopg2.Error: If database operation fails after all retries
- """
+ """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
if not contents:
return
- # Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
-
- # Build INSERT statement with literal values
- # Note: Aliyun SLS PG protocol doesn't support parameterized queries,
- # so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
@@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents),
)
- # Retry configuration
max_retries = 3
- retry_delay = 0.1 # Start with 100ms
+ retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
- # Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
- # Success - exit retry loop
return
except psycopg2.Error as e:
- # Check if error is retriable
if not self._is_retriable_error(e):
- # Not a retriable error (e.g., data validation error), fail immediately
- logger.exception(
- "Failed to put logs to logstore %s via PG protocol (non-retriable error)",
- logstore,
- )
+ logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
raise
- # Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
- "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
+ "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
- retry_delay *= 2 # Exponential backoff
+ retry_delay *= 2
else:
- # Last attempt failed
- logger.exception(
- "Failed to put logs to logstore %s via PG protocol after %d attempts",
- logstore,
- max_retries,
- )
+ logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
- """
- Execute SQL query using PostgreSQL protocol with automatic retry.
-
- Args:
- sql: SQL query string
- logstore: Name of the logstore (for logging purposes)
- log_enabled: Whether to enable logging
-
- Returns:
- List of result rows as dictionaries
-
- Raises:
- psycopg2.Error: If database operation fails after all retries
- """
+ """Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql,
)
- # Retry configuration
max_retries = 3
- retry_delay = 0.1 # Start with 100ms
+ retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
-
- # Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
- # Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
@@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result
except psycopg2.Error as e:
- # Check if error is retriable
if not self._is_retriable_error(e):
- # Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
- "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
+ "Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore,
sql,
)
raise
- # Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
- "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
+ "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
- retry_delay *= 2 # Exponential backoff
+ retry_delay *= 2
else:
- # Last attempt failed
logger.exception(
- "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
+ "Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
- # This line should never be reached due to raise above, but makes type checker happy
return []
diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py
index e69de29bb2..b5a4fcf844 100644
--- a/api/extensions/logstore/repositories/__init__.py
+++ b/api/extensions/logstore/repositories/__init__.py
@@ -0,0 +1,29 @@
+"""
+LogStore repository utilities.
+"""
+
+from typing import Any
+
+
+def safe_float(value: Any, default: float = 0.0) -> float:
+ """
+ Safely convert a value to float, handling 'null' strings and None.
+ """
+ if value is None or value in {"null", ""}:
+ return default
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ return default
+
+
+def safe_int(value: Any, default: int = 0) -> int:
+ """
+ Safely convert a value to int, handling 'null' strings and None.
+ """
+ if value is None or value in {"null", ""}:
+ return default
+ try:
+ return int(float(value))
+ except (ValueError, TypeError):
+ return default
diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
index 1cae14b726..817c8b0448 100644
--- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
@@ -15,6 +15,8 @@ from sqlalchemy.orm import sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@@ -53,9 +55,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
- # Numeric fields with defaults
- model.index = int(data.get("index", 0))
- model.elapsed_time = float(data.get("elapsed_time", 0))
+ model.index = safe_int(data.get("index", 0))
+ model.elapsed_time = safe_float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
@@ -131,6 +132,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id,
)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_workflow_id = escape_identifier(workflow_id)
+ escaped_node_id = escape_identifier(node_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@@ -139,10 +146,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE tenant_id = '{tenant_id}'
- AND app_id = '{app_id}'
- AND workflow_id = '{workflow_id}'
- AND node_id = '{node_id}'
+ WHERE tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND workflow_id = '{escaped_workflow_id}'
+ AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@@ -154,7 +161,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else:
# Use SDK with LogStore query syntax
query = (
- f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
+ f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
+ f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
)
from_time = 0
to_time = int(time.time()) # now
@@ -230,6 +238,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id,
)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_workflow_run_id = escape_identifier(workflow_run_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@@ -238,9 +251,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE tenant_id = '{tenant_id}'
- AND app_id = '{app_id}'
- AND workflow_run_id = '{workflow_run_id}'
+ WHERE tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
@@ -251,7 +264,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
- query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
+ query = (
+ f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
+ f"and workflow_run_id: {escaped_workflow_run_id}"
+ )
from_time = 0
to_time = int(time.time()) # now
@@ -318,16 +334,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_execution_id = escape_identifier(execution_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
- tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
+ if tenant_id:
+ escaped_tenant_id = escape_identifier(tenant_id)
+ tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
+ else:
+ tenant_filter = ""
+
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
+ WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
@@ -337,10 +361,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
+ # Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id:
- query = f"id: {execution_id} and tenant_id: {tenant_id}"
+ query = (
+ f"id:{escape_logstore_query_value(execution_id)} "
+ f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
+ )
else:
- query = f"id: {execution_id}"
+ query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0
to_time = int(time.time()) # now
diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
index 252cdcc4df..14382ed876 100644
--- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
+++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
@@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
+- SQL injection prevention via parameter escaping
"""
import logging
@@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
@@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
- # Numeric fields with defaults
- model.total_tokens = int(data.get("total_tokens", 0))
- model.total_steps = int(data.get("total_steps", 0))
- model.exceptions_count = int(data.get("exceptions_count", 0))
+ model.total_tokens = safe_int(data.get("total_tokens", 0))
+ model.total_steps = safe_int(data.get("total_steps", 0))
+ model.exceptions_count = safe_int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
@@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
- model.elapsed_time = float(data.get("elapsed_time", 0))
+ # Use safe conversion to handle 'null' strings and None values
+ model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model
@@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status,
)
# Convert triggered_from to list if needed
- if isinstance(triggered_from, WorkflowRunTriggeredFrom):
+ if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
- # Build triggered_from filter
- triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
- # Build status filter
- status_filter = f"AND status='{status}'" if status else ""
+ # Build triggered_from filter with escaped values
+ # Support both enum and string values for triggered_from
+ triggered_from_filter = " OR ".join(
+ [
+ f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
+ for tf in triggered_from_list
+ ]
+ )
+
+ # Build status filter with escaped value
+ status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
@@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
@@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_run_id = escape_identifier(run_id)
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
- WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
+ WHERE id = '{escaped_run_id}'
+ AND tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
- query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
+ # Note: Values must be quoted in LogStore query syntax to prevent injection
+ query = (
+ f"id:{escape_logstore_query_value(run_id)} "
+ f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
+ f"and app_id:{escape_logstore_query_value(app_id)}"
+ )
from_time = 0
to_time = int(time.time()) # now
@@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
+ # Escape parameter to prevent SQL injection
+ escaped_run_id = escape_identifier(run_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
- WHERE id = '{run_id}' AND __time__ > 0
+ WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
- query = f"id: {run_id}"
+ # Note: Values must be quoted in LogStore query syntax
+ query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0
to_time = int(time.time()) # now
@@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from,
status,
)
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
# Build time range filter
time_filter = ""
if time_range:
@@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count
if status:
+ escaped_status = escape_sql_string(status)
+
if status == "running":
# Running status requires window function
sql = f"""
@@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
- AND status='{status}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND status='{escaped_status}'
AND finished_at IS NOT NULL
{time_filter}
"""
@@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
+ # Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
@@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by
diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
index 6e6631cfef..9928879a7b 100644
--- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
@@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
@@ -67,7 +68,12 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
- self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
+ self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
+
+ # Control flag for whether to write the `graph` field to LogStore.
+ # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
+ # otherwise write an empty {} instead. Defaults to writing the `graph` field.
+ self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
"""
@@ -96,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
+ # Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
+ json_converter = WorkflowRuntimeTypeConverter()
+
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
@@ -108,9 +117,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
),
("type", domain_model.workflow_type.value),
("version", domain_model.workflow_version),
- ("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
- ("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
- ("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
+ (
+ "graph",
+ json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
+ if domain_model.graph and self._enable_put_graph_field
+ else "{}",
+ ),
+ (
+ "inputs",
+ json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
+ if domain_model.inputs
+ else "{}",
+ ),
+ (
+ "outputs",
+ json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
+ if domain_model.outputs
+ else "{}",
+ ),
("status", domain_model.status.value),
("error_message", domain_model.error_message or ""),
("total_tokens", str(domain_model.total_tokens)),
diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
index 400a089516..4897171b12 100644
--- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
@@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id
from models import (
Account,
@@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
- index=int(data.get("index", 0)),
+ index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
@@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs,
status=status,
error=data.get("error"),
- elapsed_time=float(data.get("elapsed_time", 0.0)),
+ elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
@@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
- self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
+ self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
@@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
- For LogStore implementation, this is similar to save() since we always write
- complete records. We append a new record with updated data fields.
+ For LogStore implementation, this is a no-op for the LogStore write because save()
+ already writes all fields including inputs, process_data, and outputs. The caller
+ typically calls save() first to persist status/metadata, then calls save_execution_data()
+ to persist data fields. Since LogStore writes complete records atomically, we don't
+ need a separate write here to avoid duplicate records.
+
+ However, if dual-write is enabled, we still need to call the SQL repository's
+ save_execution_data() method to properly update the SQL database.
Args:
execution: The NodeExecution instance with data to save
"""
- logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
- # In LogStore, we simply write a new complete record with the data
- # The log_version timestamp will ensure this is treated as the latest version
- self.save(execution)
+ logger.debug(
+ "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
+ execution.id,
+ execution.node_execution_id,
+ )
+ # No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
+ # Calling save() again would create a duplicate record in the append-only LogStore
+
+ # Dual-write to SQL database if enabled (for safe migration)
+ if self._enable_dual_write:
+ try:
+ self.sql_repository.save_execution_data(execution)
+ logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
+ except Exception:
+ logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
+ # Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run(
self,
@@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
- Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
- This ensures we only get the final version of each node execution.
+ Uses LogStore SQL query with window function to get the latest version of each node execution.
+ This ensures we only get the most recent version of each node execution record.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
@@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances
Note:
- This method filters by finished_at IS NOT NULL to avoid duplicates from
- version updates. For complete history including intermediate states,
- a different query strategy would be needed.
+ This method uses ROW_NUMBER() window function partitioned by node_execution_id
+ to get the latest version (highest log_version) of each node execution.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
- # Build SQL query with deduplication using finished_at IS NOT NULL
- # This optimization avoids window functions for common case where we only
- # want the final state of each node execution
+ # Build SQL query with deduplication using window function
+ # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
+ # ensures we get the latest version of each node execution
- # Build ORDER BY clause
+ # Escape parameters to prevent SQL injection
+ escaped_workflow_run_id = escape_identifier(workflow_run_id)
+ escaped_tenant_id = escape_identifier(self._tenant_id)
+
+ # Build ORDER BY clause for outer query
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
@@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
- sql = f"""
- SELECT *
- FROM {AliyunLogStore.workflow_node_execution_logstore}
- WHERE workflow_run_id='{workflow_run_id}'
- AND tenant_id='{self._tenant_id}'
- AND finished_at IS NOT NULL
- """
-
+ # Build app_id filter for subquery
+ app_id_filter = ""
if self._app_id:
- sql += f" AND app_id='{self._app_id}'"
+ escaped_app_id = escape_identifier(self._app_id)
+ app_id_filter = f" AND app_id='{escaped_app_id}'"
+
+ # Use window function to get latest version of each node execution
+ sql = f"""
+ SELECT * FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
+ FROM {AliyunLogStore.workflow_node_execution_logstore}
+ WHERE workflow_run_id='{escaped_workflow_run_id}'
+ AND tenant_id='{escaped_tenant_id}'
+ {app_id_filter}
+ ) t
+ WHERE rn = 1
+ """
if order_clause:
sql += f" {order_clause}"
diff --git a/api/extensions/logstore/sql_escape.py b/api/extensions/logstore/sql_escape.py
new file mode 100644
index 0000000000..d88d6bd959
--- /dev/null
+++ b/api/extensions/logstore/sql_escape.py
@@ -0,0 +1,134 @@
+"""
+SQL Escape Utility for LogStore Queries
+
+This module provides escaping utilities to prevent injection attacks in LogStore queries.
+
+LogStore supports two query modes:
+1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
+2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
+
+Key Security Concerns:
+- Prevent tenant A from accessing tenant B's data via injection
+- SLS queries are read-only, so we focus on data access control
+- Different escaping strategies for SQL vs LogStore query syntax
+"""
+
+
+def escape_sql_string(value: str) -> str:
+ """
+ Escape a string value for safe use in SQL queries.
+
+ This function escapes single quotes by doubling them, which is the standard
+ SQL escaping method. This prevents SQL injection by ensuring that user input
+ cannot break out of string literals.
+
+ Args:
+ value: The string value to escape
+
+ Returns:
+ Escaped string safe for use in SQL queries
+
+ Examples:
+ >>> escape_sql_string("normal_value")
+ "normal_value"
+ >>> escape_sql_string("value' OR '1'='1")
+ "value'' OR ''1''=''1"
+ >>> escape_sql_string("tenant's_id")
+ "tenant''s_id"
+
+ Security:
+ - Prevents breaking out of string literals
+ - Stops injection attacks like: ' OR '1'='1
+ - Protects against cross-tenant data access
+ """
+ if not value:
+ return value
+
+ # Escape single quotes by doubling them (standard SQL escaping)
+ # This prevents breaking out of string literals in SQL queries
+ return value.replace("'", "''")
+
+
+def escape_identifier(value: str) -> str:
+ """
+ Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
+
+ This function is for PG protocol mode (SQL syntax).
+ For SDK mode, use escape_logstore_query_value() instead.
+
+ Args:
+ value: The identifier value to escape
+
+ Returns:
+ Escaped identifier safe for use in SQL queries
+
+ Examples:
+ >>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
+ "550e8400-e29b-41d4-a716-446655440000"
+ >>> escape_identifier("tenant_id' OR '1'='1")
+ "tenant_id'' OR ''1''=''1"
+
+ Security:
+ - Prevents SQL injection via identifiers
+ - Stops cross-tenant access attempts
+ - Works for UUIDs, alphanumeric IDs, and similar identifiers
+ """
+ # For identifiers, use the same escaping as strings
+ # This is simple and effective for preventing injection
+ return escape_sql_string(value)
+
+
+def escape_logstore_query_value(value: str) -> str:
+ """
+ Escape value for LogStore query syntax (SDK mode).
+
+ LogStore query syntax rules:
+ 1. Keywords (and/or/not) are case-insensitive
+ 2. Single quotes are ordinary characters (no special meaning)
+ 3. Double quotes wrap values: key:"value"
+ 4. Backslash is the escape character:
+ - \" for double quote inside value
+ - \\ for backslash itself
+ 5. Parentheses can change query structure
+
+ To prevent injection:
+ - Wrap value in double quotes to treat special chars as literals
+ - Escape backslashes and double quotes using backslash
+
+ Args:
+ value: The value to escape for LogStore query syntax
+
+ Returns:
+ Quoted and escaped value safe for LogStore query syntax (includes the quotes)
+
+ Examples:
+ >>> escape_logstore_query_value("normal_value")
+ '"normal_value"'
+ >>> escape_logstore_query_value("value or field:evil")
+ '"value or field:evil"' # 'or' and ':' are now literals
+ >>> escape_logstore_query_value('value"test')
+ '"value\\"test"' # Internal double quote escaped
+ >>> escape_logstore_query_value('value\\test')
+ '"value\\\\test"' # Backslash escaped
+
+ Security:
+ - Prevents injection via and/or/not keywords
+ - Prevents injection via colons (:)
+ - Prevents injection via parentheses
+ - Protects against cross-tenant data access
+
+ Note:
+ Escape order is critical: backslash first, then double quotes.
+ Otherwise, we'd double-escape the escape character itself.
+ """
+ if not value:
+ return '""'
+
+ # IMPORTANT: Escape backslashes FIRST, then double quotes
+ # This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
+ escaped = value.replace("\\", "\\\\") # \ -> \\
+ escaped = escaped.replace('"', '\\"') # " -> \"
+
+ # Wrap in double quotes to treat as literal string
+ # This prevents and/or/not/:/() from being interpreted as operators
+ return f'"{escaped}"'
diff --git a/api/extensions/otel/instrumentation.py b/api/extensions/otel/instrumentation.py
index 3597110cba..6617f69513 100644
--- a/api/extensions/otel/instrumentation.py
+++ b/api/extensions/otel/instrumentation.py
@@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
class ExceptionLoggingHandler(logging.Handler):
+ """
+ Handler that records exceptions to the current OpenTelemetry span.
+
+ Unlike creating a new span, this records exceptions on the existing span
+ to maintain trace context consistency throughout the request lifecycle.
+ """
+
def emit(self, record: logging.LogRecord):
with contextlib.suppress(Exception):
- if record.exc_info:
- tracer = get_tracer_provider().get_tracer("dify.exception.logging")
- with tracer.start_as_current_span(
- "log.exception",
- attributes={
- "log.level": record.levelname,
- "log.message": record.getMessage(),
- "log.logger": record.name,
- "log.file.path": record.pathname,
- "log.file.line": record.lineno,
- },
- ) as span:
- span.set_status(StatusCode.ERROR)
- if record.exc_info[1]:
- span.record_exception(record.exc_info[1])
- span.set_attribute("exception.message", str(record.exc_info[1]))
- if record.exc_info[0]:
- span.set_attribute("exception.type", record.exc_info[0].__name__)
+ if not record.exc_info:
+ return
+
+ from opentelemetry.trace import get_current_span
+
+ span = get_current_span()
+ if not span or not span.is_recording():
+ return
+
+ # Record exception on the current span instead of creating a new one
+ span.set_status(StatusCode.ERROR, record.getMessage())
+
+ # Add log context as span events/attributes
+ span.add_event(
+ "log.exception",
+ attributes={
+ "log.level": record.levelname,
+ "log.message": record.getMessage(),
+ "log.logger": record.name,
+ "log.file.path": record.pathname,
+ "log.file.line": record.lineno,
+ },
+ )
+
+ if record.exc_info[1]:
+ span.record_exception(record.exc_info[1])
+ if record.exc_info[0]:
+ span.set_attribute("exception.type", record.exc_info[0].__name__)
def instrument_exception_logging() -> None:
diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py
new file mode 100644
index 0000000000..164db7c275
--- /dev/null
+++ b/api/extensions/otel/parser/__init__.py
@@ -0,0 +1,20 @@
+"""
+OpenTelemetry node parsers for workflow nodes.
+
+This module provides parsers that extract node-specific metadata and set
+OpenTelemetry span attributes according to semantic conventions.
+"""
+
+from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
+from extensions.otel.parser.llm import LLMNodeOTelParser
+from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
+from extensions.otel.parser.tool import ToolNodeOTelParser
+
+__all__ = [
+ "DefaultNodeOTelParser",
+ "LLMNodeOTelParser",
+ "NodeOTelParser",
+ "RetrievalNodeOTelParser",
+ "ToolNodeOTelParser",
+ "safe_json_dumps",
+]
diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py
new file mode 100644
index 0000000000..f4db26e840
--- /dev/null
+++ b/api/extensions/otel/parser/base.py
@@ -0,0 +1,117 @@
+"""
+Base parser interface and utilities for OpenTelemetry node parsers.
+"""
+
+import json
+from typing import Any, Protocol
+
+from opentelemetry.trace import Span
+from opentelemetry.trace.status import Status, StatusCode
+from pydantic import BaseModel
+
+from core.file.models import File
+from core.variables import Segment
+from core.workflow.enums import NodeType
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
+
+
+def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
+ """
+ Safely serialize objects to JSON, handling non-serializable types.
+
+ Handles:
+ - Segment types (ArrayFileSegment, FileSegment, etc.) - converts to their value
+ - File objects - converts to dict using to_dict()
+ - BaseModel objects - converts using model_dump()
+ - Other types - falls back to str() representation
+
+ Args:
+ obj: Object to serialize
+ ensure_ascii: Whether to ensure ASCII encoding
+
+ Returns:
+ JSON string representation of the object
+ """
+
+ def _convert_value(value: Any) -> Any:
+ """Recursively convert non-serializable values."""
+ if value is None:
+ return None
+ if isinstance(value, (bool, int, float, str)):
+ return value
+ if isinstance(value, Segment):
+ # Convert Segment to its underlying value
+ return _convert_value(value.value)
+ if isinstance(value, File):
+ # Convert File to dict
+ return value.to_dict()
+ if isinstance(value, BaseModel):
+ # Convert Pydantic model to dict
+ return _convert_value(value.model_dump(mode="json"))
+ if isinstance(value, dict):
+ return {k: _convert_value(v) for k, v in value.items()}
+ if isinstance(value, (list, tuple)):
+ return [_convert_value(item) for item in value]
+ # Fallback to string representation for unknown types
+ return str(value)
+
+ try:
+ converted = _convert_value(obj)
+ return json.dumps(converted, ensure_ascii=ensure_ascii)
+ except (TypeError, ValueError) as e:
+ # If conversion still fails, return error message as string
+ return json.dumps(
+ {"error": f"Failed to serialize: {type(obj).__name__}", "message": str(e)}, ensure_ascii=ensure_ascii
+ )
+
+
+class NodeOTelParser(Protocol):
+ """Parser interface for node-specific OpenTelemetry enrichment."""
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None: ...
+
+
+class DefaultNodeOTelParser:
+ """Fallback parser used when no node-specific parser is registered."""
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ span.set_attribute("node.id", node.id)
+ if node.execution_id:
+ span.set_attribute("node.execution_id", node.execution_id)
+ if hasattr(node, "node_type") and node.node_type:
+ span.set_attribute("node.type", node.node_type.value)
+
+ span.set_attribute(GenAIAttributes.FRAMEWORK, "dify")
+
+ node_type = getattr(node, "node_type", None)
+ if isinstance(node_type, NodeType):
+ if node_type == NodeType.LLM:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
+ elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
+ elif node_type == NodeType.TOOL:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
+ else:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
+ else:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
+
+ # Extract inputs and outputs from result_event
+ if result_event and result_event.node_run_result:
+ node_run_result = result_event.node_run_result
+ if node_run_result.inputs:
+ span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
+ if node_run_result.outputs:
+ span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
+
+ if error:
+ span.record_exception(error)
+ span.set_status(Status(StatusCode.ERROR, str(error)))
+ else:
+ span.set_status(Status(StatusCode.OK))
diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py
new file mode 100644
index 0000000000..8556974080
--- /dev/null
+++ b/api/extensions/otel/parser/llm.py
@@ -0,0 +1,155 @@
+"""
+Parser for LLM nodes that captures LLM-specific metadata.
+"""
+
+import logging
+from collections.abc import Mapping
+from typing import Any
+
+from opentelemetry.trace import Span
+
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
+from extensions.otel.semconv.gen_ai import LLMAttributes
+
+logger = logging.getLogger(__name__)
+
+
+def _format_input_messages(process_data: Mapping[str, Any]) -> str:
+ """
+ Format input messages from process_data for LLM spans.
+
+ Args:
+ process_data: Process data containing prompts
+
+ Returns:
+ JSON string of formatted input messages
+ """
+ try:
+ if not isinstance(process_data, dict):
+ return safe_json_dumps([])
+
+ prompts = process_data.get("prompts", [])
+ if not prompts:
+ return safe_json_dumps([])
+
+ valid_roles = {"system", "user", "assistant", "tool"}
+ input_messages = []
+ for prompt in prompts:
+ if not isinstance(prompt, dict):
+ continue
+
+ role = prompt.get("role", "")
+ text = prompt.get("text", "")
+
+ if not role or role not in valid_roles:
+ continue
+
+ if text:
+ message = {"role": role, "parts": [{"type": "text", "content": text}]}
+ input_messages.append(message)
+
+ return safe_json_dumps(input_messages)
+ except Exception as e:
+ logger.warning("Failed to format input messages: %s", e, exc_info=True)
+ return safe_json_dumps([])
+
+
+def _format_output_messages(outputs: Mapping[str, Any]) -> str:
+ """
+ Format output messages from outputs for LLM spans.
+
+ Args:
+ outputs: Output data containing text and finish_reason
+
+ Returns:
+ JSON string of formatted output messages
+ """
+ try:
+ if not isinstance(outputs, dict):
+ return safe_json_dumps([])
+
+ text = outputs.get("text", "")
+ finish_reason = outputs.get("finish_reason", "")
+
+ if not text:
+ return safe_json_dumps([])
+
+ valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
+ if finish_reason not in valid_finish_reasons:
+ finish_reason = "stop"
+
+ output_message = {
+ "role": "assistant",
+ "parts": [{"type": "text", "content": text}],
+ "finish_reason": finish_reason,
+ }
+
+ return safe_json_dumps([output_message])
+ except Exception as e:
+ logger.warning("Failed to format output messages: %s", e, exc_info=True)
+ return safe_json_dumps([])
+
+
+class LLMNodeOTelParser:
+ """Parser for LLM nodes that captures LLM-specific metadata."""
+
+ def __init__(self) -> None:
+ self._delegate = DefaultNodeOTelParser()
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
+
+ if not result_event or not result_event.node_run_result:
+ return
+
+ node_run_result = result_event.node_run_result
+ process_data = node_run_result.process_data or {}
+ outputs = node_run_result.outputs or {}
+
+ # Extract usage data (from process_data or outputs)
+ usage_data = process_data.get("usage") or outputs.get("usage") or {}
+
+ # Model and provider information
+ model_name = process_data.get("model_name") or ""
+ model_provider = process_data.get("model_provider") or ""
+
+ if model_name:
+ span.set_attribute(LLMAttributes.REQUEST_MODEL, model_name)
+ if model_provider:
+ span.set_attribute(LLMAttributes.PROVIDER_NAME, model_provider)
+
+ # Token usage
+ if usage_data:
+ prompt_tokens = usage_data.get("prompt_tokens", 0)
+ completion_tokens = usage_data.get("completion_tokens", 0)
+ total_tokens = usage_data.get("total_tokens", 0)
+
+ span.set_attribute(LLMAttributes.USAGE_INPUT_TOKENS, prompt_tokens)
+ span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens)
+ span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens)
+
+ # Prompts and completion
+ prompts = process_data.get("prompts", [])
+ if prompts:
+ prompts_json = safe_json_dumps(prompts)
+ span.set_attribute(LLMAttributes.PROMPT, prompts_json)
+
+ text_output = str(outputs.get("text", ""))
+ if text_output:
+ span.set_attribute(LLMAttributes.COMPLETION, text_output)
+
+ # Finish reason
+ finish_reason = outputs.get("finish_reason") or ""
+ if finish_reason:
+ span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason)
+
+ # Structured input/output messages
+ gen_ai_input_message = _format_input_messages(process_data)
+ gen_ai_output_message = _format_output_messages(outputs)
+
+ span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
+ span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)
diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py
new file mode 100644
index 0000000000..fc151af691
--- /dev/null
+++ b/api/extensions/otel/parser/retrieval.py
@@ -0,0 +1,105 @@
+"""
+Parser for knowledge retrieval nodes that captures retrieval-specific metadata.
+"""
+
+import logging
+from collections.abc import Sequence
+from typing import Any
+
+from opentelemetry.trace import Span
+
+from core.variables import Segment
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
+from extensions.otel.semconv.gen_ai import RetrieverAttributes
+
+logger = logging.getLogger(__name__)
+
+
+def _format_retrieval_documents(retrieval_documents: list[Any]) -> list:
+ """
+ Format retrieval documents for semantic conventions.
+
+ Args:
+ retrieval_documents: List of retrieval document dictionaries
+
+ Returns:
+ List of formatted semantic documents
+ """
+ try:
+ if not isinstance(retrieval_documents, list):
+ return []
+
+ semantic_documents = []
+ for doc in retrieval_documents:
+ if not isinstance(doc, dict):
+ continue
+
+ metadata = doc.get("metadata", {})
+ content = doc.get("content", "")
+ title = doc.get("title", "")
+ score = metadata.get("score", 0.0)
+ document_id = metadata.get("document_id", "")
+
+ semantic_metadata = {}
+ if title:
+ semantic_metadata["title"] = title
+ if metadata.get("source"):
+ semantic_metadata["source"] = metadata["source"]
+ elif metadata.get("_source"):
+ semantic_metadata["source"] = metadata["_source"]
+ if metadata.get("doc_metadata"):
+ doc_metadata = metadata["doc_metadata"]
+ if isinstance(doc_metadata, dict):
+ semantic_metadata.update(doc_metadata)
+
+ semantic_doc = {
+ "document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
+ }
+ semantic_documents.append(semantic_doc)
+
+ return semantic_documents
+ except Exception as e:
+ logger.warning("Failed to format retrieval documents: %s", e, exc_info=True)
+ return []
+
+
+class RetrievalNodeOTelParser:
+ """Parser for knowledge retrieval nodes that captures retrieval-specific metadata."""
+
+ def __init__(self) -> None:
+ self._delegate = DefaultNodeOTelParser()
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
+
+ if not result_event or not result_event.node_run_result:
+ return
+
+ node_run_result = result_event.node_run_result
+ inputs = node_run_result.inputs or {}
+ outputs = node_run_result.outputs or {}
+
+ # Extract query from inputs
+ query = str(inputs.get("query", "")) if inputs else ""
+ if query:
+ span.set_attribute(RetrieverAttributes.QUERY, query)
+
+ # Extract and format retrieval documents from outputs
+ result_value = outputs.get("result") if outputs else None
+ retrieval_documents: list[Any] = []
+ if result_value:
+ value_to_check = result_value
+ if isinstance(result_value, Segment):
+ value_to_check = result_value.value
+
+ if isinstance(value_to_check, (list, Sequence)):
+ retrieval_documents = list(value_to_check)
+
+ if retrieval_documents:
+ semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
+ semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
+ span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)
diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py
new file mode 100644
index 0000000000..b99180722b
--- /dev/null
+++ b/api/extensions/otel/parser/tool.py
@@ -0,0 +1,47 @@
+"""
+Parser for tool nodes that captures tool-specific metadata.
+"""
+
+from opentelemetry.trace import Span
+
+from core.workflow.enums import WorkflowNodeExecutionMetadataKey
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from core.workflow.nodes.tool.entities import ToolNodeData
+from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
+from extensions.otel.semconv.gen_ai import ToolAttributes
+
+
+class ToolNodeOTelParser:
+ """Parser for tool nodes that captures tool-specific metadata."""
+
+ def __init__(self) -> None:
+ self._delegate = DefaultNodeOTelParser()
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
+
+ tool_data = getattr(node, "_node_data", None)
+ if not isinstance(tool_data, ToolNodeData):
+ return
+
+ span.set_attribute(ToolAttributes.TOOL_NAME, node.title)
+ span.set_attribute(ToolAttributes.TOOL_TYPE, tool_data.provider_type.value)
+
+ # Extract tool info from metadata (consistent with aliyun_trace)
+ tool_info = {}
+ if result_event and result_event.node_run_result:
+ node_run_result = result_event.node_run_result
+ if node_run_result.metadata:
+ tool_info = node_run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
+
+ if tool_info:
+ span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info))
+
+ if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
+ span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
+
+ if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
+ span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))
diff --git a/api/extensions/otel/semconv/__init__.py b/api/extensions/otel/semconv/__init__.py
index dc79dee222..0db3075815 100644
--- a/api/extensions/otel/semconv/__init__.py
+++ b/api/extensions/otel/semconv/__init__.py
@@ -1,6 +1,13 @@
"""Semantic convention shortcuts for Dify-specific spans."""
from .dify import DifySpanAttributes
-from .gen_ai import GenAIAttributes
+from .gen_ai import ChainAttributes, GenAIAttributes, LLMAttributes, RetrieverAttributes, ToolAttributes
-__all__ = ["DifySpanAttributes", "GenAIAttributes"]
+__all__ = [
+ "ChainAttributes",
+ "DifySpanAttributes",
+ "GenAIAttributes",
+ "LLMAttributes",
+ "RetrieverAttributes",
+ "ToolAttributes",
+]
diff --git a/api/extensions/otel/semconv/gen_ai.py b/api/extensions/otel/semconv/gen_ai.py
index 83c52ed34f..88c2058c06 100644
--- a/api/extensions/otel/semconv/gen_ai.py
+++ b/api/extensions/otel/semconv/gen_ai.py
@@ -62,3 +62,37 @@ class ToolAttributes:
TOOL_CALL_RESULT = "gen_ai.tool.call.result"
"""Tool invocation result."""
+
+
+class LLMAttributes:
+ """LLM operation attribute keys."""
+
+ REQUEST_MODEL = "gen_ai.request.model"
+ """Model identifier."""
+
+ PROVIDER_NAME = "gen_ai.provider.name"
+ """Provider name."""
+
+ USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
+ """Number of input tokens."""
+
+ USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
+ """Number of output tokens."""
+
+ USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
+ """Total number of tokens."""
+
+ PROMPT = "gen_ai.prompt"
+ """Prompt text."""
+
+ COMPLETION = "gen_ai.completion"
+ """Completion text."""
+
+ RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
+ """Finish reason for the response."""
+
+ INPUT_MESSAGE = "gen_ai.input.messages"
+ """Input messages in structured format."""
+
+ OUTPUT_MESSAGE = "gen_ai.output.messages"
+ """Output messages in structured format."""
diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py
index 51a97b20f8..1d9911465b 100644
--- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py
+++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py
@@ -5,6 +5,8 @@ automatic cleanup, backup and restore.
Supports complete lifecycle management for knowledge base files.
"""
+from __future__ import annotations
+
import json
import logging
import operator
@@ -48,7 +50,7 @@ class FileMetadata:
return data
@classmethod
- def from_dict(cls, data: dict) -> "FileMetadata":
+ def from_dict(cls, data: dict) -> FileMetadata:
"""Create instance from dictionary"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])
diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py
index ea5d982efc..cf092c6973 100644
--- a/api/extensions/storage/tencent_cos_storage.py
+++ b/api/extensions/storage/tencent_cos_storage.py
@@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage):
super().__init__()
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
- config = CosConfig(
- Region=dify_config.TENCENT_COS_REGION,
- SecretId=dify_config.TENCENT_COS_SECRET_ID,
- SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
- Scheme=dify_config.TENCENT_COS_SCHEME,
- )
+ if dify_config.TENCENT_COS_CUSTOM_DOMAIN:
+ config = CosConfig(
+ Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN,
+ SecretId=dify_config.TENCENT_COS_SECRET_ID,
+ SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
+ Scheme=dify_config.TENCENT_COS_SCHEME,
+ )
+ else:
+ config = CosConfig(
+ Region=dify_config.TENCENT_COS_REGION,
+ SecretId=dify_config.TENCENT_COS_SECRET_ID,
+ SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
+ Scheme=dify_config.TENCENT_COS_SCHEME,
+ )
self.client = CosS3Client(config)
def save(self, filename, data):
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index bd71f18af2..0be836c8f1 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -115,7 +115,18 @@ def build_from_mappings(
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
- valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
+ def is_valid_mapping(m: Mapping[str, Any]) -> bool:
+ if not m or not m.get("transfer_method"):
+ return False
+ # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None
+ transfer_method = m.get("transfer_method")
+ if transfer_method == FileTransferMethod.REMOTE_URL:
+ url = m.get("url") or m.get("remote_url")
+ if not url:
+ return False
+ return True
+
+ valid_mappings = [m for m in mappings if is_valid_mapping(m)]
files = [
build_from_mapping(
mapping=mapping,
diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py
index 494194369a..3f030ae127 100644
--- a/api/factories/variable_factory.py
+++ b/api/factories/variable_factory.py
@@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
- Variable,
+ VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
-def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
-def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
-def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
-def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
+def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
@@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
- result: Variable
+ result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
- return cast(Variable, result)
+ return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
@@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None,
name: str | None = None,
description: str = "",
-) -> Variable:
- if isinstance(segment, Variable):
+) -> VariableBase:
+ if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
@@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
- Variable,
+ VariableBase,
variable_class(
id=id,
name=name,
diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py
index 38835d5ac7..e69306dcb2 100644
--- a/api/fields/annotation_fields.py
+++ b/api/fields/annotation_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from libs.helper import TimestampField
@@ -12,7 +12,7 @@ annotation_fields = {
}
-def build_annotation_model(api_or_ns: Api | Namespace):
+def build_annotation_model(api_or_ns: Namespace):
"""Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields)
diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py
index 0bcc797ec7..cda46f2339 100644
--- a/api/fields/conversation_fields.py
+++ b/api/fields/conversation_fields.py
@@ -1,241 +1,339 @@
-from flask_restx import Api, Namespace, fields
+from __future__ import annotations
-from fields.member_fields import simple_account_fields
-from libs.helper import TimestampField
+from datetime import datetime
+from typing import Any, TypeAlias
-from .raws import FilesContainedField
+from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
+
+from core.file import File
+
+JSONValue: TypeAlias = Any
-class MessageTextField(fields.Raw):
- def format(self, value):
- return value[0]["text"] if value else ""
+class ResponseModel(BaseModel):
+ model_config = ConfigDict(
+ from_attributes=True,
+ extra="ignore",
+ populate_by_name=True,
+ serialize_by_alias=True,
+ protected_namespaces=(),
+ )
-feedback_fields = {
- "rating": fields.String,
- "content": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_account": fields.Nested(simple_account_fields, allow_null=True),
-}
+class MessageFile(ResponseModel):
+ id: str
+ filename: str
+ type: str
+ url: str | None = None
+ mime_type: str | None = None
+ size: int | None = None
+ transfer_method: str
+ belongs_to: str | None = None
+ upload_file_id: str | None = None
-annotation_fields = {
- "id": fields.String,
- "question": fields.String,
- "content": fields.String,
- "account": fields.Nested(simple_account_fields, allow_null=True),
- "created_at": TimestampField,
-}
-
-annotation_hit_history_fields = {
- "annotation_id": fields.String(attribute="id"),
- "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
- "created_at": TimestampField,
-}
-
-message_file_fields = {
- "id": fields.String,
- "filename": fields.String,
- "type": fields.String,
- "url": fields.String,
- "mime_type": fields.String,
- "size": fields.Integer,
- "transfer_method": fields.String,
- "belongs_to": fields.String(default="user"),
- "upload_file_id": fields.String(default=None),
-}
+ @field_validator("transfer_method", mode="before")
+ @classmethod
+ def _normalize_transfer_method(cls, value: object) -> str:
+ if isinstance(value, str):
+ return value
+ return str(value)
-def build_message_file_model(api_or_ns: Api | Namespace):
- """Build the message file fields for the API or Namespace."""
- return api_or_ns.model("MessageFile", message_file_fields)
+class SimpleConversation(ResponseModel):
+ id: str
+ name: str
+ inputs: dict[str, JSONValue]
+ status: str
+ introduction: str | None = None
+ created_at: int | None = None
+ updated_at: int | None = None
+
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
+ return format_files_contained(value)
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
-agent_thought_fields = {
- "id": fields.String,
- "chain_id": fields.String,
- "message_id": fields.String,
- "position": fields.Integer,
- "thought": fields.String,
- "tool": fields.String,
- "tool_labels": fields.Raw,
- "tool_input": fields.String,
- "created_at": TimestampField,
- "observation": fields.String,
- "files": fields.List(fields.String),
-}
-
-message_detail_fields = {
- "id": fields.String,
- "conversation_id": fields.String,
- "inputs": FilesContainedField,
- "query": fields.String,
- "message": fields.Raw,
- "message_tokens": fields.Integer,
- "answer": fields.String(attribute="re_sign_file_url_answer"),
- "answer_tokens": fields.Integer,
- "provider_response_latency": fields.Float,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_account_id": fields.String,
- "feedbacks": fields.List(fields.Nested(feedback_fields)),
- "workflow_run_id": fields.String,
- "annotation": fields.Nested(annotation_fields, allow_null=True),
- "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
- "created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "metadata": fields.Raw(attribute="message_metadata_dict"),
- "status": fields.String,
- "error": fields.String,
- "parent_message_id": fields.String,
-}
-
-feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
-status_count_fields = {
- "success": fields.Integer,
- "failed": fields.Integer,
- "partial_success": fields.Integer,
- "paused": fields.Integer,
-}
-model_config_fields = {
- "opening_statement": fields.String,
- "suggested_questions": fields.Raw,
- "model": fields.Raw,
- "user_input_form": fields.Raw,
- "pre_prompt": fields.String,
- "agent_mode": fields.Raw,
-}
-
-simple_model_config_fields = {
- "model": fields.Raw(attribute="model_dict"),
- "pre_prompt": fields.String,
-}
-
-simple_message_detail_fields = {
- "inputs": FilesContainedField,
- "query": fields.String,
- "message": MessageTextField,
- "answer": fields.String,
-}
-
-conversation_fields = {
- "id": fields.String,
- "status": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_end_user_session_id": fields.String(),
- "from_account_id": fields.String,
- "from_account_name": fields.String,
- "read_at": TimestampField,
- "created_at": TimestampField,
- "updated_at": TimestampField,
- "annotation": fields.Nested(annotation_fields, allow_null=True),
- "model_config": fields.Nested(simple_model_config_fields),
- "user_feedback_stats": fields.Nested(feedback_stat_fields),
- "admin_feedback_stats": fields.Nested(feedback_stat_fields),
- "message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
-}
-
-conversation_pagination_fields = {
- "page": fields.Integer,
- "limit": fields.Integer(attribute="per_page"),
- "total": fields.Integer,
- "has_more": fields.Boolean(attribute="has_next"),
- "data": fields.List(fields.Nested(conversation_fields), attribute="items"),
-}
-
-conversation_message_detail_fields = {
- "id": fields.String,
- "status": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_account_id": fields.String,
- "created_at": TimestampField,
- "model_config": fields.Nested(model_config_fields),
- "message": fields.Nested(message_detail_fields, attribute="first_message"),
-}
-
-conversation_with_summary_fields = {
- "id": fields.String,
- "status": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_end_user_session_id": fields.String,
- "from_account_id": fields.String,
- "from_account_name": fields.String,
- "name": fields.String,
- "summary": fields.String(attribute="summary_or_query"),
- "read_at": TimestampField,
- "created_at": TimestampField,
- "updated_at": TimestampField,
- "annotated": fields.Boolean,
- "model_config": fields.Nested(simple_model_config_fields),
- "message_count": fields.Integer,
- "user_feedback_stats": fields.Nested(feedback_stat_fields),
- "admin_feedback_stats": fields.Nested(feedback_stat_fields),
- "status_count": fields.Nested(status_count_fields),
-}
-
-conversation_with_summary_pagination_fields = {
- "page": fields.Integer,
- "limit": fields.Integer(attribute="per_page"),
- "total": fields.Integer,
- "has_more": fields.Boolean(attribute="has_next"),
- "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
-}
-
-conversation_detail_fields = {
- "id": fields.String,
- "status": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_account_id": fields.String,
- "created_at": TimestampField,
- "updated_at": TimestampField,
- "annotated": fields.Boolean,
- "introduction": fields.String,
- "model_config": fields.Nested(model_config_fields),
- "message_count": fields.Integer,
- "user_feedback_stats": fields.Nested(feedback_stat_fields),
- "admin_feedback_stats": fields.Nested(feedback_stat_fields),
-}
-
-simple_conversation_fields = {
- "id": fields.String,
- "name": fields.String,
- "inputs": FilesContainedField,
- "status": fields.String,
- "introduction": fields.String,
- "created_at": TimestampField,
- "updated_at": TimestampField,
-}
-
-conversation_delete_fields = {
- "result": fields.String,
-}
-
-conversation_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(simple_conversation_fields)),
-}
+class ConversationInfiniteScrollPagination(ResponseModel):
+ limit: int
+ has_more: bool
+ data: list[SimpleConversation]
-def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
- """Build the conversation infinite scroll pagination model for the API or Namespace."""
- simple_conversation_model = build_simple_conversation_model(api_or_ns)
-
- copied_fields = conversation_infinite_scroll_pagination_fields.copy()
- copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
- return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
+class ConversationDelete(ResponseModel):
+ result: str
-def build_conversation_delete_model(api_or_ns: Api | Namespace):
- """Build the conversation delete model for the API or Namespace."""
- return api_or_ns.model("ConversationDelete", conversation_delete_fields)
+class ResultResponse(ResponseModel):
+ result: str
-def build_simple_conversation_model(api_or_ns: Api | Namespace):
- """Build the simple conversation model for the API or Namespace."""
- return api_or_ns.model("SimpleConversation", simple_conversation_fields)
+class SimpleAccount(ResponseModel):
+ id: str
+ name: str
+ email: str
+
+
+class Feedback(ResponseModel):
+ rating: str
+ content: str | None = None
+ from_source: str
+ from_end_user_id: str | None = None
+ from_account: SimpleAccount | None = None
+
+
+class Annotation(ResponseModel):
+ id: str
+ question: str | None = None
+ content: str
+ account: SimpleAccount | None = None
+ created_at: int | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class AnnotationHitHistory(ResponseModel):
+ annotation_id: str
+ annotation_create_account: SimpleAccount | None = None
+ created_at: int | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class AgentThought(ResponseModel):
+ id: str
+ chain_id: str | None = None
+ message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id")
+ message_id: str
+ position: int
+ thought: str | None = None
+ tool: str | None = None
+ tool_labels: JSONValue
+ tool_input: str | None = None
+ created_at: int | None = None
+ observation: str | None = None
+ files: list[str]
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+ @model_validator(mode="after")
+ def _fallback_chain_id(self):
+ if self.chain_id is None and self.message_chain_id:
+ self.chain_id = self.message_chain_id
+ return self
+
+
+class MessageDetail(ResponseModel):
+ id: str
+ conversation_id: str
+ inputs: dict[str, JSONValue]
+ query: str
+ message: JSONValue
+ message_tokens: int
+ answer: str
+ answer_tokens: int
+ provider_response_latency: float
+ from_source: str
+ from_end_user_id: str | None = None
+ from_account_id: str | None = None
+ feedbacks: list[Feedback]
+ workflow_run_id: str | None = None
+ annotation: Annotation | None = None
+ annotation_hit_history: AnnotationHitHistory | None = None
+ created_at: int | None = None
+ agent_thoughts: list[AgentThought]
+ message_files: list[MessageFile]
+ metadata: JSONValue
+ status: str
+ error: str | None = None
+ parent_message_id: str | None = None
+
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
+ return format_files_contained(value)
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class FeedbackStat(ResponseModel):
+ like: int
+ dislike: int
+
+
+class StatusCount(ResponseModel):
+ success: int
+ failed: int
+ partial_success: int
+ paused: int
+
+
+class ModelConfig(ResponseModel):
+ opening_statement: str | None = None
+ suggested_questions: JSONValue | None = None
+ model: JSONValue | None = None
+ user_input_form: JSONValue | None = None
+ pre_prompt: str | None = None
+ agent_mode: JSONValue | None = None
+
+
+class SimpleModelConfig(ResponseModel):
+ model: JSONValue | None = None
+ pre_prompt: str | None = None
+
+
+class SimpleMessageDetail(ResponseModel):
+ inputs: dict[str, JSONValue]
+ query: str
+ message: str
+ answer: str
+
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
+ return format_files_contained(value)
+
+
+class Conversation(ResponseModel):
+ id: str
+ status: str
+ from_source: str
+ from_end_user_id: str | None = None
+ from_end_user_session_id: str | None = None
+ from_account_id: str | None = None
+ from_account_name: str | None = None
+ read_at: int | None = None
+ created_at: int | None = None
+ updated_at: int | None = None
+ annotation: Annotation | None = None
+ model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
+ user_feedback_stats: FeedbackStat | None = None
+ admin_feedback_stats: FeedbackStat | None = None
+ message: SimpleMessageDetail | None = None
+
+
+class ConversationPagination(ResponseModel):
+ page: int
+ limit: int
+ total: int
+ has_more: bool
+ data: list[Conversation]
+
+
+class ConversationMessageDetail(ResponseModel):
+ id: str
+ status: str
+ from_source: str
+ from_end_user_id: str | None = None
+ from_account_id: str | None = None
+ created_at: int | None = None
+ model_config_: ModelConfig | None = Field(default=None, alias="model_config")
+ message: MessageDetail | None = None
+
+
+class ConversationWithSummary(ResponseModel):
+ id: str
+ status: str
+ from_source: str
+ from_end_user_id: str | None = None
+ from_end_user_session_id: str | None = None
+ from_account_id: str | None = None
+ from_account_name: str | None = None
+ name: str
+ summary: str
+ read_at: int | None = None
+ created_at: int | None = None
+ updated_at: int | None = None
+ annotated: bool
+ model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
+ message_count: int
+ user_feedback_stats: FeedbackStat | None = None
+ admin_feedback_stats: FeedbackStat | None = None
+ status_count: StatusCount | None = None
+
+
+class ConversationWithSummaryPagination(ResponseModel):
+ page: int
+ limit: int
+ total: int
+ has_more: bool
+ data: list[ConversationWithSummary]
+
+
+class ConversationDetail(ResponseModel):
+ id: str
+ status: str
+ from_source: str
+ from_end_user_id: str | None = None
+ from_account_id: str | None = None
+ created_at: int | None = None
+ updated_at: int | None = None
+ annotated: bool
+ introduction: str | None = None
+ model_config_: ModelConfig | None = Field(default=None, alias="model_config")
+ message_count: int
+ user_feedback_stats: FeedbackStat | None = None
+ admin_feedback_stats: FeedbackStat | None = None
+
+
+def to_timestamp(value: datetime | None) -> int | None:
+ if value is None:
+ return None
+ return int(value.timestamp())
+
+
+def format_files_contained(value: JSONValue) -> JSONValue:
+ if isinstance(value, File):
+ return value.model_dump()
+ if isinstance(value, dict):
+ return {k: format_files_contained(v) for k, v in value.items()}
+ if isinstance(value, list):
+ return [format_files_contained(v) for v in value]
+ return value
+
+
+def message_text(value: JSONValue) -> str:
+ if isinstance(value, list) and value:
+ first = value[0]
+ if isinstance(first, dict):
+ text = first.get("text")
+ if isinstance(text, str):
+ return text
+ return ""
+
+
+def extract_model_config(value: object | None) -> dict[str, JSONValue]:
+ if value is None:
+ return {}
+ if isinstance(value, dict):
+ return value
+ if hasattr(value, "to_dict"):
+ return value.to_dict()
+ return {}
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index 7d5e311591..c55014a368 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from libs.helper import TimestampField
@@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
}
-def build_conversation_variable_model(api_or_ns: Api | Namespace):
+def build_conversation_variable_model(api_or_ns: Namespace):
"""Build the conversation variable model for the API or Namespace."""
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
-def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
+def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
# Build the nested variable model first
conversation_variable_model = build_conversation_variable_model(api_or_ns)
diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py
index ea43e3b5fd..5389b0213a 100644
--- a/api/fields/end_user_fields.py
+++ b/api/fields/end_user_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
simple_end_user_fields = {
"id": fields.String,
@@ -8,5 +8,5 @@ simple_end_user_fields = {
}
-def build_simple_end_user_model(api_or_ns: Api | Namespace):
+def build_simple_end_user_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py
index a707500445..913fb675f9 100644
--- a/api/fields/file_fields.py
+++ b/api/fields/file_fields.py
@@ -1,93 +1,85 @@
-from flask_restx import Api, Namespace, fields
+from __future__ import annotations
-from libs.helper import TimestampField
+from datetime import datetime
-upload_config_fields = {
- "file_size_limit": fields.Integer,
- "batch_count_limit": fields.Integer,
- "image_file_size_limit": fields.Integer,
- "video_file_size_limit": fields.Integer,
- "audio_file_size_limit": fields.Integer,
- "workflow_file_upload_limit": fields.Integer,
- "image_file_batch_limit": fields.Integer,
- "single_chunk_attachment_limit": fields.Integer,
-}
+from pydantic import BaseModel, ConfigDict, field_validator
-def build_upload_config_model(api_or_ns: Api | Namespace):
- """Build the upload config model for the API or Namespace.
-
- Args:
- api_or_ns: Flask-RestX Api or Namespace instance
-
- Returns:
- The registered model
- """
- return api_or_ns.model("UploadConfig", upload_config_fields)
+class ResponseModel(BaseModel):
+ model_config = ConfigDict(
+ from_attributes=True,
+ extra="ignore",
+ populate_by_name=True,
+ serialize_by_alias=True,
+ protected_namespaces=(),
+ )
-file_fields = {
- "id": fields.String,
- "name": fields.String,
- "size": fields.Integer,
- "extension": fields.String,
- "mime_type": fields.String,
- "created_by": fields.String,
- "created_at": TimestampField,
- "preview_url": fields.String,
- "source_url": fields.String,
-}
+def _to_timestamp(value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return int(value.timestamp())
+ return value
-def build_file_model(api_or_ns: Api | Namespace):
- """Build the file model for the API or Namespace.
-
- Args:
- api_or_ns: Flask-RestX Api or Namespace instance
-
- Returns:
- The registered model
- """
- return api_or_ns.model("File", file_fields)
+class UploadConfig(ResponseModel):
+ file_size_limit: int
+ batch_count_limit: int
+ file_upload_limit: int | None = None
+ image_file_size_limit: int
+ video_file_size_limit: int
+ audio_file_size_limit: int
+ workflow_file_upload_limit: int
+ image_file_batch_limit: int
+ single_chunk_attachment_limit: int
+ attachment_image_file_size_limit: int | None = None
-remote_file_info_fields = {
- "file_type": fields.String(attribute="file_type"),
- "file_length": fields.Integer(attribute="file_length"),
-}
+class FileResponse(ResponseModel):
+ id: str
+ name: str
+ size: int
+ extension: str | None = None
+ mime_type: str | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ preview_url: str | None = None
+ source_url: str | None = None
+ original_url: str | None = None
+ user_id: str | None = None
+ tenant_id: str | None = None
+ conversation_id: str | None = None
+ file_key: str | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
-def build_remote_file_info_model(api_or_ns: Api | Namespace):
- """Build the remote file info model for the API or Namespace.
-
- Args:
- api_or_ns: Flask-RestX Api or Namespace instance
-
- Returns:
- The registered model
- """
- return api_or_ns.model("RemoteFileInfo", remote_file_info_fields)
+class RemoteFileInfo(ResponseModel):
+ file_type: str
+ file_length: int
-file_fields_with_signed_url = {
- "id": fields.String,
- "name": fields.String,
- "size": fields.Integer,
- "extension": fields.String,
- "url": fields.String,
- "mime_type": fields.String,
- "created_by": fields.String,
- "created_at": TimestampField,
-}
+class FileWithSignedUrl(ResponseModel):
+ id: str
+ name: str
+ size: int
+ extension: str | None = None
+ url: str | None = None
+ mime_type: str | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
-def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
- """Build the file with signed URL model for the API or Namespace.
-
- Args:
- api_or_ns: Flask-RestX Api or Namespace instance
-
- Returns:
- The registered model
- """
- return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url)
+__all__ = [
+ "FileResponse",
+ "FileWithSignedUrl",
+ "RemoteFileInfo",
+ "UploadConfig",
+]
diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py
index 08e38a6931..25160927e6 100644
--- a/api/fields/member_fields.py
+++ b/api/fields/member_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from libs.helper import AvatarUrlField, TimestampField
@@ -9,7 +9,7 @@ simple_account_fields = {
}
-def build_simple_account_model(api_or_ns: Api | Namespace):
+def build_simple_account_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields)
diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py
index 552f0b598f..53911a18dd 100644
--- a/api/fields/message_fields.py
+++ b/api/fields/message_fields.py
@@ -1,78 +1,140 @@
-from flask_restx import Api, Namespace, fields
+from __future__ import annotations
-from fields.conversation_fields import message_file_fields
-from libs.helper import TimestampField
+from datetime import datetime
+from typing import TypeAlias
+from uuid import uuid4
-from .raws import FilesContainedField
+from pydantic import BaseModel, ConfigDict, Field, field_validator
-feedback_fields = {
- "rating": fields.String,
-}
+from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
+from core.file import File
+from fields.conversation_fields import AgentThought, JSONValue, MessageFile
+
+JSONValueType: TypeAlias = JSONValue
-def build_feedback_model(api_or_ns: Api | Namespace):
- """Build the feedback model for the API or Namespace."""
- return api_or_ns.model("Feedback", feedback_fields)
+class ResponseModel(BaseModel):
+ model_config = ConfigDict(from_attributes=True, extra="ignore")
-agent_thought_fields = {
- "id": fields.String,
- "chain_id": fields.String,
- "message_id": fields.String,
- "position": fields.Integer,
- "thought": fields.String,
- "tool": fields.String,
- "tool_labels": fields.Raw,
- "tool_input": fields.String,
- "created_at": TimestampField,
- "observation": fields.String,
- "files": fields.List(fields.String),
-}
+class SimpleFeedback(ResponseModel):
+ rating: str | None = None
-def build_agent_thought_model(api_or_ns: Api | Namespace):
- """Build the agent thought model for the API or Namespace."""
- return api_or_ns.model("AgentThought", agent_thought_fields)
+class RetrieverResource(ResponseModel):
+ id: str = Field(default_factory=lambda: str(uuid4()))
+ message_id: str = Field(default_factory=lambda: str(uuid4()))
+ position: int
+ dataset_id: str | None = None
+ dataset_name: str | None = None
+ document_id: str | None = None
+ document_name: str | None = None
+ data_source_type: str | None = None
+ segment_id: str | None = None
+ score: float | None = None
+ hit_count: int | None = None
+ word_count: int | None = None
+ segment_position: int | None = None
+ index_node_hash: str | None = None
+ content: str | None = None
+ created_at: int | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
-retriever_resource_fields = {
- "id": fields.String,
- "message_id": fields.String,
- "position": fields.Integer,
- "dataset_id": fields.String,
- "dataset_name": fields.String,
- "document_id": fields.String,
- "document_name": fields.String,
- "data_source_type": fields.String,
- "segment_id": fields.String,
- "score": fields.Float,
- "hit_count": fields.Integer,
- "word_count": fields.Integer,
- "segment_position": fields.Integer,
- "index_node_hash": fields.String,
- "content": fields.String,
- "created_at": TimestampField,
-}
+class MessageListItem(ResponseModel):
+ id: str
+ conversation_id: str
+ parent_message_id: str | None = None
+ inputs: dict[str, JSONValueType]
+ query: str
+ answer: str = Field(validation_alias="re_sign_file_url_answer")
+ feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
+ retriever_resources: list[RetrieverResource]
+ created_at: int | None = None
+ agent_thoughts: list[AgentThought]
+ message_files: list[MessageFile]
+ status: str
+ error: str | None = None
+ extra_contents: list[ExecutionExtraContentDomainModel]
-message_fields = {
- "id": fields.String,
- "conversation_id": fields.String,
- "parent_message_id": fields.String,
- "inputs": FilesContainedField,
- "query": fields.String,
- "answer": fields.String(attribute="re_sign_file_url_answer"),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
- "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
- "extra_contents": fields.List(cls_or_instance=fields.Raw),
- "created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "status": fields.String,
- "error": fields.String,
-}
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
+ return format_files_contained(value)
-message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
-}
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class WebMessageListItem(MessageListItem):
+ metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
+
+
+class MessageInfiniteScrollPagination(ResponseModel):
+ limit: int
+ has_more: bool
+ data: list[MessageListItem]
+
+
+class WebMessageInfiniteScrollPagination(ResponseModel):
+ limit: int
+ has_more: bool
+ data: list[WebMessageListItem]
+
+
+class SavedMessageItem(ResponseModel):
+ id: str
+ inputs: dict[str, JSONValueType]
+ query: str
+ answer: str
+ message_files: list[MessageFile]
+ feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
+ created_at: int | None = None
+
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
+ return format_files_contained(value)
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class SavedMessageInfiniteScrollPagination(ResponseModel):
+ limit: int
+ has_more: bool
+ data: list[SavedMessageItem]
+
+
+class SuggestedQuestionsResponse(ResponseModel):
+ data: list[str]
+
+
+def to_timestamp(value: datetime | None) -> int | None:
+ if value is None:
+ return None
+ return int(value.timestamp())
+
+
+def format_files_contained(value: JSONValueType) -> JSONValueType:
+ if isinstance(value, File):
+ return value.model_dump()
+ if isinstance(value, dict):
+ return {k: format_files_contained(v) for k, v in value.items()}
+ if isinstance(value, list):
+ return [format_files_contained(v) for v in value]
+ return value
diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py
index f9e858c68b..97c02e7085 100644
--- a/api/fields/rag_pipeline_fields.py
+++ b/api/fields/rag_pipeline_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import fields # type: ignore
+from flask_restx import fields
from fields.workflow_fields import workflow_partial_fields
from libs.helper import AppIconUrlField, TimestampField
diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py
index d5b7c86a04..e359a4408c 100644
--- a/api/fields/tag_fields.py
+++ b/api/fields/tag_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
dataset_tag_fields = {
"id": fields.String,
@@ -8,5 +8,5 @@ dataset_tag_fields = {
}
-def build_dataset_tag_fields(api_or_ns: Api | Namespace):
+def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py
index 4cbdf6f0ca..ae70356322 100644
--- a/api/fields/workflow_app_log_fields.py
+++ b/api/fields/workflow_app_log_fields.py
@@ -1,8 +1,13 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields
-from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields
+from fields.workflow_run_fields import (
+ build_workflow_run_for_archived_log_model,
+ build_workflow_run_for_log_model,
+ workflow_run_for_archived_log_fields,
+ workflow_run_for_log_fields,
+)
from libs.helper import TimestampField
workflow_app_log_partial_fields = {
@@ -17,7 +22,7 @@ workflow_app_log_partial_fields = {
}
-def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
+def build_workflow_app_log_partial_model(api_or_ns: Namespace):
"""Build the workflow app log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
@@ -34,6 +39,33 @@ def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
+workflow_archived_log_partial_fields = {
+ "id": fields.String,
+ "workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True),
+ "trigger_metadata": fields.Raw,
+ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
+ "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
+ "created_at": TimestampField,
+}
+
+
+def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
+ """Build the workflow archived log partial model for the API or Namespace."""
+ workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
+ simple_account_model = build_simple_account_model(api_or_ns)
+ simple_end_user_model = build_simple_end_user_model(api_or_ns)
+
+ copied_fields = workflow_archived_log_partial_fields.copy()
+ copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
+ copied_fields["created_by_account"] = fields.Nested(
+ simple_account_model, attribute="created_by_account", allow_null=True
+ )
+ copied_fields["created_by_end_user"] = fields.Nested(
+ simple_end_user_model, attribute="created_by_end_user", allow_null=True
+ )
+ return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)
+
+
workflow_app_log_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer,
@@ -43,7 +75,7 @@ workflow_app_log_pagination_fields = {
}
-def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
+def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
"""Build the workflow app log pagination model for the API or Namespace."""
# Build the nested partial model first
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
@@ -51,3 +83,21 @@ def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
copied_fields = workflow_app_log_pagination_fields.copy()
copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model))
return api_or_ns.model("WorkflowAppLogPagination", copied_fields)
+
+
+workflow_archived_log_pagination_fields = {
+ "page": fields.Integer,
+ "limit": fields.Integer,
+ "total": fields.Integer,
+ "has_more": fields.Boolean,
+ "data": fields.List(fields.Nested(workflow_archived_log_partial_fields)),
+}
+
+
+def build_workflow_archived_log_pagination_model(api_or_ns: Namespace):
+ """Build the workflow archived log pagination model for the API or Namespace."""
+ workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns)
+
+ copied_fields = workflow_archived_log_pagination_fields.copy()
+ copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model))
+ return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields)
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index d037b0c442..2755f77f61 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -1,7 +1,7 @@
from flask_restx import fields
from core.helper import encrypter
-from core.variables import SecretVariable, SegmentType, Variable
+from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value,
"description": value.description,
}
- if isinstance(value, Variable):
+ if isinstance(value, VariableBase):
return {
"id": value.id,
"name": value.name,
diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py
index 821ce62ecc..35bb442c59 100644
--- a/api/fields/workflow_run_fields.py
+++ b/api/fields/workflow_run_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@@ -19,10 +19,23 @@ workflow_run_for_log_fields = {
}
-def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
+def build_workflow_run_for_log_model(api_or_ns: Namespace):
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
+workflow_run_for_archived_log_fields = {
+ "id": fields.String,
+ "status": fields.String,
+ "triggered_from": fields.String,
+ "elapsed_time": fields.Float,
+ "total_tokens": fields.Integer,
+}
+
+
+def build_workflow_run_for_archived_log_model(api_or_ns: Namespace):
+ return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields)
+
+
workflow_run_for_list_fields = {
"id": fields.String,
"version": fields.String,
diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py
new file mode 100644
index 0000000000..66b57ac661
--- /dev/null
+++ b/api/libs/archive_storage.py
@@ -0,0 +1,353 @@
+"""
+Archive Storage Client for S3-compatible storage.
+
+This module provides a dedicated storage client for archiving or exporting logs
+to S3-compatible object storage.
+"""
+
+import base64
+import datetime
+import hashlib
+import logging
+from collections.abc import Generator
+from typing import Any, cast
+
+import boto3
+import orjson
+from botocore.client import Config
+from botocore.exceptions import ClientError
+
+from configs import dify_config
+
+logger = logging.getLogger(__name__)
+
+
+class ArchiveStorageError(Exception):
+ """Base exception for archive storage operations."""
+
+ pass
+
+
+class ArchiveStorageNotConfiguredError(ArchiveStorageError):
+ """Raised when archive storage is not properly configured."""
+
+ pass
+
+
+class ArchiveStorage:
+ """
+ S3-compatible storage client for archiving or exporting.
+
+ This client provides methods for storing and retrieving archived data in JSONL format.
+ """
+
+ def __init__(self, bucket: str):
+ if not dify_config.ARCHIVE_STORAGE_ENABLED:
+ raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
+
+ if not bucket:
+ raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
+ if not all(
+ [
+ dify_config.ARCHIVE_STORAGE_ENDPOINT,
+ bucket,
+ dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
+ dify_config.ARCHIVE_STORAGE_SECRET_KEY,
+ ]
+ ):
+ raise ArchiveStorageNotConfiguredError(
+ "Archive storage configuration is incomplete. "
+ "Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
+ "ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
+ )
+
+ self.bucket = bucket
+ self.client = boto3.client(
+ "s3",
+ endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
+ aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
+ aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
+ region_name=dify_config.ARCHIVE_STORAGE_REGION,
+ config=Config(
+ s3={"addressing_style": "path"},
+ max_pool_connections=64,
+ ),
+ )
+
+ # Verify bucket accessibility
+ try:
+ self.client.head_bucket(Bucket=self.bucket)
+ except ClientError as e:
+ error_code = e.response.get("Error", {}).get("Code")
+ if error_code == "404":
+ raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
+ elif error_code == "403":
+ raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
+ else:
+ raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
+
+ def put_object(self, key: str, data: bytes) -> str:
+ """
+ Upload an object to the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+ data: Binary data to upload
+
+ Returns:
+ MD5 checksum of the uploaded data
+
+ Raises:
+ ArchiveStorageError: If upload fails
+ """
+ checksum = hashlib.md5(data).hexdigest()
+ try:
+ response = self.client.put_object(
+ Bucket=self.bucket,
+ Key=key,
+ Body=data,
+ ContentMD5=self._content_md5(data),
+ )
+ etag = response.get("ETag")
+ if not etag:
+ raise ArchiveStorageError(f"Missing ETag for '{key}'")
+ normalized_etag = etag.strip('"')
+ if normalized_etag != checksum:
+ raise ArchiveStorageError(f"ETag mismatch for '{key}': expected={checksum}, actual={normalized_etag}")
+ logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
+ return checksum
+ except ClientError as e:
+ raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
+
+ def get_object(self, key: str) -> bytes:
+ """
+ Download an object from the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+
+ Returns:
+ Binary data of the object
+
+ Raises:
+ ArchiveStorageError: If download fails
+ FileNotFoundError: If object does not exist
+ """
+ try:
+ response = self.client.get_object(Bucket=self.bucket, Key=key)
+ return response["Body"].read()
+ except ClientError as e:
+ error_code = e.response.get("Error", {}).get("Code")
+ if error_code == "NoSuchKey":
+ raise FileNotFoundError(f"Archive object not found: {key}")
+ raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
+
+ def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
+ """
+ Stream an object from the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+
+ Yields:
+ Chunks of binary data
+
+ Raises:
+ ArchiveStorageError: If download fails
+ FileNotFoundError: If object does not exist
+ """
+ try:
+ response = self.client.get_object(Bucket=self.bucket, Key=key)
+ yield from response["Body"].iter_chunks()
+ except ClientError as e:
+ error_code = e.response.get("Error", {}).get("Code")
+ if error_code == "NoSuchKey":
+ raise FileNotFoundError(f"Archive object not found: {key}")
+ raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
+
+ def object_exists(self, key: str) -> bool:
+ """
+ Check if an object exists in the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+
+ Returns:
+ True if object exists, False otherwise
+ """
+ try:
+ self.client.head_object(Bucket=self.bucket, Key=key)
+ return True
+ except ClientError:
+ return False
+
+ def delete_object(self, key: str) -> None:
+ """
+ Delete an object from the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+
+ Raises:
+ ArchiveStorageError: If deletion fails
+ """
+ try:
+ self.client.delete_object(Bucket=self.bucket, Key=key)
+ logger.debug("Deleted object: %s", key)
+ except ClientError as e:
+ raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
+
+ def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
+ """
+ Generate a pre-signed URL for downloading an object.
+
+ Args:
+ key: Object key (path) within the bucket
+ expires_in: URL validity duration in seconds (default: 1 hour)
+
+ Returns:
+ Pre-signed URL string.
+
+ Raises:
+ ArchiveStorageError: If generation fails
+ """
+ try:
+ return self.client.generate_presigned_url(
+ ClientMethod="get_object",
+ Params={"Bucket": self.bucket, "Key": key},
+ ExpiresIn=expires_in,
+ )
+ except ClientError as e:
+ raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
+
+ def list_objects(self, prefix: str) -> list[str]:
+ """
+ List objects under a given prefix.
+
+ Args:
+ prefix: Object key prefix to filter by
+
+ Returns:
+ List of object keys matching the prefix
+ """
+ keys = []
+ paginator = self.client.get_paginator("list_objects_v2")
+
+ try:
+ for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
+ for obj in page.get("Contents", []):
+ keys.append(obj["Key"])
+ except ClientError as e:
+ raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
+
+ return keys
+
+ @staticmethod
+ def _content_md5(data: bytes) -> str:
+ """Calculate base64-encoded MD5 for Content-MD5 header."""
+ return base64.b64encode(hashlib.md5(data).digest()).decode()
+
+ @staticmethod
+ def serialize_to_jsonl(records: list[dict[str, Any]]) -> bytes:
+ """
+ Serialize records to JSONL format.
+
+ Args:
+ records: List of dictionaries to serialize
+
+ Returns:
+ JSONL bytes
+ """
+ lines = []
+ for record in records:
+ serialized = ArchiveStorage._serialize_record(record)
+ lines.append(orjson.dumps(serialized))
+
+ jsonl_content = b"\n".join(lines)
+ if jsonl_content:
+ jsonl_content += b"\n"
+
+ return jsonl_content
+
+ @staticmethod
+ def deserialize_from_jsonl(data: bytes) -> list[dict[str, Any]]:
+ """
+ Deserialize JSONL data to records.
+
+ Args:
+ data: JSONL bytes
+
+ Returns:
+ List of dictionaries
+ """
+ records = []
+
+ for line in data.splitlines():
+ if line:
+ records.append(orjson.loads(line))
+
+ return records
+
+ @staticmethod
+ def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
+ """Serialize a single record, converting special types."""
+
+ def _serialize(item: Any) -> Any:
+ if isinstance(item, datetime.datetime):
+ return item.isoformat()
+ if isinstance(item, dict):
+ return {key: _serialize(value) for key, value in item.items()}
+ if isinstance(item, list):
+ return [_serialize(value) for value in item]
+ return item
+
+ return cast(dict[str, Any], _serialize(record))
+
+ @staticmethod
+ def compute_checksum(data: bytes) -> str:
+ """Compute MD5 checksum of data."""
+ return hashlib.md5(data).hexdigest()
+
+
+# Singleton instance (lazy initialization)
+_archive_storage: ArchiveStorage | None = None
+_export_storage: ArchiveStorage | None = None
+
+
+def get_archive_storage() -> ArchiveStorage:
+ """
+ Get the archive storage singleton instance.
+
+ Returns:
+ ArchiveStorage instance
+
+ Raises:
+ ArchiveStorageNotConfiguredError: If archive storage is not configured
+ """
+ global _archive_storage
+ if _archive_storage is None:
+ archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
+ if not archive_bucket:
+ raise ArchiveStorageNotConfiguredError(
+ "Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
+ )
+ _archive_storage = ArchiveStorage(bucket=archive_bucket)
+ return _archive_storage
+
+
+def get_export_storage() -> ArchiveStorage:
+ """
+ Get the export storage singleton instance.
+
+ Returns:
+ ArchiveStorage instance
+ """
+ global _export_storage
+ if _export_storage is None:
+ export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
+ if not export_bucket:
+ raise ArchiveStorageNotConfiguredError(
+ "Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
+ )
+ _export_storage = ArchiveStorage(bucket=export_bucket)
+ return _export_storage
diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py
index 5bbf0c79a3..d4cb3e9971 100644
--- a/api/libs/broadcast_channel/channel.py
+++ b/api/libs/broadcast_channel/channel.py
@@ -2,6 +2,8 @@
Broadcast channel for Pub/Sub messaging.
"""
+from __future__ import annotations
+
import types
from abc import abstractmethod
from collections.abc import Iterator
@@ -129,6 +131,6 @@ class BroadcastChannel(Protocol):
"""
@abstractmethod
- def topic(self, topic: str) -> "Topic":
+ def topic(self, topic: str) -> Topic:
"""topic returns a `Topic` instance for the given topic name."""
...
diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py
index 1fc3db8156..5bb4f579c1 100644
--- a/api/libs/broadcast_channel/redis/channel.py
+++ b/api/libs/broadcast_channel/redis/channel.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
@@ -20,7 +22,7 @@ class BroadcastChannel:
):
self._client = redis_client
- def topic(self, topic: str) -> "Topic":
+ def topic(self, topic: str) -> Topic:
return Topic(self._client, topic)
diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py
index 991dee64af..9e8ab90e8e 100644
--- a/api/libs/broadcast_channel/redis/sharded_channel.py
+++ b/api/libs/broadcast_channel/redis/sharded_channel.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
@@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel:
):
self._client = redis_client
- def topic(self, topic: str) -> "ShardedTopic":
+ def topic(self, topic: str) -> ShardedTopic:
return ShardedTopic(self._client, topic)
diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py
index ff74ccbe8e..0828cf80bf 100644
--- a/api/libs/email_i18n.py
+++ b/api/libs/email_i18n.py
@@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and
eliminates the need for repetitive language switching logic.
"""
+from __future__ import annotations
+
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Any, Protocol
@@ -53,7 +55,7 @@ class EmailLanguage(StrEnum):
ZH_HANS = "zh-Hans"
@classmethod
- def from_language_code(cls, language_code: str) -> "EmailLanguage":
+ def from_language_code(cls, language_code: str) -> EmailLanguage:
"""Convert a language code to EmailLanguage with fallback to English."""
if language_code == "zh-Hans":
return cls.ZH_HANS
diff --git a/api/libs/external_api.py b/api/libs/external_api.py
index 61a90ee4a9..e8592407c3 100644
--- a/api/libs/external_api.py
+++ b/api/libs/external_api.py
@@ -1,5 +1,4 @@
import re
-import sys
from collections.abc import Mapping
from typing import Any
@@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
- # Log stack
- exc_info: Any = sys.exc_info()
- if exc_info[1] is None:
- exc_info = (None, None, None)
- current_app.log_exception(exc_info)
+ # Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
+ # Explicit log_exception call removed to avoid duplicate log entries
return data, status_code
diff --git a/api/libs/helper.py b/api/libs/helper.py
index 26dd6fdfa6..94e1770810 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -32,6 +32,38 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def escape_like_pattern(pattern: str) -> str:
+ """
+ Escape special characters in a string for safe use in SQL LIKE patterns.
+
+ This function escapes the special characters used in SQL LIKE patterns:
+ - Backslash (\\) -> \\
+ - Percent (%) -> \\%
+ - Underscore (_) -> \\_
+
+ The escaped pattern can then be safely used in SQL LIKE queries with the
+ ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards.
+
+ Args:
+ pattern: The string pattern to escape
+
+ Returns:
+ Escaped string safe for use in SQL LIKE queries
+
+ Examples:
+ >>> escape_like_pattern("50% discount")
+ '50\\% discount'
+ >>> escape_like_pattern("test_data")
+ 'test\\_data'
+ >>> escape_like_pattern("path\\to\\file")
+ 'path\\\\to\\\\file'
+ """
+ if not pattern:
+ return pattern
+ # Escape backslash first, then percent and underscore
+ return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
+
+
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
"""
Extract tenant_id from Account or EndUser object.
diff --git a/api/libs/login.py b/api/libs/login.py
index 4b8ee2d1f8..73caa492fe 100644
--- a/api/libs/login.py
+++ b/api/libs/login.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
from collections.abc import Callable
from functools import wraps
-from typing import Any
+from typing import TYPE_CHECKING, Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
@@ -9,7 +11,9 @@ from werkzeug.local import LocalProxy
from configs import dify_config
from libs.token import check_csrf_token
from models import Account
-from models.model import EndUser
+
+if TYPE_CHECKING:
+ from models.model import EndUser
def current_account_with_tenant():
diff --git a/api/libs/smtp.py b/api/libs/smtp.py
index 4044c6f7ed..6f82f1440a 100644
--- a/api/libs/smtp.py
+++ b/api/libs/smtp.py
@@ -3,6 +3,8 @@ import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
+from configs import dify_config
+
logger = logging.getLogger(__name__)
@@ -19,20 +21,21 @@ class SMTPClient:
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
- smtp = None
+ smtp: smtplib.SMTP | None = None
+ local_host = dify_config.SMTP_LOCAL_HOSTNAME
try:
- if self.use_tls:
- if self.opportunistic_tls:
- smtp = smtplib.SMTP(self.server, self.port, timeout=10)
- # Send EHLO command with the HELO domain name as the server address
- smtp.ehlo(self.server)
- smtp.starttls()
- # Resend EHLO command to identify the TLS session
- smtp.ehlo(self.server)
- else:
- smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
+ if self.use_tls and not self.opportunistic_tls:
+ # SMTP with SSL (implicit TLS)
+ smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10, local_hostname=local_host)
else:
- smtp = smtplib.SMTP(self.server, self.port, timeout=10)
+ # Plain SMTP or SMTP with STARTTLS (explicit TLS)
+ smtp = smtplib.SMTP(self.server, self.port, timeout=10, local_hostname=local_host)
+
+ assert smtp is not None
+ if self.use_tls and self.opportunistic_tls:
+ smtp.ehlo(self.server)
+ smtp.starttls()
+ smtp.ehlo(self.server)
# Only authenticate if both username and password are non-empty
if self.username and self.password and self.username.strip() and self.password.strip():
diff --git a/api/libs/workspace_permission.py b/api/libs/workspace_permission.py
new file mode 100644
index 0000000000..dd42a7facf
--- /dev/null
+++ b/api/libs/workspace_permission.py
@@ -0,0 +1,74 @@
+"""
+Workspace permission helper functions.
+
+These helpers check both billing/plan level and workspace-specific policy level permissions.
+Checks are performed at two levels:
+1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
+2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
+"""
+
+import logging
+
+from werkzeug.exceptions import Forbidden
+
+from configs import dify_config
+from services.enterprise.enterprise_service import EnterpriseService
+from services.feature_service import FeatureService
+
+logger = logging.getLogger(__name__)
+
+
+def check_workspace_member_invite_permission(workspace_id: str) -> None:
+ """
+ Check if workspace allows member invitations at both billing and policy levels.
+
+ Checks performed:
+ 1. Billing/plan level - For future expansion (currently no plan-level restriction)
+ 2. Enterprise policy level - Admin-configured workspace permission
+
+ Args:
+ workspace_id: The workspace ID to check permissions for
+
+ Raises:
+ Forbidden: If either billing plan or workspace policy prohibits member invitations
+ """
+ # Check enterprise workspace policy level (only if enterprise enabled)
+ if dify_config.ENTERPRISE_ENABLED:
+ try:
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
+ if not permission.allow_member_invite:
+ raise Forbidden("Workspace policy prohibits member invitations")
+ except Forbidden:
+ raise
+ except Exception:
+ logger.exception("Failed to check workspace invite permission for %s", workspace_id)
+
+
+def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
+ """
+ Check if workspace allows owner transfer at both billing and policy levels.
+
+ Checks performed:
+ 1. Billing/plan level - SANDBOX plan blocks owner transfer
+ 2. Enterprise policy level - Admin-configured workspace permission
+
+ Args:
+ workspace_id: The workspace ID to check permissions for
+
+ Raises:
+ Forbidden: If either billing plan or workspace policy prohibits ownership transfer
+ """
+ features = FeatureService.get_features(workspace_id)
+ if not features.is_allow_transfer_workspace:
+ raise Forbidden("Your current plan does not allow workspace ownership transfer")
+
+ # Check enterprise workspace policy level (only if enterprise enabled)
+ if dify_config.ENTERPRISE_ENABLED:
+ try:
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
+ if not permission.allow_owner_transfer:
+ raise Forbidden("Workspace policy prohibits ownership transfer")
+ except Forbidden:
+ raise
+ except Exception:
+ logger.exception("Failed to check workspace transfer permission for %s", workspace_id)
diff --git a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py
index 17ed067d81..657d28f896 100644
--- a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py
+++ b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '00bacef91f18'
down_revision = '8ec536f3c800'
@@ -23,31 +20,17 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
- batch_op.drop_column('description_str')
- else:
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
- batch_op.drop_column('description_str')
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
+ batch_op.drop_column('description_str')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
- batch_op.drop_column('description')
- else:
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
- batch_op.drop_column('description')
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
+ batch_op.drop_column('description')
# ### end Alembic commands ###
diff --git a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py
index ed70bf5d08..912d9dbfa4 100644
--- a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py
+++ b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py
@@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824
"""
import sqlalchemy as sa
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '114eed84c228'
down_revision = 'c71211c8f604'
@@ -32,13 +28,7 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
- else:
- with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
+ with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py
index 509bd5d0e8..0ca905129d 100644
--- a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py
+++ b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '161cadc1af8d'
down_revision = '7e6a8693e07a'
@@ -23,16 +20,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
- # Step 1: Add column without NOT NULL constraint
- op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
- else:
- with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
- # Step 1: Add column without NOT NULL constraint
- op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
+ with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
+ # Step 1: Add column without NOT NULL constraint
+ op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py
index 0767b725f6..be1b42f883 100644
--- a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py
+++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py
@@ -9,11 +9,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-import sqlalchemy as sa
-from sqlalchemy.dialects import postgresql
-
# revision identifiers, used by Alembic.
revision = '6af6a521a53e'
down_revision = 'd57ba9ebb251'
@@ -23,58 +18,30 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
- batch_op.alter_column('document_id',
- existing_type=sa.UUID(),
- nullable=True)
- batch_op.alter_column('data_source_type',
- existing_type=sa.TEXT(),
- nullable=True)
- batch_op.alter_column('segment_id',
- existing_type=sa.UUID(),
- nullable=True)
- else:
- with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
- batch_op.alter_column('document_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
- batch_op.alter_column('data_source_type',
- existing_type=models.types.LongText(),
- nullable=True)
- batch_op.alter_column('segment_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
+ with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
+ batch_op.alter_column('document_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
+ batch_op.alter_column('data_source_type',
+ existing_type=models.types.LongText(),
+ nullable=True)
+ batch_op.alter_column('segment_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
- batch_op.alter_column('segment_id',
- existing_type=sa.UUID(),
- nullable=False)
- batch_op.alter_column('data_source_type',
- existing_type=sa.TEXT(),
- nullable=False)
- batch_op.alter_column('document_id',
- existing_type=sa.UUID(),
- nullable=False)
- else:
- with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
- batch_op.alter_column('segment_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
- batch_op.alter_column('data_source_type',
- existing_type=models.types.LongText(),
- nullable=False)
- batch_op.alter_column('document_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
+ with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
+ batch_op.alter_column('segment_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
+ batch_op.alter_column('data_source_type',
+ existing_type=models.types.LongText(),
+ nullable=False)
+ batch_op.alter_column('document_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py
index a749c8bddf..5d12419bf7 100644
--- a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py
+++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py
@@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198
from alembic import op
import models as models
import sqlalchemy as sa
-from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'd3f6769a94a3'
diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py
index 45842295ea..a49d6a52f6 100644
--- a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py
+++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py
@@ -28,85 +28,45 @@ def upgrade():
op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
- if _is_pg(conn):
- with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=sa.TEXT(),
- nullable=False)
+ with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=sa.VARCHAR(length=255),
+ type_=models.types.LongText(),
+ nullable=False)
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=sa.TEXT(),
- nullable=False)
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=sa.VARCHAR(length=255),
+ type_=models.types.LongText(),
+ nullable=False)
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=sa.TEXT(),
- nullable=False)
- else:
- with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=models.types.LongText(),
- nullable=False)
-
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=models.types.LongText(),
- nullable=False)
-
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=models.types.LongText(),
- nullable=False)
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=sa.VARCHAR(length=255),
+ type_=models.types.LongText(),
+ nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.TEXT(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=models.types.LongText(),
+ type_=sa.VARCHAR(length=255),
+ nullable=True)
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.TEXT(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=models.types.LongText(),
+ type_=sa.VARCHAR(length=255),
+ nullable=True)
- with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.TEXT(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
- else:
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=models.types.LongText(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
-
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=models.types.LongText(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
-
- with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=models.types.LongText(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
+ with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=models.types.LongText(),
+ type_=sa.VARCHAR(length=255),
+ nullable=True)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py
index fdd8984029..8a36c9c4a5 100644
--- a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py
+++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py
@@ -49,57 +49,33 @@ def upgrade():
op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL")
op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL")
op.execute("UPDATE workflows SET features = '' WHERE features IS NULL")
- if _is_pg(conn):
- with op.batch_alter_table('workflows', schema=None) as batch_op:
- batch_op.alter_column('graph',
- existing_type=sa.TEXT(),
- nullable=False)
- batch_op.alter_column('features',
- existing_type=sa.TEXT(),
- nullable=False)
- batch_op.alter_column('updated_at',
- existing_type=postgresql.TIMESTAMP(),
- nullable=False)
- else:
- with op.batch_alter_table('workflows', schema=None) as batch_op:
- batch_op.alter_column('graph',
- existing_type=models.types.LongText(),
- nullable=False)
- batch_op.alter_column('features',
- existing_type=models.types.LongText(),
- nullable=False)
- batch_op.alter_column('updated_at',
- existing_type=sa.TIMESTAMP(),
- nullable=False)
+
+ with op.batch_alter_table('workflows', schema=None) as batch_op:
+ batch_op.alter_column('graph',
+ existing_type=models.types.LongText(),
+ nullable=False)
+ batch_op.alter_column('features',
+ existing_type=models.types.LongText(),
+ nullable=False)
+ batch_op.alter_column('updated_at',
+ existing_type=sa.TIMESTAMP(),
+ nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('workflows', schema=None) as batch_op:
- batch_op.alter_column('updated_at',
- existing_type=postgresql.TIMESTAMP(),
- nullable=True)
- batch_op.alter_column('features',
- existing_type=sa.TEXT(),
- nullable=True)
- batch_op.alter_column('graph',
- existing_type=sa.TEXT(),
- nullable=True)
- else:
- with op.batch_alter_table('workflows', schema=None) as batch_op:
- batch_op.alter_column('updated_at',
- existing_type=sa.TIMESTAMP(),
- nullable=True)
- batch_op.alter_column('features',
- existing_type=models.types.LongText(),
- nullable=True)
- batch_op.alter_column('graph',
- existing_type=models.types.LongText(),
- nullable=True)
+ with op.batch_alter_table('workflows', schema=None) as batch_op:
+ batch_op.alter_column('updated_at',
+ existing_type=sa.TIMESTAMP(),
+ nullable=True)
+ batch_op.alter_column('features',
+ existing_type=models.types.LongText(),
+ nullable=True)
+ batch_op.alter_column('graph',
+ existing_type=models.types.LongText(),
+ nullable=True)
if _is_pg(conn):
with op.batch_alter_table('messages', schema=None) as batch_op:
diff --git a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py
index 16ca902726..1fc4a64df1 100644
--- a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py
+++ b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py
@@ -86,57 +86,30 @@ def upgrade():
def migrate_existing_provider_models_data():
"""migrate provider_models table data to provider_model_credentials"""
- conn = op.get_bind()
- # Define table structure for data manipulation
- if _is_pg(conn):
- provider_models_table = table('provider_models',
- column('id', models.types.StringUUID()),
- column('tenant_id', models.types.StringUUID()),
- column('provider_name', sa.String()),
- column('model_name', sa.String()),
- column('model_type', sa.String()),
- column('encrypted_config', sa.Text()),
- column('created_at', sa.DateTime()),
- column('updated_at', sa.DateTime()),
- column('credential_id', models.types.StringUUID()),
- )
- else:
- provider_models_table = table('provider_models',
- column('id', models.types.StringUUID()),
- column('tenant_id', models.types.StringUUID()),
- column('provider_name', sa.String()),
- column('model_name', sa.String()),
- column('model_type', sa.String()),
- column('encrypted_config', models.types.LongText()),
- column('created_at', sa.DateTime()),
- column('updated_at', sa.DateTime()),
- column('credential_id', models.types.StringUUID()),
- )
+ # Define table structure for data manipulatio
+ provider_models_table = table('provider_models',
+ column('id', models.types.StringUUID()),
+ column('tenant_id', models.types.StringUUID()),
+ column('provider_name', sa.String()),
+ column('model_name', sa.String()),
+ column('model_type', sa.String()),
+ column('encrypted_config', models.types.LongText()),
+ column('created_at', sa.DateTime()),
+ column('updated_at', sa.DateTime()),
+ column('credential_id', models.types.StringUUID()),
+ )
- if _is_pg(conn):
- provider_model_credentials_table = table('provider_model_credentials',
- column('id', models.types.StringUUID()),
- column('tenant_id', models.types.StringUUID()),
- column('provider_name', sa.String()),
- column('model_name', sa.String()),
- column('model_type', sa.String()),
- column('credential_name', sa.String()),
- column('encrypted_config', sa.Text()),
- column('created_at', sa.DateTime()),
- column('updated_at', sa.DateTime())
- )
- else:
- provider_model_credentials_table = table('provider_model_credentials',
- column('id', models.types.StringUUID()),
- column('tenant_id', models.types.StringUUID()),
- column('provider_name', sa.String()),
- column('model_name', sa.String()),
- column('model_type', sa.String()),
- column('credential_name', sa.String()),
- column('encrypted_config', models.types.LongText()),
- column('created_at', sa.DateTime()),
- column('updated_at', sa.DateTime())
- )
+ provider_model_credentials_table = table('provider_model_credentials',
+ column('id', models.types.StringUUID()),
+ column('tenant_id', models.types.StringUUID()),
+ column('provider_name', sa.String()),
+ column('model_name', sa.String()),
+ column('model_type', sa.String()),
+ column('credential_name', sa.String()),
+ column('encrypted_config', models.types.LongText()),
+ column('created_at', sa.DateTime()),
+ column('updated_at', sa.DateTime())
+ )
# Get database connection
@@ -183,14 +156,8 @@ def migrate_existing_provider_models_data():
def downgrade():
# Re-add encrypted_config column to provider_models table
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('provider_models', schema=None) as batch_op:
- batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('provider_models', schema=None) as batch_op:
- batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('provider_models', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
if not context.is_offline_mode():
# Migrate data back from provider_model_credentials to provider_models
diff --git a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py
index 75b4d61173..79fe9d9bba 100644
--- a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py
+++ b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py
@@ -8,7 +8,6 @@ Create Date: 2025-08-20 17:47:17.015695
from alembic import op
import models as models
import sqlalchemy as sa
-from libs.uuid_utils import uuidv7
def _is_pg(conn):
diff --git a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py
index 4f472fe4b4..cf2b973d2d 100644
--- a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py
+++ b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py
@@ -9,8 +9,6 @@ from alembic import op
import models as models
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
import sqlalchemy as sa
@@ -23,12 +21,7 @@ depends_on = None
def upgrade():
# Add encrypted_headers column to tool_mcp_providers table
- conn = op.get_bind()
-
- if _is_pg(conn):
- op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True))
- else:
- op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
+ op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
def downgrade():
diff --git a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py
index 8eac0dee10..bad516dcac 100644
--- a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py
+++ b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py
@@ -44,6 +44,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
)
+
if _is_pg(conn):
op.create_table('datasource_oauth_tenant_params',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -70,6 +71,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique')
)
+
if _is_pg(conn):
op.create_table('datasource_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -104,6 +106,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name')
)
+
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False)
@@ -133,6 +136,7 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey')
)
+
with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op:
batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False)
@@ -174,6 +178,7 @@ def upgrade():
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
)
+
if _is_pg(conn):
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -193,7 +198,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
else:
- # MySQL: Use compatible syntax
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
@@ -211,6 +215,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
+
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
@@ -236,6 +241,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey')
)
+
if _is_pg(conn):
op.create_table('pipelines',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -266,6 +272,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
)
+
if _is_pg(conn):
op.create_table('workflow_draft_variable_files',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -292,6 +299,7 @@ def upgrade():
sa.Column('value_type', sa.String(20), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey'))
)
+
if _is_pg(conn):
op.create_table('workflow_node_execution_offload',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -316,6 +324,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')),
sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key'))
)
+
if _is_pg(conn):
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
@@ -342,6 +351,7 @@ def upgrade():
comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',)
)
batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False)
+
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
diff --git a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py
index 0776ab0818..ec0cfbd11d 100644
--- a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py
+++ b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py
@@ -9,8 +9,6 @@ from alembic import op
import models as models
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
import sqlalchemy as sa
@@ -33,15 +31,9 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
- batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
- batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
- else:
- with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
- batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
- batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
+
+ with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
+ batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py
index 627219cc4b..12905b3674 100644
--- a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py
+++ b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py
@@ -9,7 +9,6 @@ Create Date: 2025-10-22 16:11:31.805407
from alembic import op
import models as models
import sqlalchemy as sa
-from libs.uuid_utils import uuidv7
def _is_pg(conn):
return conn.dialect.name == "postgresql"
diff --git a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py
index 9641a15c89..c27c1058d1 100644
--- a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py
+++ b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py
@@ -105,6 +105,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client')
)
+
if _is_pg(conn):
op.create_table('trigger_subscriptions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
@@ -143,6 +144,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider')
)
+
with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op:
batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True)
batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False)
@@ -176,6 +178,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'),
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription')
)
+
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False)
@@ -207,6 +210,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'),
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node')
)
+
with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op:
batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False)
@@ -264,6 +268,7 @@ def upgrade():
sa.Column('finished_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
)
+
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
@@ -299,6 +304,7 @@ def upgrade():
sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'),
sa.UniqueConstraint('webhook_id', name='uniq_webhook_id')
)
+
with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op:
batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False)
diff --git a/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py
new file mode 100644
index 0000000000..624be1d073
--- /dev/null
+++ b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py
@@ -0,0 +1,60 @@
+"""make message annotation question not nullable
+
+Revision ID: 9e6fa5cbcd80
+Revises: 03f8dcbc611e
+Create Date: 2025-11-06 16:03:54.549378
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '9e6fa5cbcd80'
+down_revision = '288345cd01d1'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ bind = op.get_bind()
+ message_annotations = sa.table(
+ "message_annotations",
+ sa.column("id", sa.String),
+ sa.column("message_id", sa.String),
+ sa.column("question", sa.Text),
+ )
+ messages = sa.table(
+ "messages",
+ sa.column("id", sa.String),
+ sa.column("query", sa.Text),
+ )
+ update_question_from_message = (
+ sa.update(message_annotations)
+ .where(
+ sa.and_(
+ message_annotations.c.question.is_(None),
+ message_annotations.c.message_id.isnot(None),
+ )
+ )
+ .values(
+ question=sa.select(sa.func.coalesce(messages.c.query, ""))
+ .where(messages.c.id == message_annotations.c.message_id)
+ .scalar_subquery()
+ )
+ )
+ bind.execute(update_question_from_message)
+
+ fill_remaining_questions = (
+ sa.update(message_annotations)
+ .where(message_annotations.c.question.is_(None))
+ .values(question="")
+ )
+ bind.execute(fill_remaining_questions)
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=False)
+
+
+def downgrade():
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=True)
diff --git a/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py
new file mode 100644
index 0000000000..e89fcee7e5
--- /dev/null
+++ b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py
@@ -0,0 +1,46 @@
+"""add credit pool
+
+Revision ID: 7df29de0f6be
+Revises: 03ea244985ce
+Create Date: 2025-12-25 10:39:15.139304
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '7df29de0f6be'
+down_revision = '03ea244985ce'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('tenant_credit_pools',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
+ sa.Column('quota_limit', sa.BigInteger(), nullable=False),
+ sa.Column('quota_used', sa.BigInteger(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
+ )
+ with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
+ batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
+ batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+
+ with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
+ batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
+ batch_op.drop_index('tenant_credit_pool_pool_type_idx')
+
+ op.drop_table('tenant_credit_pools')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py
new file mode 100644
index 0000000000..7e0cc8ec9d
--- /dev/null
+++ b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py
@@ -0,0 +1,30 @@
+"""add workflow_run_created_at_id_idx
+
+Revision ID: 905527cc8fd3
+Revises: 7df29de0f6be
+Create Date: 2025-01-09 16:30:02.462084
+
+"""
+from alembic import op
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '905527cc8fd3'
+down_revision = '7df29de0f6be'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_run_created_at_id_idx')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py
new file mode 100644
index 0000000000..758369ba99
--- /dev/null
+++ b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py
@@ -0,0 +1,33 @@
+"""feat: add created_at id index to messages
+
+Revision ID: 3334862ee907
+Revises: 905527cc8fd3
+Create Date: 2026-01-12 17:29:44.846544
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '3334862ee907'
+down_revision = '905527cc8fd3'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('messages', schema=None) as batch_op:
+ batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('messages', schema=None) as batch_op:
+ batch_op.drop_index('message_created_at_id_idx')
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py
new file mode 100644
index 0000000000..2e1af0c83f
--- /dev/null
+++ b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py
@@ -0,0 +1,35 @@
+"""change workflow node execution workflow_run index
+
+Revision ID: 288345cd01d1
+Revises: 3334862ee907
+Create Date: 2026-01-16 17:15:00.000000
+
+"""
+from alembic import op
+
+
+# revision identifiers, used by Alembic.
+revision = "288345cd01d1"
+down_revision = "3334862ee907"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_index("workflow_node_execution_workflow_run_idx")
+ batch_op.create_index(
+ "workflow_node_execution_workflow_run_id_idx",
+ ["workflow_run_id"],
+ unique=False,
+ )
+
+
+def downgrade():
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_index("workflow_node_execution_workflow_run_id_idx")
+ batch_op.create_index(
+ "workflow_node_execution_workflow_run_idx",
+ ["tenant_id", "app_id", "workflow_id", "triggered_from", "workflow_run_id"],
+ unique=False,
+ )
diff --git a/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py
new file mode 100644
index 0000000000..b99ca04e3f
--- /dev/null
+++ b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py
@@ -0,0 +1,73 @@
+"""add table explore banner and trial
+
+Revision ID: f9f6d18a37f9
+Revises: 9e6fa5cbcd80
+Create Date: 2026-01-017 11:10:18.079355
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'f9f6d18a37f9'
+down_revision = '9e6fa5cbcd80'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('account_trial_app_records',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('account_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('count', sa.Integer(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
+ sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
+ )
+ with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
+ batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
+ batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
+
+ op.create_table('exporle_banners',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('content', sa.JSON(), nullable=False),
+ sa.Column('link', sa.String(length=255), nullable=False),
+ sa.Column('sort', sa.Integer(), nullable=False),
+ sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
+ )
+ op.create_table('trial_apps',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('trial_limit', sa.Integer(), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
+ sa.UniqueConstraint('app_id', name='unique_trail_app_id')
+ )
+ with op.batch_alter_table('trial_apps', schema=None) as batch_op:
+ batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
+ batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('trial_apps', schema=None) as batch_op:
+ batch_op.drop_index('trial_app_tenant_id_idx')
+ batch_op.drop_index('trial_app_app_id_idx')
+
+ op.drop_table('trial_apps')
+ op.drop_table('exporle_banners')
+ with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
+ batch_op.drop_index('account_trial_app_record_app_id_idx')
+ batch_op.drop_index('account_trial_app_record_account_id_idx')
+
+ op.drop_table('account_trial_app_records')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py
new file mode 100644
index 0000000000..5e7298af54
--- /dev/null
+++ b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py
@@ -0,0 +1,95 @@
+"""create workflow_archive_logs
+
+Revision ID: 9d77545f524e
+Revises: f9f6d18a37f9
+Create Date: 2026-01-06 17:18:56.292479
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+def _is_pg(conn):
+ return conn.dialect.name == "postgresql"
+
+# revision identifiers, used by Alembic.
+revision = '9d77545f524e'
+down_revision = 'f9f6d18a37f9'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ conn = op.get_bind()
+ if _is_pg(conn):
+ op.create_table('workflow_archive_logs',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
+ sa.Column('log_id', models.types.StringUUID(), nullable=True),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_by_role', sa.String(length=255), nullable=False),
+ sa.Column('created_by', models.types.StringUUID(), nullable=False),
+ sa.Column('log_created_at', sa.DateTime(), nullable=True),
+ sa.Column('log_created_from', sa.String(length=255), nullable=True),
+ sa.Column('run_version', sa.String(length=255), nullable=False),
+ sa.Column('run_status', sa.String(length=255), nullable=False),
+ sa.Column('run_triggered_from', sa.String(length=255), nullable=False),
+ sa.Column('run_error', models.types.LongText(), nullable=True),
+ sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('run_created_at', sa.DateTime(), nullable=False),
+ sa.Column('run_finished_at', sa.DateTime(), nullable=True),
+ sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('trigger_metadata', models.types.LongText(), nullable=True),
+ sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey')
+ )
+ else:
+ op.create_table('workflow_archive_logs',
+ sa.Column('id', models.types.StringUUID(), nullable=False),
+ sa.Column('log_id', models.types.StringUUID(), nullable=True),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_by_role', sa.String(length=255), nullable=False),
+ sa.Column('created_by', models.types.StringUUID(), nullable=False),
+ sa.Column('log_created_at', sa.DateTime(), nullable=True),
+ sa.Column('log_created_from', sa.String(length=255), nullable=True),
+ sa.Column('run_version', sa.String(length=255), nullable=False),
+ sa.Column('run_status', sa.String(length=255), nullable=False),
+ sa.Column('run_triggered_from', sa.String(length=255), nullable=False),
+ sa.Column('run_error', models.types.LongText(), nullable=True),
+ sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('run_created_at', sa.DateTime(), nullable=False),
+ sa.Column('run_finished_at', sa.DateTime(), nullable=True),
+ sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('trigger_metadata', models.types.LongText(), nullable=True),
+ sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey')
+ )
+ with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op:
+ batch_op.create_index('workflow_archive_log_app_idx', ['tenant_id', 'app_id'], unique=False)
+ batch_op.create_index('workflow_archive_log_run_created_at_idx', ['run_created_at'], unique=False)
+ batch_op.create_index('workflow_archive_log_workflow_run_id_idx', ['workflow_run_id'], unique=False)
+
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_archive_log_workflow_run_id_idx')
+ batch_op.drop_index('workflow_archive_log_run_created_at_idx')
+ batch_op.drop_index('workflow_archive_log_app_idx')
+
+ op.drop_table('workflow_archive_logs')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py
index fae506906b..127ffd5599 100644
--- a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py
+++ b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '23db93619b9d'
down_revision = '8ae9bc661daa'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py
index 2676ef0b94..31829d8e58 100644
--- a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py
+++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py
@@ -62,14 +62,8 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.drop_index('app_annotation_settings_app_idx')
diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py
index 3362a3a09f..07a8cd86b1 100644
--- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py
+++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py
@@ -11,9 +11,6 @@ from alembic import op
import models as models
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '2a3aebbbf4bb'
down_revision = 'c031d46af369'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('apps', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('apps', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('apps', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py
index 40bd727f66..211b2d8882 100644
--- a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py
+++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py
@@ -7,14 +7,10 @@ Create Date: 2023-09-22 15:41:01.243183
"""
import sqlalchemy as sa
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '2e9819ca5b28'
down_revision = 'ab23c11305d4'
@@ -24,35 +20,19 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('api_tokens', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
- batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
- batch_op.drop_column('dataset_id')
- else:
- with op.batch_alter_table('api_tokens', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
- batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
- batch_op.drop_column('dataset_id')
+ with op.batch_alter_table('api_tokens', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
+ batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
+ batch_op.drop_column('dataset_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('api_tokens', schema=None) as batch_op:
- batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
- batch_op.drop_index('api_token_tenant_idx')
- batch_op.drop_column('tenant_id')
- else:
- with op.batch_alter_table('api_tokens', schema=None) as batch_op:
- batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
- batch_op.drop_index('api_token_tenant_idx')
- batch_op.drop_column('tenant_id')
+ with op.batch_alter_table('api_tokens', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
+ batch_op.drop_index('api_token_tenant_idx')
+ batch_op.drop_column('tenant_id')
# ### end Alembic commands ###
diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py
index 76056a9460..3491c85e2f 100644
--- a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py
+++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py
@@ -7,14 +7,10 @@ Create Date: 2024-03-07 08:30:29.133614
"""
import sqlalchemy as sa
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '42e85ed5564d'
down_revision = 'f9107f83abab'
@@ -24,59 +20,31 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('conversations', schema=None) as batch_op:
- batch_op.alter_column('app_model_config_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- batch_op.alter_column('model_provider',
- existing_type=sa.VARCHAR(length=255),
- nullable=True)
- batch_op.alter_column('model_id',
- existing_type=sa.VARCHAR(length=255),
- nullable=True)
- else:
- with op.batch_alter_table('conversations', schema=None) as batch_op:
- batch_op.alter_column('app_model_config_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
- batch_op.alter_column('model_provider',
- existing_type=sa.VARCHAR(length=255),
- nullable=True)
- batch_op.alter_column('model_id',
- existing_type=sa.VARCHAR(length=255),
- nullable=True)
+ with op.batch_alter_table('conversations', schema=None) as batch_op:
+ batch_op.alter_column('app_model_config_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
+ batch_op.alter_column('model_provider',
+ existing_type=sa.VARCHAR(length=255),
+ nullable=True)
+ batch_op.alter_column('model_id',
+ existing_type=sa.VARCHAR(length=255),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('conversations', schema=None) as batch_op:
- batch_op.alter_column('model_id',
- existing_type=sa.VARCHAR(length=255),
- nullable=False)
- batch_op.alter_column('model_provider',
- existing_type=sa.VARCHAR(length=255),
- nullable=False)
- batch_op.alter_column('app_model_config_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- else:
- with op.batch_alter_table('conversations', schema=None) as batch_op:
- batch_op.alter_column('model_id',
- existing_type=sa.VARCHAR(length=255),
- nullable=False)
- batch_op.alter_column('model_provider',
- existing_type=sa.VARCHAR(length=255),
- nullable=False)
- batch_op.alter_column('app_model_config_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
+ with op.batch_alter_table('conversations', schema=None) as batch_op:
+ batch_op.alter_column('model_id',
+ existing_type=sa.VARCHAR(length=255),
+ nullable=False)
+ batch_op.alter_column('model_provider',
+ existing_type=sa.VARCHAR(length=255),
+ nullable=False)
+ batch_op.alter_column('app_model_config_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py
index ef066587b7..8537a87233 100644
--- a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py
+++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py
@@ -6,14 +6,10 @@ Create Date: 2024-01-12 03:42:27.362415
"""
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '4829e54d2fee'
down_revision = '114eed84c228'
@@ -23,39 +19,21 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- # PostgreSQL: Keep original syntax
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.alter_column('message_chain_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- else:
- # MySQL: Use compatible syntax
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.alter_column('message_chain_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
+
+ with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+ batch_op.alter_column('message_chain_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- # PostgreSQL: Keep original syntax
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.alter_column('message_chain_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- else:
- # MySQL: Use compatible syntax
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.alter_column('message_chain_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
+
+ with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+ batch_op.alter_column('message_chain_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py
index b080e7680b..22405e3cc8 100644
--- a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py
+++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py
@@ -6,14 +6,10 @@ Create Date: 2024-03-14 04:54:56.679506
"""
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '563cf8bf777b'
down_revision = 'b5429b71023c'
@@ -23,35 +19,19 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('conversation_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- else:
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('conversation_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
+ with op.batch_alter_table('tool_files', schema=None) as batch_op:
+ batch_op.alter_column('conversation_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('conversation_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- else:
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('conversation_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
+ with op.batch_alter_table('tool_files', schema=None) as batch_op:
+ batch_op.alter_column('conversation_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py
index 1ace8ea5a0..01d7d5ba21 100644
--- a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py
+++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py
@@ -48,12 +48,9 @@ def upgrade():
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
- if _is_pg(conn):
- with op.batch_alter_table('datasets', schema=None) as batch_op:
- batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
- else:
- with op.batch_alter_table('datasets', schema=None) as batch_op:
- batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
+
+ with op.batch_alter_table('datasets', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py
index 457338ef42..0faa48f535 100644
--- a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py
+++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '714aafe25d39'
down_revision = 'f2a6fc85e260'
@@ -23,16 +20,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
- batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
- else:
- with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False))
- batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False))
+ with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False))
+ batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py
index 7bcd1a1be3..aa7b4a21e2 100644
--- a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py
+++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '77e83833755c'
down_revision = '6dcb43972bdc'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
index 3c0aa082d5..34a17697d3 100644
--- a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
+++ b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
@@ -27,7 +27,6 @@ def upgrade():
conn = op.get_bind()
if _is_pg(conn):
- # PostgreSQL: Keep original syntax
op.create_table('tool_providers',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
@@ -40,7 +39,6 @@ def upgrade():
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
else:
- # MySQL: Use compatible syntax
op.create_table('tool_providers',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
@@ -52,12 +50,9 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True))
+
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py
index beea90b384..884839c010 100644
--- a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py
+++ b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '88072f0caa04'
down_revision = '246ba09cbbdb'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tenants', schema=None) as batch_op:
- batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('tenants', schema=None) as batch_op:
- batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('tenants', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/89c7899ca936_.py b/api/migrations/versions/89c7899ca936_.py
index 2420710e74..d26f1e82d6 100644
--- a/api/migrations/versions/89c7899ca936_.py
+++ b/api/migrations/versions/89c7899ca936_.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '89c7899ca936'
down_revision = '187385f442fc'
@@ -23,39 +20,21 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('description',
- existing_type=sa.VARCHAR(length=255),
- type_=sa.Text(),
- existing_nullable=True)
- else:
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('description',
- existing_type=sa.VARCHAR(length=255),
- type_=models.types.LongText(),
- existing_nullable=True)
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.alter_column('description',
+ existing_type=sa.VARCHAR(length=255),
+ type_=models.types.LongText(),
+ existing_nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('description',
- existing_type=sa.Text(),
- type_=sa.VARCHAR(length=255),
- existing_nullable=True)
- else:
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('description',
- existing_type=models.types.LongText(),
- type_=sa.VARCHAR(length=255),
- existing_nullable=True)
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.alter_column('description',
+ existing_type=models.types.LongText(),
+ type_=sa.VARCHAR(length=255),
+ existing_nullable=True)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py
index 111e81240b..6022ea2c20 100644
--- a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py
+++ b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '8ec536f3c800'
down_revision = 'ad472b61a054'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False))
- else:
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False))
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
index 1c1c6cacbb..9d6d40114d 100644
--- a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
+++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
@@ -57,12 +57,9 @@ def upgrade():
batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True))
+
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True))
if _is_pg(conn):
with op.batch_alter_table('upload_files', schema=None) as batch_op:
diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py
index 5d29d354f3..0b3f92a12e 100644
--- a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py
+++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py
@@ -24,7 +24,6 @@ def upgrade():
conn = op.get_bind()
if _is_pg(conn):
- # PostgreSQL: Keep original syntax
with op.batch_alter_table('pinned_conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False))
batch_op.drop_index('pinned_conversation_conversation_idx')
@@ -35,7 +34,6 @@ def upgrade():
batch_op.drop_index('saved_message_message_idx')
batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False)
else:
- # MySQL: Use compatible syntax
with op.batch_alter_table('pinned_conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'"), nullable=False))
batch_op.drop_index('pinned_conversation_conversation_idx')
diff --git a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py
index 616cb2f163..c8747a51f7 100644
--- a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py
+++ b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'a5b56fb053ef'
down_revision = 'd3d503a3471c'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py
index 900ff78036..f56aeb7e66 100644
--- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py
+++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'a9836e3baeee'
down_revision = '968fff4c0ab9'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/b24be59fbb04_.py b/api/migrations/versions/b24be59fbb04_.py
index b0a6d10d8c..ae91eaf1bc 100644
--- a/api/migrations/versions/b24be59fbb04_.py
+++ b/api/migrations/versions/b24be59fbb04_.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'b24be59fbb04'
down_revision = 'de95f5c77138'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('text_to_speech', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
index 772395c25b..c02c24c23f 100644
--- a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
+++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'b3a09c049e8e'
down_revision = '2e9819ca5b28'
@@ -23,20 +20,11 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
- batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
- batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
- batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
- batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True))
- batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True))
- batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
+ batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True))
+ batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True))
+ batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py
index 76be794ff4..fe51d1c78d 100644
--- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py
+++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py
@@ -7,7 +7,6 @@ Create Date: 2024-06-17 10:01:00.255189
"""
import sqlalchemy as sa
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
diff --git a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py
index 9e02ec5d84..36e934f0fc 100644
--- a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py
+++ b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py
@@ -54,12 +54,9 @@ def upgrade():
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True))
+
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True))
if _is_pg(conn):
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
@@ -68,54 +65,31 @@ def upgrade():
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'"), nullable=False))
- if _is_pg(conn):
- with op.batch_alter_table('message_annotations', schema=None) as batch_op:
- batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
- batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
- batch_op.alter_column('conversation_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- batch_op.alter_column('message_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- else:
- with op.batch_alter_table('message_annotations', schema=None) as batch_op:
- batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True))
- batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
- batch_op.alter_column('conversation_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
- batch_op.alter_column('message_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True))
+ batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
+ batch_op.alter_column('conversation_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
+ batch_op.alter_column('message_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
- if _is_pg(conn):
- with op.batch_alter_table('message_annotations', schema=None) as batch_op:
- batch_op.alter_column('message_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- batch_op.alter_column('conversation_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- batch_op.drop_column('hit_count')
- batch_op.drop_column('question')
- else:
- with op.batch_alter_table('message_annotations', schema=None) as batch_op:
- batch_op.alter_column('message_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
- batch_op.alter_column('conversation_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
- batch_op.drop_column('hit_count')
- batch_op.drop_column('question')
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.alter_column('message_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
+ batch_op.alter_column('conversation_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
+ batch_op.drop_column('hit_count')
+ batch_op.drop_column('question')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_column('type')
diff --git a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py
index 02098e91c1..ac1c14e50c 100644
--- a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py
+++ b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py
@@ -12,9 +12,6 @@ from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'f2a6fc85e260'
down_revision = '46976cc39132'
@@ -24,16 +21,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
- batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
- batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
- else:
- with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
- batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False))
- batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
+ with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False))
+ batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
# ### end Alembic commands ###
diff --git a/api/models/__init__.py b/api/models/__init__.py
index f648a60ace..1d5d604ba7 100644
--- a/api/models/__init__.py
+++ b/api/models/__init__.py
@@ -37,6 +37,7 @@ from .enums import (
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
from .human_input import HumanInputForm
from .model import (
+ AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@@ -49,6 +50,7 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
+ ExporleBanner,
IconType,
InstalledApp,
Message,
@@ -62,7 +64,9 @@ from .model import (
Site,
Tag,
TagBinding,
+ TenantCreditPool,
TraceAppConfig,
+ TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@@ -101,6 +105,7 @@ from .workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
+ WorkflowArchiveLog,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
WorkflowNodeExecutionTriggeredFrom,
@@ -115,6 +120,7 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
+ "AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@@ -152,6 +158,7 @@ __all__ = [
"Embedding",
"EndUser",
"ExecutionExtraContent",
+ "ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"HumanInputContent",
@@ -182,6 +189,7 @@ __all__ = [
"Tenant",
"TenantAccountJoin",
"TenantAccountRole",
+ "TenantCreditPool",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
@@ -191,6 +199,7 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
+ "TrialApp",
"TriggerOAuthSystemClient",
"TriggerOAuthTenantClient",
"TriggerSubscription",
@@ -200,6 +209,7 @@ __all__ = [
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
+ "WorkflowArchiveLog",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",
diff --git a/api/models/account.py b/api/models/account.py
index 420e6adc6c..f7a9c20026 100644
--- a/api/models/account.py
+++ b/api/models/account.py
@@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
-from sqlalchemy.orm import Mapped, Session, mapped_column
+from sqlalchemy.orm import Mapped, Session, mapped_column, validates
from typing_extensions import deprecated
from .base import TypeBase
@@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False)
+ @validates("status")
+ def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
+ if isinstance(value, AccountStatus):
+ return value.value
+ return value
+
@property
def is_password_set(self):
return self.password is not None
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 445ac6086f..62f11b8c72 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -1149,7 +1149,7 @@ class DatasetCollectionBinding(TypeBase):
)
-class TidbAuthBinding(Base):
+class TidbAuthBinding(TypeBase):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@@ -1158,7 +1158,13 @@ class TidbAuthBinding(Base):
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1166,7 +1172,9 @@ class TidbAuthBinding(Base):
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
- created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
class Whitelist(TypeBase):
diff --git a/api/models/model.py b/api/models/model.py
index 76efa0d989..39bb52d668 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import re
import uuid
@@ -5,13 +7,13 @@ from collections.abc import Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
-from typing import TYPE_CHECKING, Any, Literal, Optional, cast
+from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4
import sqlalchemy as sa
from flask import request
-from flask_login import UserMixin
-from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
+from flask_login import UserMixin # type: ignore[import-untyped]
+from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@@ -54,7 +56,7 @@ class AppMode(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
- def value_of(cls, value: str) -> "AppMode":
+ def value_of(cls, value: str) -> AppMode:
"""
Get value of given mode.
@@ -70,6 +72,7 @@ class AppMode(StrEnum):
class IconType(StrEnum):
IMAGE = auto()
EMOJI = auto()
+ LINK = auto()
class App(Base):
@@ -81,7 +84,7 @@ class App(Base):
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(LongText, default=sa.text("''"))
mode: Mapped[str] = mapped_column(String(255))
- icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji
+ icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link
icon = mapped_column(String(255))
icon_background: Mapped[str | None] = mapped_column(String(255))
app_model_config_id = mapped_column(StringUUID, nullable=True)
@@ -120,19 +123,19 @@ class App(Base):
return ""
@property
- def site(self) -> Optional["Site"]:
+ def site(self) -> Site | None:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
@property
- def app_model_config(self) -> Optional["AppModelConfig"]:
+ def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return None
@property
- def workflow(self) -> Optional["Workflow"]:
+ def workflow(self) -> Workflow | None:
if self.workflow_id:
from .workflow import Workflow
@@ -287,7 +290,7 @@ class App(Base):
return deleted_tools
@property
- def tags(self) -> list["Tag"]:
+ def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
@@ -312,40 +315,48 @@ class App(Base):
return None
-class AppModelConfig(Base):
+class AppModelConfig(TypeBase):
__tablename__ = "app_model_configs"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- app_id = mapped_column(StringUUID, nullable=False)
- provider = mapped_column(String(255), nullable=True)
- model_id = mapped_column(String(255), nullable=True)
- configs = mapped_column(sa.JSON, nullable=True)
- created_by = mapped_column(StringUUID, nullable=True)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = mapped_column(StringUUID, nullable=True)
- updated_at = mapped_column(
- sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ configs: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True, default=None)
+ created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
- opening_statement = mapped_column(LongText)
- suggested_questions = mapped_column(LongText)
- suggested_questions_after_answer = mapped_column(LongText)
- speech_to_text = mapped_column(LongText)
- text_to_speech = mapped_column(LongText)
- more_like_this = mapped_column(LongText)
- model = mapped_column(LongText)
- user_input_form = mapped_column(LongText)
- dataset_query_variable = mapped_column(String(255))
- pre_prompt = mapped_column(LongText)
- agent_mode = mapped_column(LongText)
- sensitive_word_avoidance = mapped_column(LongText)
- retriever_resource = mapped_column(LongText)
- prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'"))
- chat_prompt_config = mapped_column(LongText)
- completion_prompt_config = mapped_column(LongText)
- dataset_configs = mapped_column(LongText)
- external_data_tools = mapped_column(LongText)
- file_upload = mapped_column(LongText)
+ updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ updated_at: Mapped[datetime] = mapped_column(
+ sa.DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
+ )
+ opening_statement: Mapped[str | None] = mapped_column(LongText, default=None)
+ suggested_questions: Mapped[str | None] = mapped_column(LongText, default=None)
+ suggested_questions_after_answer: Mapped[str | None] = mapped_column(LongText, default=None)
+ speech_to_text: Mapped[str | None] = mapped_column(LongText, default=None)
+ text_to_speech: Mapped[str | None] = mapped_column(LongText, default=None)
+ more_like_this: Mapped[str | None] = mapped_column(LongText, default=None)
+ model: Mapped[str | None] = mapped_column(LongText, default=None)
+ user_input_form: Mapped[str | None] = mapped_column(LongText, default=None)
+ dataset_query_variable: Mapped[str | None] = mapped_column(String(255), default=None)
+ pre_prompt: Mapped[str | None] = mapped_column(LongText, default=None)
+ agent_mode: Mapped[str | None] = mapped_column(LongText, default=None)
+ sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None)
+ retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None)
+ prompt_type: Mapped[str] = mapped_column(
+ String(255), nullable=False, server_default=sa.text("'simple'"), default="simple"
+ )
+ chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
+ completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
+ dataset_configs: Mapped[str | None] = mapped_column(LongText, default=None)
+ external_data_tools: Mapped[str | None] = mapped_column(LongText, default=None)
+ file_upload: Mapped[str | None] = mapped_column(LongText, default=None)
@property
def app(self) -> App | None:
@@ -600,6 +611,64 @@ class InstalledApp(TypeBase):
return tenant
+class TrialApp(Base):
+ __tablename__ = "trial_apps"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
+ sa.Index("trial_app_app_id_idx", "app_id"),
+ sa.Index("trial_app_tenant_id_idx", "tenant_id"),
+ sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
+ )
+
+ id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+ trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
+
+ @property
+ def app(self) -> App | None:
+ app = db.session.query(App).where(App.id == self.app_id).first()
+ return app
+
+
+class AccountTrialAppRecord(Base):
+ __tablename__ = "account_trial_app_records"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
+ sa.Index("account_trial_app_record_account_id_idx", "account_id"),
+ sa.Index("account_trial_app_record_app_id_idx", "app_id"),
+ sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
+ )
+ id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
+ account_id = mapped_column(StringUUID, nullable=False)
+ app_id = mapped_column(StringUUID, nullable=False)
+ count = mapped_column(sa.Integer, nullable=False, default=0)
+ created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+
+ @property
+ def app(self) -> App | None:
+ app = db.session.query(App).where(App.id == self.app_id).first()
+ return app
+
+ @property
+ def user(self) -> Account | None:
+ user = db.session.query(Account).where(Account.id == self.account_id).first()
+ return user
+
+
+class ExporleBanner(Base):
+ __tablename__ = "exporle_banners"
+ __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
+ id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
+ content = mapped_column(sa.JSON, nullable=False)
+ link = mapped_column(String(255), nullable=False)
+ sort = mapped_column(sa.Integer, nullable=False)
+ status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
+ created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+ language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
+
+
class OAuthProviderApp(TypeBase):
"""
Globally shared OAuth provider app information.
@@ -749,8 +818,8 @@ class Conversation(Base):
override_model_configs = json.loads(self.override_model_configs)
if "model" in override_model_configs:
- app_model_config = AppModelConfig()
- app_model_config = app_model_config.from_model_config_dict(override_model_configs)
+ # where is app_id?
+ app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
@@ -967,6 +1036,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
+ Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -1195,7 +1265,7 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
- def agent_thoughts(self) -> list["MessageAgentThought"]:
+ def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
@@ -1316,7 +1386,7 @@ class Message(Base):
}
@classmethod
- def from_dict(cls, data: dict[str, Any]) -> "Message":
+ def from_dict(cls, data: dict[str, Any]) -> Message:
return cls(
id=data["id"],
app_id=data["app_id"],
@@ -1429,15 +1499,20 @@ class MessageAnnotation(Base):
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[str | None] = mapped_column(StringUUID)
- question = mapped_column(LongText, nullable=True)
- content = mapped_column(LongText, nullable=False)
+ question: Mapped[str] = mapped_column(LongText, nullable=False)
+ content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
- account_id = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = mapped_column(
+ account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
+ @property
+ def question_text(self) -> str:
+ """Return a non-null question string, falling back to the answer content."""
+ return self.question or self.content
+
@property
def account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
@@ -1449,7 +1524,7 @@ class MessageAnnotation(Base):
return account
-class AppAnnotationHitHistory(Base):
+class AppAnnotationHitHistory(TypeBase):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@@ -1459,17 +1534,19 @@ class AppAnnotationHitHistory(Base):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- app_id = mapped_column(StringUUID, nullable=False)
+ id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
- source = mapped_column(LongText, nullable=False)
- question = mapped_column(LongText, nullable=False)
- account_id = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
- message_id = mapped_column(StringUUID, nullable=False)
- annotation_question = mapped_column(LongText, nullable=False)
- annotation_content = mapped_column(LongText, nullable=False)
+ source: Mapped[str] = mapped_column(LongText, nullable=False)
+ question: Mapped[str] = mapped_column(LongText, nullable=False)
+ account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
+ score: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ annotation_question: Mapped[str] = mapped_column(LongText, nullable=False)
+ annotation_content: Mapped[str] = mapped_column(LongText, nullable=False)
@property
def account(self):
@@ -1538,7 +1615,7 @@ class OperationLog(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
action: Mapped[str] = mapped_column(String(255), nullable=False)
- content: Mapped[Any] = mapped_column(sa.JSON)
+ content: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@@ -1845,7 +1922,7 @@ class MessageChain(TypeBase):
)
-class MessageAgentThought(Base):
+class MessageAgentThought(TypeBase):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
@@ -1853,34 +1930,42 @@ class MessageAgentThought(Base):
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- message_id = mapped_column(StringUUID, nullable=False)
- message_chain_id = mapped_column(StringUUID, nullable=True)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
- thought = mapped_column(LongText, nullable=True)
- tool = mapped_column(LongText, nullable=True)
- tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
- tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
- tool_input = mapped_column(LongText, nullable=True)
- observation = mapped_column(LongText, nullable=True)
+ created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+ tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+ tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
- tool_process_data = mapped_column(LongText, nullable=True)
- message = mapped_column(LongText, nullable=True)
- message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- message_unit_price = mapped_column(sa.Numeric, nullable=True)
- message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
- message_files = mapped_column(LongText, nullable=True)
- answer = mapped_column(LongText, nullable=True)
- answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- answer_unit_price = mapped_column(sa.Numeric, nullable=True)
- answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
- tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- total_price = mapped_column(sa.Numeric, nullable=True)
- currency = mapped_column(String(255), nullable=True)
- latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
- created_by_role = mapped_column(String(255), nullable=False)
- created_by = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
+ tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ message_price_unit: Mapped[Decimal] = mapped_column(
+ sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
+ )
+ message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ answer_price_unit: Mapped[Decimal] = mapped_column(
+ sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
+ )
+ tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
+ )
@property
def files(self) -> list[Any]:
@@ -2075,3 +2160,35 @@ class TraceAppConfig(TypeBase):
"created_at": str(self.created_at) if self.created_at else None,
"updated_at": str(self.updated_at) if self.updated_at else None,
}
+
+
+class TenantCreditPool(TypeBase):
+ __tablename__ = "tenant_credit_pools"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
+ sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
+ sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False)
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
+ quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
+ quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), init=False
+ )
+ updated_at: Mapped[datetime] = mapped_column(
+ sa.DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
+ )
+
+ @property
+ def remaining_credits(self) -> int:
+ return max(0, self.quota_limit - self.quota_used)
+
+ def has_sufficient_credits(self, required_credits: int) -> bool:
+ return self.remaining_credits >= required_credits
diff --git a/api/models/provider.py b/api/models/provider.py
index 2afd8c5329..441b54c797 100644
--- a/api/models/provider.py
+++ b/api/models/provider.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from datetime import datetime
from enum import StrEnum, auto
from functools import cached_property
@@ -19,7 +21,7 @@ class ProviderType(StrEnum):
SYSTEM = auto()
@staticmethod
- def value_of(value: str) -> "ProviderType":
+ def value_of(value: str) -> ProviderType:
for member in ProviderType:
if member.value == value:
return member
@@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum):
"""hosted trial quota"""
@staticmethod
- def value_of(value: str) -> "ProviderQuotaType":
+ def value_of(value: str) -> ProviderQuotaType:
for member in ProviderQuotaType:
if member.value == value:
return member
@@ -76,7 +78,7 @@ class Provider(TypeBase):
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
- quota_used: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, default=0)
+ quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
diff --git a/api/models/tools.py b/api/models/tools.py
index e4f9bcb582..e7b98dcf27 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from datetime import datetime
from decimal import Decimal
@@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase):
)
@property
- def schema_type(self) -> "ApiProviderSchemaType":
+ def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
- def tools(self) -> list["ApiToolBundle"]:
+ def tools(self) -> list[ApiToolBundle]:
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
@property
@@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase):
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
- def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
+ def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
return [
WorkflowToolParameterConfiguration.model_validate(config)
for config in json.loads(self.parameter_configuration)
@@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase):
except (json.JSONDecodeError, TypeError):
return []
- def to_entity(self) -> "MCPProviderEntity":
+ def to_entity(self) -> MCPProviderEntity:
"""Convert to domain entity"""
from core.entities.mcp_provider import MCPProviderEntity
@@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase):
)
@property
- def description_i18n(self) -> "I18nObject":
+ def description_i18n(self) -> I18nObject:
return I18nObject.model_validate(json.loads(self.description))
diff --git a/api/models/trigger.py b/api/models/trigger.py
index 87e2a5ccfc..209345eb84 100644
--- a/api/models/trigger.py
+++ b/api/models/trigger.py
@@ -415,7 +415,7 @@ class AppTrigger(TypeBase):
node_id: Mapped[str | None] = mapped_column(String(64), nullable=False)
trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
- provider_name: Mapped[str] = mapped_column(String(255), server_default="", default="") # why it is nullable?
+ provider_name: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default="")
status: Mapped[str] = mapped_column(
EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.ENABLED
)
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 5522a16ab7..94e0881bd1 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -45,7 +45,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
-from core.variables import SecretVariable, Segment, SegmentType, Variable
+from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@@ -177,8 +177,8 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
@@ -227,8 +227,7 @@ class Workflow(Base): # bug
#
# Currently, the following functions / methods would mutate the returned dict:
#
- # - `_get_graph_and_variable_pool_of_single_iteration`.
- # - `_get_graph_and_variable_pool_of_single_loop`.
+ # - `_get_graph_and_variable_pool_for_single_node_run`.
return json.loads(self.graph) if self.graph else {}
def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
@@ -451,7 +450,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
- var: Variable,
+ var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@@ -467,7 +466,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
- def environment_variables(self, value: Sequence[Variable]):
+ def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@@ -491,7 +490,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
- def encrypt_func(var: Variable) -> Variable:
+ def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@@ -521,7 +520,7 @@ class Workflow(Base): # bug
return result
@property
- def conversation_variables(self) -> Sequence[Variable]:
+ def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@@ -531,7 +530,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
- def conversation_variables(self, value: Sequence[Variable]):
+ def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@@ -601,6 +600,7 @@ class WorkflowRun(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
+ sa.Index("workflow_run_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -793,11 +793,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
- "workflow_node_execution_workflow_run_idx",
- "tenant_id",
- "app_id",
- "workflow_id",
- "triggered_from",
+ "workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
@@ -1179,6 +1175,69 @@ class WorkflowAppLog(TypeBase):
}
+class WorkflowArchiveLog(TypeBase):
+ """
+ Workflow archive log.
+
+ Stores essential workflow run snapshot data for archived app logs.
+
+ Field sources:
+ - Shared fields (tenant/app/workflow/run ids, created_by*): from WorkflowRun for consistency.
+ - log_* fields: from WorkflowAppLog when present; null if the run has no app log.
+ - run_* fields: workflow run snapshot fields from WorkflowRun.
+ - trigger_metadata: snapshot from WorkflowTriggerLog when present.
+ """
+
+ __tablename__ = "workflow_archive_logs"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="workflow_archive_log_pkey"),
+ sa.Index("workflow_archive_log_app_idx", "tenant_id", "app_id"),
+ sa.Index("workflow_archive_log_workflow_run_id_idx", "workflow_run_id"),
+ sa.Index("workflow_archive_log_run_created_at_idx", "run_created_at"),
+ )
+
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
+
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+
+ log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
+ log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+ log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
+
+ run_version: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_status: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_error: Mapped[str | None] = mapped_column(LongText, nullable=True)
+ run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
+ run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
+ run_total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
+ run_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
+ run_finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+ run_exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
+
+ trigger_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True)
+ archived_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
+
+ @property
+ def workflow_run_summary(self) -> dict[str, Any]:
+ return {
+ "id": self.workflow_run_id,
+ "status": self.run_status,
+ "triggered_from": self.run_triggered_from,
+ "elapsed_time": self.run_elapsed_time,
+ "total_tokens": self.run_total_tokens,
+ }
+
+
class ConversationVariable(TypeBase):
__tablename__ = "workflow_conversation_variables"
@@ -1194,7 +1253,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
- def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable":
+ def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@@ -1203,7 +1262,7 @@ class ConversationVariable(TypeBase):
)
return obj
- def to_variable(self) -> Variable:
+ def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@@ -1519,6 +1578,7 @@ class WorkflowDraftVariable(Base):
file_id: str | None = None,
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
+ variable.id = str(uuid4())
variable.created_at = naive_utc_now()
variable.updated_at = naive_utc_now()
variable.description = description
diff --git a/api/pyproject.toml b/api/pyproject.toml
index dbc6a2eb83..575c1434c5 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "dify-api"
-version = "1.11.2"
+version = "1.11.4"
requires-python = ">=3.11,<3.13"
dependencies = [
@@ -31,7 +31,7 @@ dependencies = [
"gunicorn~=23.0.0",
"httpx[socks]~=0.27.0",
"jieba==0.42.1",
- "json-repair>=0.41.1",
+ "json-repair>=0.55.1",
"jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
@@ -64,7 +64,7 @@ dependencies = [
"pandas[excel,output-formatting,performance]~=2.2.2",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
- "pycryptodome==3.19.1",
+ "pycryptodome==3.23.0",
"pydantic~=2.11.4",
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.11.0",
@@ -93,6 +93,7 @@ dependencies = [
"weaviate-client==4.17.0",
"apscheduler>=3.11.0",
"weave>=0.52.16",
+ "fastopenapi[flask]>=0.7.0",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@@ -189,7 +190,7 @@ storage = [
"opendal~=0.46.0",
"oss2==2.18.5",
"supabase~=2.18.1",
- "tos~=2.7.1",
+ "tos~=2.9.0",
]
############################################################
diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json
index 6a689b96df..007c49ddb0 100644
--- a/api/pyrightconfig.json
+++ b/api/pyrightconfig.json
@@ -8,6 +8,7 @@
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [
+ "fastopenapi",
"flask_restx",
"flask_login",
"opentelemetry.instrumentation.celery",
diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py
index 641e8a488e..6446eb0d6e 100644
--- a/api/repositories/api_workflow_node_execution_repository.py
+++ b/api/repositories/api_workflow_node_execution_repository.py
@@ -14,8 +14,10 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Protocol
+from sqlalchemy.orm import Session
+
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
-from models.workflow import WorkflowNodeExecutionModel
+from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
@dataclass(frozen=True)
@@ -175,6 +177,18 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
"""
...
+ def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Count node executions and offloads for the given workflow run ids.
+ """
+ ...
+
+ def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Delete node executions and offloads for the given workflow run ids.
+ """
+ ...
+
def delete_executions_by_app(
self,
tenant_id: str,
@@ -240,3 +254,23 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
The number of executions deleted
"""
...
+
+ def get_offloads_by_execution_ids(
+ self,
+ session: Session,
+ node_execution_ids: Sequence[str],
+ ) -> Sequence[WorkflowNodeExecutionOffload]:
+ """
+ Get offload records by node execution IDs.
+
+ This method retrieves workflow node execution offload records
+ that belong to the given node execution IDs.
+
+ Args:
+ session: The database session to use
+ node_execution_ids: List of node execution IDs to filter by
+
+ Returns:
+ A sequence of WorkflowNodeExecutionOffload instances
+ """
+ ...
diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py
index c1ed45812e..17e01a6e18 100644
--- a/api/repositories/api_workflow_run_repository.py
+++ b/api/repositories/api_workflow_run_repository.py
@@ -34,15 +34,18 @@ Example:
```
"""
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from typing import Protocol
+from sqlalchemy.orm import Session
+
from core.workflow.entities.pause_reason import PauseReason
+from core.workflow.enums import WorkflowType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
-from models.workflow import WorkflowRun
+from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
AverageInteractionStats,
@@ -253,6 +256,151 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
...
+ def get_runs_batch_by_time_range(
+ self,
+ start_from: datetime | None,
+ end_before: datetime,
+ last_seen: tuple[datetime, str] | None,
+ batch_size: int,
+ run_types: Sequence[WorkflowType] | None = None,
+ tenant_ids: Sequence[str] | None = None,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Fetch ended workflow runs in a time window for archival and clean batching.
+ """
+ ...
+
+ def get_archived_run_ids(
+ self,
+ session: Session,
+ run_ids: Sequence[str],
+ ) -> set[str]:
+ """
+ Fetch workflow run IDs that already have archive log records.
+ """
+ ...
+
+ def get_archived_logs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowArchiveLog]:
+ """
+ Fetch archived workflow logs by time range for restore.
+ """
+ ...
+
+ def get_archived_log_by_run_id(
+ self,
+ run_id: str,
+ ) -> WorkflowArchiveLog | None:
+ """
+ Fetch a workflow archive log by workflow run ID.
+ """
+ ...
+
+ def delete_archive_log_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> int:
+ """
+ Delete archive log by workflow run ID.
+
+ Used after restoring a workflow run to remove the archive log record,
+ allowing the run to be archived again if needed.
+
+ Args:
+ session: Database session
+ run_id: Workflow run ID
+
+ Returns:
+ Number of records deleted (0 or 1)
+ """
+ ...
+
+ def delete_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ """
+ Delete workflow runs and their related records (node executions, offloads, app logs,
+ trigger logs, pauses, pause reasons).
+ """
+ ...
+
+ def get_pause_records_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowPause]:
+ """
+ Fetch workflow pause records by workflow run ID.
+ """
+ ...
+
+ def get_pause_reason_records_by_run_id(
+ self,
+ session: Session,
+ pause_ids: Sequence[str],
+ ) -> Sequence[WorkflowPauseReason]:
+ """
+ Fetch workflow pause reason records by pause IDs.
+ """
+ ...
+
+ def get_app_logs_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowAppLog]:
+ """
+ Fetch workflow app logs by workflow run ID.
+ """
+ ...
+
+ def create_archive_logs(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ app_logs: Sequence[WorkflowAppLog],
+ trigger_metadata: str | None,
+ ) -> int:
+ """
+ Create archive log records for a workflow run.
+ """
+ ...
+
+ def get_archived_runs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Return workflow runs that already have archive logs, for cleanup of `workflow_runs`.
+ """
+ ...
+
+ def count_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ """
+ Count workflow runs and their related records (node executions, offloads, app logs,
+ trigger logs, pauses, pause reasons) without deleting data.
+ """
+ ...
+
def create_workflow_pause(
self,
workflow_run_id: str,
diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
index 1512c162d7..e74e6e26f1 100644
--- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
@@ -9,12 +9,12 @@ import json
from collections.abc import Sequence
from datetime import datetime
-from sqlalchemy import asc, delete, desc, select
+from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
-from models.workflow import WorkflowNodeExecutionModel
+from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
from repositories.api_workflow_node_execution_repository import (
DifyAPIWorkflowNodeExecutionRepository,
WorkflowNodeExecutionSnapshot,
@@ -369,3 +369,85 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
result = cast(CursorResult, session.execute(stmt))
session.commit()
return result.rowcount
+
+ def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
+ """
+ if not run_ids:
+ return 0, 0
+
+ run_ids = list(run_ids)
+ run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
+ node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
+
+ offloads_deleted = (
+ cast(
+ CursorResult,
+ session.execute(
+ delete(WorkflowNodeExecutionOffload).where(
+ WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
+ )
+ ),
+ ).rowcount
+ or 0
+ )
+
+ node_executions_deleted = (
+ cast(
+ CursorResult,
+ session.execute(delete(WorkflowNodeExecutionModel).where(run_id_filter)),
+ ).rowcount
+ or 0
+ )
+
+ return node_executions_deleted, offloads_deleted
+
+ def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Count node executions (and offloads) for the given workflow runs using workflow_run_id.
+ """
+ if not run_ids:
+ return 0, 0
+
+ run_ids = list(run_ids)
+ run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
+
+ node_executions_count = (
+ session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(run_id_filter)) or 0
+ )
+ node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
+ offloads_count = (
+ session.scalar(
+ select(func.count())
+ .select_from(WorkflowNodeExecutionOffload)
+ .where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
+ )
+ or 0
+ )
+
+ return int(node_executions_count), int(offloads_count)
+
+ @staticmethod
+ def get_by_run(
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Fetch node executions for a run using workflow_run_id.
+ """
+ stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def get_offloads_by_execution_ids(
+ self,
+ session: Session,
+ node_execution_ids: Sequence[str],
+ ) -> Sequence[WorkflowNodeExecutionOffload]:
+ if not node_execution_ids:
+ return []
+
+ stmt = select(WorkflowNodeExecutionOffload).where(
+ WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
+ )
+ return list(session.scalars(stmt))
diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py
index cbf495dce1..00cb979e17 100644
--- a/api/repositories/sqlalchemy_api_workflow_run_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py
@@ -22,7 +22,7 @@ Implementation Notes:
import json
import logging
import uuid
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from decimal import Decimal
from typing import Any, cast
@@ -34,7 +34,7 @@ from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
-from core.workflow.enums import WorkflowExecutionStatus
+from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.nodes.human_input.entities import FormDefinition
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
@@ -44,8 +44,7 @@ from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
-from models.workflow import WorkflowPause as WorkflowPauseModel
-from models.workflow import WorkflowPauseReason, WorkflowRun
+from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
@@ -379,6 +378,335 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
return total_deleted
+ def get_runs_batch_by_time_range(
+ self,
+ start_from: datetime | None,
+ end_before: datetime,
+ last_seen: tuple[datetime, str] | None,
+ batch_size: int,
+ run_types: Sequence[WorkflowType] | None = None,
+ tenant_ids: Sequence[str] | None = None,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Fetch ended workflow runs in a time window for archival and clean batching.
+
+ Query scope:
+ - created_at in [start_from, end_before)
+ - type in run_types (when provided)
+ - status is an ended state
+ - optional tenant_id filter and cursor (last_seen) for pagination
+ """
+ with self._session_maker() as session:
+ stmt = (
+ select(WorkflowRun)
+ .where(
+ WorkflowRun.created_at < end_before,
+ WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()),
+ )
+ .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
+ .limit(batch_size)
+ )
+ if run_types is not None:
+ if not run_types:
+ return []
+ stmt = stmt.where(WorkflowRun.type.in_(run_types))
+
+ if start_from:
+ stmt = stmt.where(WorkflowRun.created_at >= start_from)
+
+ if tenant_ids:
+ stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids))
+
+ if last_seen:
+ stmt = stmt.where(
+ or_(
+ WorkflowRun.created_at > last_seen[0],
+ and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
+ )
+ )
+
+ return session.scalars(stmt).all()
+
+ def get_archived_run_ids(
+ self,
+ session: Session,
+ run_ids: Sequence[str],
+ ) -> set[str]:
+ if not run_ids:
+ return set()
+
+ stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids))
+ return set(session.scalars(stmt).all())
+
+ def get_archived_log_by_run_id(
+ self,
+ run_id: str,
+ ) -> WorkflowArchiveLog | None:
+ with self._session_maker() as session:
+ stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1)
+ return session.scalar(stmt)
+
+ def delete_archive_log_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> int:
+ stmt = delete(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id)
+ result = session.execute(stmt)
+ return cast(CursorResult, result).rowcount or 0
+
+ def get_pause_records_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowPause]:
+ stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def get_pause_reason_records_by_run_id(
+ self,
+ session: Session,
+ pause_ids: Sequence[str],
+ ) -> Sequence[WorkflowPauseReason]:
+ if not pause_ids:
+ return []
+
+ stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ return list(session.scalars(stmt))
+
+ def delete_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ if not runs:
+ return {
+ "runs": 0,
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ with self._session_maker() as session:
+ run_ids = [run.id for run in runs]
+ if delete_node_executions:
+ node_executions_deleted, offloads_deleted = delete_node_executions(session, runs)
+ else:
+ node_executions_deleted, offloads_deleted = 0, 0
+
+ app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)))
+ app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0
+
+ pause_stmt = select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids))
+ pause_ids = session.scalars(pause_stmt).all()
+ pause_reasons_deleted = 0
+ pauses_deleted = 0
+
+ if pause_ids:
+ pause_reasons_result = session.execute(
+ delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ )
+ pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0
+ pauses_result = session.execute(delete(WorkflowPause).where(WorkflowPause.id.in_(pause_ids)))
+ pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
+
+ trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
+
+ runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)))
+ runs_deleted = cast(CursorResult, runs_result).rowcount or 0
+
+ session.commit()
+
+ return {
+ "runs": runs_deleted,
+ "node_executions": node_executions_deleted,
+ "offloads": offloads_deleted,
+ "app_logs": app_logs_deleted,
+ "trigger_logs": trigger_logs_deleted,
+ "pauses": pauses_deleted,
+ "pause_reasons": pause_reasons_deleted,
+ }
+
+ def get_app_logs_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowAppLog]:
+ stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def create_archive_logs(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ app_logs: Sequence[WorkflowAppLog],
+ trigger_metadata: str | None,
+ ) -> int:
+ if not app_logs:
+ archive_log = WorkflowArchiveLog(
+ log_id=None,
+ log_created_at=None,
+ log_created_from=None,
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_id=run.workflow_id,
+ workflow_run_id=run.id,
+ created_by_role=run.created_by_role,
+ created_by=run.created_by,
+ run_version=run.version,
+ run_status=run.status,
+ run_triggered_from=run.triggered_from,
+ run_error=run.error,
+ run_elapsed_time=run.elapsed_time,
+ run_total_tokens=run.total_tokens,
+ run_total_steps=run.total_steps,
+ run_created_at=run.created_at,
+ run_finished_at=run.finished_at,
+ run_exceptions_count=run.exceptions_count,
+ trigger_metadata=trigger_metadata,
+ )
+ session.add(archive_log)
+ return 1
+
+ archive_logs = [
+ WorkflowArchiveLog(
+ log_id=app_log.id,
+ log_created_at=app_log.created_at,
+ log_created_from=app_log.created_from,
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_id=run.workflow_id,
+ workflow_run_id=run.id,
+ created_by_role=run.created_by_role,
+ created_by=run.created_by,
+ run_version=run.version,
+ run_status=run.status,
+ run_triggered_from=run.triggered_from,
+ run_error=run.error,
+ run_elapsed_time=run.elapsed_time,
+ run_total_tokens=run.total_tokens,
+ run_total_steps=run.total_steps,
+ run_created_at=run.created_at,
+ run_finished_at=run.finished_at,
+ run_exceptions_count=run.exceptions_count,
+ trigger_metadata=trigger_metadata,
+ )
+ for app_log in app_logs
+ ]
+ session.add_all(archive_logs)
+ return len(archive_logs)
+
+ def get_archived_runs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Retrieves WorkflowRun records by joining workflow_archive_logs.
+
+ Used to identify runs that are already archived and ready for deletion.
+ """
+ stmt = (
+ select(WorkflowRun)
+ .join(WorkflowArchiveLog, WorkflowArchiveLog.workflow_run_id == WorkflowRun.id)
+ .where(
+ WorkflowArchiveLog.run_created_at >= start_date,
+ WorkflowArchiveLog.run_created_at < end_date,
+ )
+ .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc())
+ .limit(limit)
+ )
+ if tenant_ids:
+ stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
+ return list(session.scalars(stmt))
+
+ def get_archived_logs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowArchiveLog]:
+ # Returns WorkflowArchiveLog rows directly; use this when workflow_runs may be deleted.
+ stmt = (
+ select(WorkflowArchiveLog)
+ .where(
+ WorkflowArchiveLog.run_created_at >= start_date,
+ WorkflowArchiveLog.run_created_at < end_date,
+ )
+ .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc())
+ .limit(limit)
+ )
+ if tenant_ids:
+ stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
+ return list(session.scalars(stmt))
+
+ def count_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ if not runs:
+ return {
+ "runs": 0,
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ with self._session_maker() as session:
+ run_ids = [run.id for run in runs]
+ if count_node_executions:
+ node_executions_count, offloads_count = count_node_executions(session, runs)
+ else:
+ node_executions_count, offloads_count = 0, 0
+
+ app_logs_count = (
+ session.scalar(
+ select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))
+ )
+ or 0
+ )
+
+ pause_ids = session.scalars(
+ select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids))
+ ).all()
+ pauses_count = len(pause_ids)
+ pause_reasons_count = 0
+ if pause_ids:
+ pause_reasons_count = (
+ session.scalar(
+ select(func.count())
+ .select_from(WorkflowPauseReason)
+ .where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ )
+ or 0
+ )
+
+ trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0
+
+ return {
+ "runs": len(runs),
+ "node_executions": node_executions_count,
+ "offloads": offloads_count,
+ "app_logs": int(app_logs_count),
+ "trigger_logs": trigger_logs_count,
+ "pauses": pauses_count,
+ "pause_reasons": int(pause_reasons_count),
+ }
+
def create_workflow_pause(
self,
workflow_run_id: str,
@@ -405,9 +733,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
RuntimeError: If workflow is already paused or in invalid state
"""
- previous_pause_model_query = select(WorkflowPauseModel).where(
- WorkflowPauseModel.workflow_run_id == workflow_run_id
- )
+ previous_pause_model_query = select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id)
with self._session_maker() as session, session.begin():
# Get the workflow run
workflow_run = session.get(WorkflowRun, workflow_run_id)
@@ -434,7 +760,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Upload the state file
# Create the pause record
- pause_model = WorkflowPauseModel()
+ pause_model = WorkflowPause()
pause_model.id = str(uuidv7())
pause_model.workflow_id = workflow_run.workflow_id
pause_model.workflow_run_id = workflow_run.id
@@ -643,13 +969,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
"""
with self._session_maker() as session, session.begin():
# Get the pause model by ID
- pause_model = session.get(WorkflowPauseModel, pause_entity.id)
+ pause_model = session.get(WorkflowPause, pause_entity.id)
if pause_model is None:
raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}")
self._delete_pause_model(session, pause_model)
@staticmethod
- def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel):
+ def _delete_pause_model(session: Session, pause_model: WorkflowPause):
storage.delete(pause_model.state_object_key)
# Delete the pause record
@@ -684,15 +1010,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
_limit: int = limit or 1000
pruned_record_ids: list[str] = []
cond = or_(
- WorkflowPauseModel.created_at < expiration,
+ WorkflowPause.created_at < expiration,
and_(
- WorkflowPauseModel.resumed_at.is_not(null()),
- WorkflowPauseModel.resumed_at < resumption_expiration,
+ WorkflowPause.resumed_at.is_not(null()),
+ WorkflowPause.resumed_at < resumption_expiration,
),
)
# First, collect pause records to delete with their state files
# Expired pauses (created before expiration time)
- stmt = select(WorkflowPauseModel).where(cond).limit(_limit)
+ stmt = select(WorkflowPause).where(cond).limit(_limit)
with self._session_maker(expire_on_commit=False) as session:
# Old resumed pauses (resumed more than resumption_duration ago)
@@ -703,7 +1029,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Delete state files from storage
for pause in pauses_to_delete:
with self._session_maker(expire_on_commit=False) as session, session.begin():
- # todo: this issues a separate query for each WorkflowPauseModel record.
+ # todo: this issues a separate query for each WorkflowPause record.
# consider batching this lookup.
try:
storage.delete(pause.state_object_key)
@@ -964,7 +1290,7 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
def __init__(
self,
*,
- pause_model: WorkflowPauseModel,
+ pause_model: WorkflowPause,
reason_models: Sequence[WorkflowPauseReason],
pause_reasons: Sequence[PauseReason] | None = None,
human_input_form: Sequence = (),
diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
index c828cc60c2..1f6740b066 100644
--- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
+++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
@@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository.
from collections.abc import Sequence
from datetime import UTC, datetime, timedelta
+from typing import cast
-from sqlalchemy import and_, select
+from sqlalchemy import and_, delete, func, select
+from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from models.enums import WorkflowTriggerStatus
@@ -44,6 +46,11 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
return self.session.scalar(query)
+ def list_by_run_id(self, run_id: str) -> Sequence[WorkflowTriggerLog]:
+ """List trigger logs for a workflow run."""
+ query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id)
+ return list(self.session.scalars(query).all())
+
def get_failed_for_retry(
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> Sequence[WorkflowTriggerLog]:
@@ -94,3 +101,37 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
.limit(1)
)
return self.session.scalar(query)
+
+ def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Delete trigger logs associated with the given workflow run ids.
+
+ Args:
+ run_ids: Collection of workflow run identifiers.
+
+ Returns:
+ Number of rows deleted.
+ """
+ if not run_ids:
+ return 0
+
+ result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)))
+ return cast(CursorResult, result).rowcount or 0
+
+ def count_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Count trigger logs associated with the given workflow run ids.
+
+ Args:
+ run_ids: Collection of workflow run identifiers.
+
+ Returns:
+ Number of rows matched.
+ """
+ if not run_ids:
+ return 0
+
+ count = self.session.scalar(
+ select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))
+ )
+ return int(count or 0)
diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py
index e78f1db532..7f9e6b7b68 100644
--- a/api/repositories/workflow_trigger_log_repository.py
+++ b/api/repositories/workflow_trigger_log_repository.py
@@ -121,3 +121,15 @@ class WorkflowTriggerLogRepository(Protocol):
The matching WorkflowTriggerLog if present, None otherwise
"""
...
+
+ def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Delete trigger logs for workflow run IDs.
+
+ Args:
+ run_ids: Workflow run IDs to delete
+
+ Returns:
+ Number of rows deleted
+ """
+ ...
diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py
index 352a84b592..be5f483b95 100644
--- a/api/schedule/clean_messages.py
+++ b/api/schedule/clean_messages.py
@@ -1,90 +1,78 @@
-import datetime
import logging
import time
import click
-from sqlalchemy.exc import SQLAlchemyError
+from redis.exceptions import LockError
import app
from configs import dify_config
-from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
-from models.model import (
- App,
- Message,
- MessageAgentThought,
- MessageAnnotation,
- MessageChain,
- MessageFeedback,
- MessageFile,
-)
-from models.web import SavedMessage
-from services.feature_service import FeatureService
+from services.retention.conversation.messages_clean_policy import create_message_clean_policy
+from services.retention.conversation.messages_clean_service import MessagesCleanService
logger = logging.getLogger(__name__)
-@app.celery.task(queue="dataset")
+@app.celery.task(queue="retention")
def clean_messages():
- click.echo(click.style("Start clean messages.", fg="green"))
- start_at = time.perf_counter()
- plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
- days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
- )
- while True:
- try:
- # Main query with join and filter
- messages = (
- db.session.query(Message)
- .where(Message.created_at < plan_sandbox_clean_message_day)
- .order_by(Message.created_at.desc())
- .limit(100)
- .all()
- )
+ """
+ Clean expired messages based on clean policy.
- except SQLAlchemyError:
- raise
- if not messages:
- break
- for message in messages:
- app = db.session.query(App).filter_by(id=message.app_id).first()
- if not app:
- logger.warning(
- "Expected App record to exist, but none was found, app_id=%s, message_id=%s",
- message.app_id,
- message.id,
- )
- continue
- features_cache_key = f"features:{app.tenant_id}"
- plan_cache = redis_client.get(features_cache_key)
- if plan_cache is None:
- features = FeatureService.get_features(app.tenant_id)
- redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
- plan = features.billing.subscription.plan
- else:
- plan = plan_cache.decode()
- if plan == CloudPlan.SANDBOX:
- # clean related message
- db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(Message).where(Message.id == message.id).delete()
- db.session.commit()
- end_at = time.perf_counter()
- click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
+ This task uses MessagesCleanService to efficiently clean messages in batches.
+ The behavior depends on BILLING_ENABLED configuration:
+ - BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period)
+ - BILLING_ENABLED=False: delete all messages within the time range
+ """
+ click.echo(click.style("clean_messages: start clean messages.", fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ # Create policy based on billing configuration
+ policy = create_message_clean_policy(
+ graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
+ )
+
+ # Create and run the cleanup service
+ # lock the task to avoid concurrent execution in case of the future data volume growth
+ with redis_client.lock(
+ "retention:clean_messages", timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL, blocking=False
+ ):
+ service = MessagesCleanService.from_days(
+ policy=policy,
+ days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
+ batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
+ )
+ stats = service.run()
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ f"clean_messages: completed successfully\n"
+ f" - Latency: {end_at - start_at:.2f}s\n"
+ f" - Batches processed: {stats['batches']}\n"
+ f" - Total messages scanned: {stats['total_messages']}\n"
+ f" - Messages filtered: {stats['filtered_messages']}\n"
+ f" - Messages deleted: {stats['total_deleted']}",
+ fg="green",
+ )
+ )
+ except LockError:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages: acquire task lock failed, skip current execution")
+ click.echo(
+ click.style(
+ f"clean_messages: skipped (lock already held) - latency: {end_at - start_at:.2f}s",
+ fg="yellow",
+ )
+ )
+ raise
+ except Exception as e:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages failed")
+ click.echo(
+ click.style(
+ f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py
new file mode 100644
index 0000000000..ff45a3ddf2
--- /dev/null
+++ b/api/schedule/clean_workflow_runs_task.py
@@ -0,0 +1,79 @@
+import logging
+from datetime import UTC, datetime
+
+import click
+from redis.exceptions import LockError
+
+import app
+from configs import dify_config
+from extensions.ext_redis import redis_client
+from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
+
+logger = logging.getLogger(__name__)
+
+
+@app.celery.task(queue="retention")
+def clean_workflow_runs_task() -> None:
+ """
+ Scheduled cleanup for workflow runs and related records (sandbox tenants only).
+ """
+ click.echo(
+ click.style(
+ (
+ "Scheduled workflow run cleanup starting: "
+ f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, "
+ f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}"
+ ),
+ fg="green",
+ )
+ )
+
+ start_time = datetime.now(UTC)
+
+ try:
+ # lock the task to avoid concurrent execution in case of the future data volume growth
+ with redis_client.lock(
+ "retention:clean_workflow_runs_task",
+ timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL,
+ blocking=False,
+ ):
+ WorkflowRunCleanup(
+ days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
+ batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
+ start_from=None,
+ end_before=None,
+ ).run()
+
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+ except LockError:
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ logger.exception("clean_workflow_runs_task: acquire task lock failed, skip current execution")
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup skipped (lock already held). "
+ f"start={start_time.isoformat()} end={end_time.isoformat()} duration={elapsed}",
+ fg="yellow",
+ )
+ )
+ raise
+ except Exception as e:
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ logger.exception("clean_workflow_runs_task failed")
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup failed. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed} - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py
index c343063fae..ed46c1c70a 100644
--- a/api/schedule/create_tidb_serverless_task.py
+++ b/api/schedule/create_tidb_serverless_task.py
@@ -50,10 +50,13 @@ def create_clusters(batch_size):
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
+ tenant_id=None,
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
+ active=False,
+ status="CREATING",
)
db.session.add(tidb_auth_binding)
db.session.commit()
diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py
index db610df290..77d6b5a138 100644
--- a/api/schedule/queue_monitor_task.py
+++ b/api/schedule/queue_monitor_task.py
@@ -16,6 +16,11 @@ celery_redis = Redis(
port=redis_config.get("port") or 6379,
password=redis_config.get("password") or None,
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
+ ssl=bool(dify_config.BROKER_USE_SSL),
+ ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
+ ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
+ ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
+ ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
)
logger = logging.getLogger(__name__)
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 5a549dc318..35e4a505af 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -748,6 +748,21 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
+ @staticmethod
+ def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
+ """
+ Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
+
+ This keeps backward compatibility for older records that stored uppercase emails while the
+ rest of the system gradually normalizes new inputs.
+ """
+ query_session = session or db.session
+ account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ if account or email == email.lower():
+ return account
+
+ return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
+
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@@ -999,6 +1014,11 @@ class TenantService:
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
+
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.create_default_pool(tenant.id)
+
return tenant
@staticmethod
@@ -1358,16 +1378,27 @@ class RegisterService:
if not inviter:
raise ValueError("Inviter is required")
+ normalized_email = email.lower()
+
"""Invite new member"""
+ # Check workspace permission for member invitations
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(tenant.id)
+
with Session(db.engine) as session:
- account = session.query(Account).filter_by(email=email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
- name = email.split("@")[0]
+ name = normalized_email.split("@")[0]
account = cls.register(
- email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
+ email=normalized_email,
+ name=name,
+ language=language,
+ status=AccountStatus.PENDING,
+ is_setup=True,
)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
@@ -1389,7 +1420,7 @@ class RegisterService:
# send email
send_invite_member_mail_task.delay(
language=language,
- to=email,
+ to=account.email,
token=token,
inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name,
@@ -1488,6 +1519,16 @@ class RegisterService:
invitation: dict = json.loads(data)
return invitation
+ @classmethod
+ def get_invitation_with_case_fallback(
+ cls, workspace_id: str | None, email: str | None, token: str
+ ) -> dict[str, Any] | None:
+ invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
+ if invitation or not email or email == email.lower():
+ return invitation
+ normalized_email = email.lower()
+ return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
+
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index d03cbddceb..56e9cc6a00 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -77,7 +77,7 @@ class AppAnnotationService:
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
- annotation.question,
+ question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
@@ -137,13 +137,16 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
if keyword:
+ from libs.helper import escape_like_pattern
+
+ escaped_keyword = escape_like_pattern(keyword)
stmt = (
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.where(
or_(
- MessageAnnotation.question.ilike(f"%{keyword}%"),
- MessageAnnotation.content.ilike(f"%{keyword}%"),
+ MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"),
+ MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
@@ -206,8 +209,12 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
+ question = args.get("question")
+ if question is None:
+ raise ValueError("'question' is required")
+
annotation = MessageAnnotation(
- app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
+ app_id=app.id, content=args["answer"], question=question, account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
@@ -216,7 +223,7 @@ class AppAnnotationService:
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
- args["question"],
+ question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
@@ -241,8 +248,12 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
+ question = args.get("question")
+ if question is None:
+ raise ValueError("'question' is required")
+
annotation.content = args["answer"]
- annotation.question = args["question"]
+ annotation.question = question
db.session.commit()
# if annotation reply is enabled , add annotation to index
@@ -253,7 +264,7 @@ class AppAnnotationService:
if app_annotation_setting:
update_annotation_to_index_task.delay(
annotation.id,
- annotation.question,
+ annotation.question_text,
current_tenant_id,
app_id,
app_annotation_setting.collection_binding_id,
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index deba0b79e8..0f42c99246 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
-from models.model import AppModelConfig
+from models.model import AppModelConfig, IconType
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
@@ -428,10 +428,10 @@ class AppDslService:
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")
- if icon_type_value in ["emoji", "link", "image"]:
+ if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]:
icon_type = icon_type_value
else:
- icon_type = "emoji"
+ icon_type = IconType.EMOJI
icon = icon or str(app_data.get("icon", ""))
if app:
@@ -521,12 +521,10 @@ class AppDslService:
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
- app_model_config = AppModelConfig().from_model_config_dict(model_config)
+ app_model_config = AppModelConfig(
+ app_id=app.id, created_by=account.id, updated_by=account.id
+ ).from_model_config_dict(model_config)
app_model_config.id = str(uuid4())
- app_model_config.app_id = app.id
- app_model_config.created_by = account.id
- app_model_config.updated_by = account.id
-
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config)
@@ -783,15 +781,16 @@ class AppDslService:
return dependencies
@classmethod
- def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
+ def get_leaked_dependencies(
+ cls, tenant_id: str, dsl_dependencies: list[PluginDependency]
+ ) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
- dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
- if not dependencies:
+ if not dsl_dependencies:
return []
- return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
+ return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies)
@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py
index ce64f6ac84..a3c3471982 100644
--- a/api/services/app_generate_service.py
+++ b/api/services/app_generate_service.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import logging
import threading
import uuid
from collections.abc import Callable, Generator, Mapping
-from typing import Any, Union
+from typing import TYPE_CHECKING, Any, Union
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@@ -18,7 +20,8 @@ from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
-from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
+from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
+from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, chatflow_execute_task
@@ -26,6 +29,9 @@ logger = logging.getLogger(__name__)
SSE_TASK_START_FALLBACK_MS = 200
+if TYPE_CHECKING:
+ from controllers.console.app.workflow import LoopNodeRunPayload
+
class AppGenerateService:
@staticmethod
@@ -243,7 +249,9 @@ class AppGenerateService:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
- def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
+ def generate_single_loop(
+ cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
+ ):
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
diff --git a/api/services/app_service.py b/api/services/app_service.py
index ef89a4fd10..af458ff618 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -55,8 +55,11 @@ class AppService:
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
+ from libs.helper import escape_like_pattern
+
name = args["name"][:30]
- filters.append(App.name.ilike(f"%{name}%"))
+ escaped_name = escape_like_pattern(name)
+ filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
@@ -147,10 +150,9 @@ class AppService:
db.session.flush()
if default_model_config:
- app_model_config = AppModelConfig(**default_model_config)
- app_model_config.app_id = app.id
- app_model_config.created_by = account.id
- app_model_config.updated_by = account.id
+ app_model_config = AppModelConfig(
+ **default_model_config, app_id=app.id, created_by=account.id, updated_by=account.id
+ )
db.session.add(app_model_config)
db.session.flush()
diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py
index e100582511..bc73b7c8c2 100644
--- a/api/services/async_workflow_service.py
+++ b/api/services/async_workflow_service.py
@@ -21,7 +21,7 @@ from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
-from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
+from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@@ -141,7 +141,7 @@ class AsyncWorkflowService:
trigger_log_repo.update(trigger_log)
session.commit()
- raise InvokeRateLimitError(
+ raise WorkflowQuotaLimitError(
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
) from e
diff --git a/api/services/billing_service.py b/api/services/billing_service.py
index 3d7cb6cc8d..946b8cdfdb 100644
--- a/api/services/billing_service.py
+++ b/api/services/billing_service.py
@@ -1,3 +1,4 @@
+import json
import logging
import os
from collections.abc import Sequence
@@ -31,6 +32,11 @@ class BillingService:
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
+ # Redis key prefix for tenant plan cache
+ _PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
+ # Cache TTL: 10 minutes
+ _PLAN_CACHE_TTL = 600
+
@classmethod
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
@@ -125,7 +131,7 @@ class BillingService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
- response = httpx.request(method, url, json=json, params=params, headers=headers)
+ response = httpx.request(method, url, json=json, params=params, headers=headers, follow_redirects=True)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
if method == "PUT":
@@ -137,6 +143,9 @@ class BillingService:
raise ValueError("Invalid arguments.")
if method == "POST" and response.status_code != httpx.codes.OK:
raise ValueError(f"Unable to send request to {url}. Please try again later or contact support.")
+ if method == "DELETE" and response.status_code != httpx.codes.OK:
+ logger.error("billing_service: DELETE response: %s %s", response.status_code, response.text)
+ raise ValueError(f"Unable to process delete request {url}. Please try again later or contact support.")
return response.json()
@staticmethod
@@ -159,7 +168,7 @@ class BillingService:
def delete_account(cls, account_id: str):
"""Delete account."""
params = {"account_id": account_id}
- return cls._send_request("DELETE", "/account/", params=params)
+ return cls._send_request("DELETE", "/account", params=params)
@classmethod
def is_email_in_freeze(cls, email: str) -> bool:
@@ -272,14 +281,110 @@ class BillingService:
data = resp.get("data", {})
for tenant_id, plan in data.items():
- subscription_plan = subscription_adapter.validate_python(plan)
- results[tenant_id] = subscription_plan
+ try:
+ subscription_plan = subscription_adapter.validate_python(plan)
+ results[tenant_id] = subscription_plan
+ except Exception:
+ logger.exception(
+ "get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id
+ )
+ continue
except Exception:
- logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
+ logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk)
continue
return results
+ @classmethod
+ def _make_plan_cache_key(cls, tenant_id: str) -> str:
+ return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}"
+
+ @classmethod
+ def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
+ """
+ Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios.
+
+ NOTE: if you want to high data consistency, use get_plan_bulk instead.
+
+ Returns:
+ Mapping of tenant_id -> {plan: str, expiration_date: int}
+ """
+ tenant_plans: dict[str, SubscriptionPlan] = {}
+
+ if not tenant_ids:
+ return tenant_plans
+
+ subscription_adapter = TypeAdapter(SubscriptionPlan)
+
+ # Step 1: Batch fetch from Redis cache using mget
+ redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids]
+ try:
+ cached_values = redis_client.mget(redis_keys)
+
+ if len(cached_values) != len(tenant_ids):
+ raise Exception(
+ "get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch"
+ )
+
+ # Map cached values back to tenant_ids
+ cache_misses: list[str] = []
+
+ for tenant_id, cached_value in zip(tenant_ids, cached_values):
+ if cached_value:
+ try:
+ # Redis returns bytes, decode to string and parse JSON
+ json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
+ plan_dict = json.loads(json_str)
+ subscription_plan = subscription_adapter.validate_python(plan_dict)
+ tenant_plans[tenant_id] = subscription_plan
+ except Exception:
+ logger.exception(
+ "get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id
+ )
+ cache_misses.append(tenant_id)
+ else:
+ cache_misses.append(tenant_id)
+
+ logger.info(
+ "get_plan_bulk_with_cache: cache hits=%s, cache misses=%s",
+ len(tenant_plans),
+ len(cache_misses),
+ )
+ except Exception:
+ logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API")
+ cache_misses = list(tenant_ids)
+
+ # Step 2: Fetch missing plans from billing API
+ if cache_misses:
+ bulk_plans = BillingService.get_plan_bulk(cache_misses)
+
+ if bulk_plans:
+ plans_to_cache: dict[str, SubscriptionPlan] = {}
+
+ for tenant_id, subscription_plan in bulk_plans.items():
+ tenant_plans[tenant_id] = subscription_plan
+ plans_to_cache[tenant_id] = subscription_plan
+
+ # Step 3: Batch update Redis cache using pipeline
+ if plans_to_cache:
+ try:
+ pipe = redis_client.pipeline()
+ for tenant_id, subscription_plan in plans_to_cache.items():
+ redis_key = cls._make_plan_cache_key(tenant_id)
+ # Serialize dict to JSON string
+ json_str = json.dumps(subscription_plan)
+ pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str)
+ pipe.execute()
+
+ logger.info(
+ "get_plan_bulk_with_cache: cached %s new tenant plans to Redis",
+ len(plans_to_cache),
+ )
+ except Exception:
+ logger.exception("get_plan_bulk_with_cache: redis pipeline failed")
+
+ return tenant_plans
+
@classmethod
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py
index 659e7406fb..295d48d8a1 100644
--- a/api/services/conversation_service.py
+++ b/api/services/conversation_service.py
@@ -11,13 +11,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.variables.types import SegmentType
-from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from extensions.ext_database import db
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account, ConversationVariable
from models.model import App, Conversation, EndUser, Message
+from services.conversation_variable_updater import ConversationVariableUpdater
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
@@ -218,7 +218,9 @@ class ConversationService:
# Apply variable_name filter if provided
if variable_name:
# Filter using JSON extraction to match variable names case-insensitively
- escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
+ from libs.helper import escape_like_pattern
+
+ escaped_variable_name = escape_like_pattern(variable_name)
# Filter using JSON extraction to match variable names case-insensitively
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
stmt = stmt.where(
@@ -335,7 +337,7 @@ class ConversationService:
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
# Use the conversation variable updater to persist the changes
- updater = conversation_variable_updater_factory()
+ updater = ConversationVariableUpdater(session_factory.get_session_maker())
updater.update(conversation_id, updated_variable)
updater.flush()
diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py
new file mode 100644
index 0000000000..92008d5ff1
--- /dev/null
+++ b/api/services/conversation_variable_updater.py
@@ -0,0 +1,28 @@
+from sqlalchemy import select
+from sqlalchemy.orm import Session, sessionmaker
+
+from core.variables.variables import VariableBase
+from models import ConversationVariable
+
+
+class ConversationVariableNotFoundError(Exception):
+ pass
+
+
+class ConversationVariableUpdater:
+ def __init__(self, session_maker: sessionmaker[Session]) -> None:
+ self._session_maker: sessionmaker[Session] = session_maker
+
+ def update(self, conversation_id: str, variable: VariableBase) -> None:
+ stmt = select(ConversationVariable).where(
+ ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
+ )
+ with self._session_maker() as session:
+ row = session.scalar(stmt)
+ if not row:
+ raise ConversationVariableNotFoundError("conversation variable not found in the database")
+ row.data = variable.model_dump_json()
+ session.commit()
+
+ def flush(self) -> None:
+ pass
diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py
new file mode 100644
index 0000000000..1954602571
--- /dev/null
+++ b/api/services/credit_pool_service.py
@@ -0,0 +1,85 @@
+import logging
+
+from sqlalchemy import update
+from sqlalchemy.orm import Session
+
+from configs import dify_config
+from core.errors.error import QuotaExceededError
+from extensions.ext_database import db
+from models import TenantCreditPool
+
+logger = logging.getLogger(__name__)
+
+
+class CreditPoolService:
+ @classmethod
+ def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
+ """create default credit pool for new tenant"""
+ credit_pool = TenantCreditPool(
+ tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
+ )
+ db.session.add(credit_pool)
+ db.session.commit()
+ return credit_pool
+
+ @classmethod
+ def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
+ """get tenant credit pool"""
+ return (
+ db.session.query(TenantCreditPool)
+ .filter_by(
+ tenant_id=tenant_id,
+ pool_type=pool_type,
+ )
+ .first()
+ )
+
+ @classmethod
+ def check_credits_available(
+ cls,
+ tenant_id: str,
+ credits_required: int,
+ pool_type: str = "trial",
+ ) -> bool:
+ """check if credits are available without deducting"""
+ pool = cls.get_pool(tenant_id, pool_type)
+ if not pool:
+ return False
+ return pool.remaining_credits >= credits_required
+
+ @classmethod
+ def check_and_deduct_credits(
+ cls,
+ tenant_id: str,
+ credits_required: int,
+ pool_type: str = "trial",
+ ) -> int:
+ """check and deduct credits, returns actual credits deducted"""
+
+ pool = cls.get_pool(tenant_id, pool_type)
+ if not pool:
+ raise QuotaExceededError("Credit pool not found")
+
+ if pool.remaining_credits <= 0:
+ raise QuotaExceededError("No credits remaining")
+
+ # deduct all remaining credits if less than required
+ actual_credits = min(credits_required, pool.remaining_credits)
+
+ try:
+ with Session(db.engine) as session:
+ stmt = (
+ update(TenantCreditPool)
+ .where(
+ TenantCreditPool.tenant_id == tenant_id,
+ TenantCreditPool.pool_type == pool_type,
+ )
+ .values(quota_used=TenantCreditPool.quota_used + actual_credits)
+ )
+ session.execute(stmt)
+ session.commit()
+ except Exception:
+ logger.exception("Failed to deduct credits for tenant %s", tenant_id)
+ raise QuotaExceededError("Failed to deduct credits")
+
+ return actual_credits
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index ac4b25c5dc..be9a0e9279 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -13,10 +13,11 @@ import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
-from werkzeug.exceptions import NotFound
+from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
+from core.file import helpers as file_helpers
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
@@ -73,6 +74,7 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
+from services.file_service import FileService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.tag_service import TagService
from services.vector_service import VectorService
@@ -144,7 +146,8 @@ class DatasetService:
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search:
- query = query.where(Dataset.name.ilike(f"%{search}%"))
+ escaped_search = helper.escape_like_pattern(search)
+ query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
@@ -1161,6 +1164,7 @@ class DocumentService:
Document.archived.is_(True),
),
}
+ DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION = ".zip"
@classmethod
def normalize_display_status(cls, status: str | None) -> str | None:
@@ -1287,6 +1291,143 @@ class DocumentService:
else:
return None
+ @staticmethod
+ def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequence[Document]:
+ """Fetch documents for a dataset in a single batch query."""
+ if not document_ids:
+ return []
+ document_id_list: list[str] = [str(document_id) for document_id in document_ids]
+ # Fetch all requested documents in one query to avoid N+1 lookups.
+ documents: Sequence[Document] = db.session.scalars(
+ select(Document).where(
+ Document.dataset_id == dataset_id,
+ Document.id.in_(document_id_list),
+ )
+ ).all()
+ return documents
+
+ @staticmethod
+ def get_document_download_url(document: Document) -> str:
+ """
+ Return a signed download URL for an upload-file document.
+ """
+ upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
+ return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
+
+ @staticmethod
+ def prepare_document_batch_download_zip(
+ *,
+ dataset_id: str,
+ document_ids: Sequence[str],
+ tenant_id: str,
+ current_user: Account,
+ ) -> tuple[list[UploadFile], str]:
+ """
+ Resolve upload files for batch ZIP downloads and generate a client-visible filename.
+ """
+ dataset = DatasetService.get_dataset(dataset_id)
+ if not dataset:
+ raise NotFound("Dataset not found.")
+ try:
+ DatasetService.check_dataset_permission(dataset, current_user)
+ except NoPermissionError as e:
+ raise Forbidden(str(e))
+
+ upload_files_by_document_id = DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset_id,
+ document_ids=document_ids,
+ tenant_id=tenant_id,
+ )
+ upload_files = [upload_files_by_document_id[document_id] for document_id in document_ids]
+ download_name = DocumentService._generate_document_batch_download_zip_filename()
+ return upload_files, download_name
+
+ @staticmethod
+ def _generate_document_batch_download_zip_filename() -> str:
+ """
+ Generate a random attachment filename for the batch download ZIP.
+ """
+ return f"{uuid.uuid4().hex}{DocumentService.DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION}"
+
+ @staticmethod
+ def _get_upload_file_id_for_upload_file_document(
+ document: Document,
+ *,
+ invalid_source_message: str,
+ missing_file_message: str,
+ ) -> str:
+ """
+ Normalize and validate `Document -> UploadFile` linkage for download flows.
+ """
+ if document.data_source_type != "upload_file":
+ raise NotFound(invalid_source_message)
+
+ data_source_info: dict[str, Any] = document.data_source_info_dict or {}
+ upload_file_id: str | None = data_source_info.get("upload_file_id")
+ if not upload_file_id:
+ raise NotFound(missing_file_message)
+
+ return str(upload_file_id)
+
+ @staticmethod
+ def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile:
+ """
+ Load the `UploadFile` row for an upload-file document.
+ """
+ upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="Document does not have an uploaded file to download.",
+ missing_file_message="Uploaded file not found.",
+ )
+ upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id])
+ upload_file = upload_files_by_id.get(upload_file_id)
+ if not upload_file:
+ raise NotFound("Uploaded file not found.")
+ return upload_file
+
+ @staticmethod
+ def _get_upload_files_by_document_id_for_zip_download(
+ *,
+ dataset_id: str,
+ document_ids: Sequence[str],
+ tenant_id: str,
+ ) -> dict[str, UploadFile]:
+ """
+ Batch load upload files keyed by document id for ZIP downloads.
+ """
+ document_id_list: list[str] = [str(document_id) for document_id in document_ids]
+
+ documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list)
+ documents_by_id: dict[str, Document] = {str(document.id): document for document in documents}
+
+ missing_document_ids: set[str] = set(document_id_list) - set(documents_by_id.keys())
+ if missing_document_ids:
+ raise NotFound("Document not found.")
+
+ upload_file_ids: list[str] = []
+ upload_file_ids_by_document_id: dict[str, str] = {}
+ for document_id, document in documents_by_id.items():
+ if document.tenant_id != tenant_id:
+ raise Forbidden("No permission.")
+
+ upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="Only uploaded-file documents can be downloaded as ZIP.",
+ missing_file_message="Only uploaded-file documents can be downloaded as ZIP.",
+ )
+ upload_file_ids.append(upload_file_id)
+ upload_file_ids_by_document_id[document_id] = upload_file_id
+
+ upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)
+ missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys())
+ if missing_upload_file_ids:
+ raise NotFound("Only uploaded-file documents can be downloaded as ZIP.")
+
+ return {
+ document_id: upload_files_by_id[upload_file_id]
+ for document_id, upload_file_id in upload_file_ids_by_document_id.items()
+ }
+
@staticmethod
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()
@@ -3423,7 +3564,8 @@ class SegmentService:
.order_by(ChildChunk.position.asc())
)
if keyword:
- query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
+ escaped_keyword = helper.escape_like_pattern(keyword)
+ query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\"))
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
@@ -3456,7 +3598,8 @@ class SegmentService:
query = query.where(DocumentSegment.status.in_(status_list))
if keyword:
- query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
+ escaped_keyword = helper.escape_like_pattern(keyword)
+ query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"))
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py
index bdc960aa2d..e3832475aa 100644
--- a/api/services/enterprise/base.py
+++ b/api/services/enterprise/base.py
@@ -1,9 +1,14 @@
+import logging
import os
from collections.abc import Mapping
from typing import Any
import httpx
+from core.helper.trace_id_helper import generate_traceparent_header
+
+logger = logging.getLogger(__name__)
+
class BaseRequest:
proxies: Mapping[str, str] | None = {
@@ -38,6 +43,15 @@ class BaseRequest:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts()
+
+ try:
+ # ensure traceparent even when OTEL is disabled
+ traceparent = generate_traceparent_header()
+ if traceparent:
+ headers["traceparent"] = traceparent
+ except Exception:
+ logger.debug("Failed to generate traceparent header", exc_info=True)
+
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
return response.json()
diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py
index 83d0fcf296..a5133dfcb4 100644
--- a/api/services/enterprise/enterprise_service.py
+++ b/api/services/enterprise/enterprise_service.py
@@ -13,6 +13,23 @@ class WebAppSettings(BaseModel):
)
+class WorkspacePermission(BaseModel):
+ workspace_id: str = Field(
+ description="The ID of the workspace.",
+ alias="workspaceId",
+ )
+ allow_member_invite: bool = Field(
+ description="Whether to allow members to invite new members to the workspace.",
+ default=False,
+ alias="allowMemberInvite",
+ )
+ allow_owner_transfer: bool = Field(
+ description="Whether to allow owners to transfer ownership of the workspace.",
+ default=False,
+ alias="allowOwnerTransfer",
+ )
+
+
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -44,6 +61,16 @@ class EnterpriseService:
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
+ class WorkspacePermissionService:
+ @classmethod
+ def get_permission(cls, workspace_id: str):
+ if not workspace_id:
+ raise ValueError("workspace_id must be provided.")
+ data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
+ if not data or "permission" not in data:
+ raise ValueError("No data found.")
+ return WorkspacePermission.model_validate(data["permission"])
+
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
@@ -110,5 +137,5 @@ class EnterpriseService:
if not app_id:
raise ValueError("app_id must be provided.")
- body = {"appId": app_id}
- EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
+ params = {"appId": app_id}
+ EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
diff --git a/api/services/enterprise/workspace_sync.py b/api/services/enterprise/workspace_sync.py
new file mode 100644
index 0000000000..acfe325397
--- /dev/null
+++ b/api/services/enterprise/workspace_sync.py
@@ -0,0 +1,58 @@
+import json
+import logging
+import uuid
+from datetime import UTC, datetime
+
+from redis import RedisError
+
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
+
+WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
+WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
+
+
+class WorkspaceSyncService:
+ """Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
+
+ @staticmethod
+ def queue_credential_sync(workspace_id: str, *, source: str) -> bool:
+ """
+ Queue a credential sync task for a newly created workspace.
+
+ This publishes a task to Redis that will be consumed by the enterprise backend
+ worker to sync credentials with the plugin-manager.
+
+ Args:
+ workspace_id: The workspace/tenant ID to sync credentials for
+ source: Source of the sync request (for debugging/tracking)
+
+ Returns:
+ bool: True if task was queued successfully, False otherwise
+ """
+ try:
+ task = {
+ "task_id": str(uuid.uuid4()),
+ "workspace_id": workspace_id,
+ "retry_count": 0,
+ "created_at": datetime.now(UTC).isoformat(),
+ "source": source,
+ }
+
+ # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
+ redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task))
+
+ logger.info(
+ "Queued credential sync task for workspace %s, task_id: %s, source: %s",
+ workspace_id,
+ task["task_id"],
+ source,
+ )
+ return True
+
+ except (RedisError, TypeError) as e:
+ logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True)
+ # Don't raise - we don't want to fail workspace creation if queueing fails
+ # The scheduled task will catch it later
+ return False
diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py
index f405546909..a29d848ac5 100644
--- a/api/services/entities/model_provider_entities.py
+++ b/api/services/entities/model_provider_entities.py
@@ -70,7 +70,6 @@ class ProviderResponse(BaseModel):
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
- icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: Sequence[ModelType]
@@ -98,11 +97,6 @@ class ProviderResponse(BaseModel):
en_US=f"{url_prefix}/icon_small_dark/en_US",
zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans",
)
-
- if self.icon_large is not None:
- self.icon_large = I18nObject(
- en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
- )
return self
@@ -116,7 +110,6 @@ class ProviderWithModelsResponse(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
- icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
@@ -134,11 +127,6 @@ class ProviderWithModelsResponse(BaseModel):
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
)
-
- if self.icon_large is not None:
- self.icon_large = I18nObject(
- en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
- )
return self
@@ -163,11 +151,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
)
-
- if self.icon_large is not None:
- self.icon_large = I18nObject(
- en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
- )
return self
diff --git a/api/services/errors/app.py b/api/services/errors/app.py
index 24e4760acc..60e59e97dc 100644
--- a/api/services/errors/app.py
+++ b/api/services/errors/app.py
@@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
pass
-class InvokeRateLimitError(Exception):
- """Raised when rate limit is exceeded for workflow invocations."""
+class WorkflowQuotaLimitError(Exception):
+ """Raised when workflow execution quota is exceeded (for async/background workflows)."""
pass
diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py
index 40faa85b9a..65dd41af43 100644
--- a/api/services/external_knowledge_service.py
+++ b/api/services/external_knowledge_service.py
@@ -35,7 +35,10 @@ class ExternalDatasetService:
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
- query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
+ from libs.helper import escape_like_pattern
+
+ escaped_search = escape_like_pattern(search)
+ query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\"))
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
diff --git a/api/services/feature_service.py b/api/services/feature_service.py
index 112855a748..fda3a15144 100644
--- a/api/services/feature_service.py
+++ b/api/services/feature_service.py
@@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from enums.cloud_plan import CloudPlan
+from enums.hosted_provider import HostedTrialProvider
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
@@ -142,6 +143,7 @@ class FeatureModel(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
knowledge_pipeline: KnowledgePipeline = KnowledgePipeline()
+ next_credit_reset_date: int = 0
class KnowledgeRateLimitModel(BaseModel):
@@ -171,6 +173,9 @@ class SystemFeatureModel(BaseModel):
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
+ trial_models: list[str] = []
+ enable_trial_app: bool = False
+ enable_explore_banner: bool = False
class FeatureService:
@@ -217,7 +222,7 @@ class FeatureService:
)
@classmethod
- def get_system_features(cls) -> SystemFeatureModel:
+ def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
system_features = SystemFeatureModel()
cls._fulfill_system_params_from_env(system_features)
@@ -227,7 +232,7 @@ class FeatureService:
system_features.webapp_auth.enabled = True
system_features.enable_change_email = False
system_features.plugin_manager.enabled = True
- cls._fulfill_params_from_enterprise(system_features)
+ cls._fulfill_params_from_enterprise(system_features, is_authenticated)
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
@@ -242,6 +247,20 @@ class FeatureService:
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
+ system_features.trial_models = cls._fulfill_trial_models_from_env()
+ system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
+ system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
+
+ @classmethod
+ def _fulfill_trial_models_from_env(cls) -> list[str]:
+ return [
+ provider.value
+ for provider in HostedTrialProvider
+ if (
+ getattr(dify_config, f"HOSTED_{provider.config_key}_PAID_ENABLED", False)
+ and getattr(dify_config, f"HOSTED_{provider.config_key}_TRIAL_ENABLED", False)
+ )
+ ]
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):
@@ -319,8 +338,11 @@ class FeatureService:
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
+ if "next_credit_reset_date" in billing_info:
+ features.next_credit_reset_date = billing_info["next_credit_reset_date"]
+
@classmethod
- def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
+ def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False):
enterprise_info = EnterpriseService.get_info()
if "SSOEnforcedForSignin" in enterprise_info:
@@ -357,19 +379,14 @@ class FeatureService:
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
- if "License" in enterprise_info:
- license_info = enterprise_info["License"]
+ if is_authenticated and (license_info := enterprise_info.get("License")):
+ features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
+ features.license.expired_at = license_info.get("expiredAt", "")
- if "status" in license_info:
- features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
-
- if "expiredAt" in license_info:
- features.license.expired_at = license_info["expiredAt"]
-
- if "workspaces" in license_info:
- features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
- features.license.workspaces.limit = license_info["workspaces"]["limit"]
- features.license.workspaces.size = license_info["workspaces"]["used"]
+ if workspaces_info := license_info.get("workspaces"):
+ features.license.workspaces.enabled = workspaces_info.get("enabled", False)
+ features.license.workspaces.limit = workspaces_info.get("limit", 0)
+ features.license.workspaces.size = workspaces_info.get("used", 0)
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
diff --git a/api/services/file_service.py b/api/services/file_service.py
index 0911cf38c4..a0a99f3f82 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -2,7 +2,11 @@ import base64
import hashlib
import os
import uuid
+from collections.abc import Iterator, Sequence
+from contextlib import contextmanager, suppress
+from tempfile import NamedTemporaryFile
from typing import Literal, Union
+from zipfile import ZIP_DEFLATED, ZipFile
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
@@ -17,6 +21,7 @@ from constants import (
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
+from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
@@ -167,6 +172,9 @@ class FileService:
return upload_file
def get_file_preview(self, file_id: str):
+ """
+ Return a short text preview extracted from a document file.
+ """
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
@@ -253,3 +261,101 @@ class FileService:
return
storage.delete(upload_file.key)
session.delete(upload_file)
+
+ @staticmethod
+ def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]:
+ """
+ Fetch `UploadFile` rows for a tenant in a single batch query.
+
+ This is a generic `UploadFile` lookup helper (not dataset/document specific), so it lives in `FileService`.
+ """
+ if not upload_file_ids:
+ return {}
+
+ # Normalize and deduplicate ids before using them in the IN clause.
+ upload_file_id_list: list[str] = [str(upload_file_id) for upload_file_id in upload_file_ids]
+ unique_upload_file_ids: list[str] = list(set(upload_file_id_list))
+
+ # Fetch upload files in one query for efficient batch access.
+ upload_files: Sequence[UploadFile] = db.session.scalars(
+ select(UploadFile).where(
+ UploadFile.tenant_id == tenant_id,
+ UploadFile.id.in_(unique_upload_file_ids),
+ )
+ ).all()
+ return {str(upload_file.id): upload_file for upload_file in upload_files}
+
+ @staticmethod
+ def _sanitize_zip_entry_name(name: str) -> str:
+ """
+ Sanitize a ZIP entry name to avoid path traversal and weird separators.
+
+ We keep this conservative: the upload flow already rejects `/` and `\\`, but older rows (or imported data)
+ could still contain unsafe names.
+ """
+ # Drop any directory components and prevent empty names.
+ base = os.path.basename(name).strip() or "file"
+
+ # ZIP uses forward slashes as separators; remove any residual separator characters.
+ return base.replace("/", "_").replace("\\", "_")
+
+ @staticmethod
+ def _dedupe_zip_entry_name(original_name: str, used_names: set[str]) -> str:
+ """
+ Return a unique ZIP entry name, inserting suffixes before the extension.
+ """
+ # Keep the original name when it's not already used.
+ if original_name not in used_names:
+ return original_name
+
+ # Insert suffixes before the extension (e.g., "doc.txt" -> "doc (1).txt").
+ stem, extension = os.path.splitext(original_name)
+ suffix = 1
+ while True:
+ candidate = f"{stem} ({suffix}){extension}"
+ if candidate not in used_names:
+ return candidate
+ suffix += 1
+
+ @staticmethod
+ @contextmanager
+ def build_upload_files_zip_tempfile(
+ *,
+ upload_files: Sequence[UploadFile],
+ ) -> Iterator[str]:
+ """
+ Build a ZIP from `UploadFile`s and yield a tempfile path.
+
+ We yield a path (rather than an open file handle) to avoid "read of closed file" issues when Flask/Werkzeug
+ streams responses. The caller is expected to keep this context open until the response is fully sent, then
+ close it (e.g., via `response.call_on_close(...)`) to delete the tempfile.
+ """
+ used_names: set[str] = set()
+
+ # Build a ZIP in a temp file and keep it on disk until the caller finishes streaming it.
+ tmp_path: str | None = None
+ try:
+ with NamedTemporaryFile(mode="w+b", suffix=".zip", delete=False) as tmp:
+ tmp_path = tmp.name
+ with ZipFile(tmp, mode="w", compression=ZIP_DEFLATED) as zf:
+ for upload_file in upload_files:
+ # Ensure the entry name is safe and unique.
+ safe_name = FileService._sanitize_zip_entry_name(upload_file.name)
+ arcname = FileService._dedupe_zip_entry_name(safe_name, used_names)
+ used_names.add(arcname)
+
+ # Stream file bytes from storage into the ZIP entry.
+ with zf.open(arcname, "w") as entry:
+ for chunk in storage.load(upload_file.key, stream=True):
+ entry.write(chunk)
+
+ # Flush so `send_file(path, ...)` can re-open it safely on all platforms.
+ tmp.flush()
+
+ assert tmp_path is not None
+ yield tmp_path
+ finally:
+ # Remove the temp file when the context is closed (typically after the response finishes streaming).
+ if tmp_path is not None:
+ with suppress(FileNotFoundError):
+ os.remove(tmp_path)
diff --git a/api/services/message_service.py b/api/services/message_service.py
index 8c9e820e80..8cfa6512be 100644
--- a/api/services/message_service.py
+++ b/api/services/message_service.py
@@ -286,10 +286,9 @@ class MessageService:
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
- id=conversation.app_model_config_id,
app_id=app_model.id,
)
-
+ app_model_config.id = conversation.app_model_config_id
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if not app_model_config:
raise ValueError("did not find app model config")
diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py
index eea382febe..edd1004b82 100644
--- a/api/services/model_provider_service.py
+++ b/api/services/model_provider_service.py
@@ -99,7 +99,6 @@ class ModelProviderService:
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_small_dark=provider_configuration.provider.icon_small_dark,
- icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
@@ -423,7 +422,6 @@ class ModelProviderService:
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_small_dark=first_model.provider.icon_small_dark,
- icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE,
models=[
ProviderModelWithStatusEntity(
@@ -488,7 +486,6 @@ class ModelProviderService:
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
- icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types,
),
)
@@ -522,7 +519,7 @@ class ModelProviderService:
:param tenant_id: workspace id
:param provider: provider name
- :param icon_type: icon type (icon_small or icon_large)
+ :param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return:
"""
diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py
index 5dcbf5fec5..40565c56ed 100644
--- a/api/services/plugin/plugin_parameter_service.py
+++ b/api/services/plugin/plugin_parameter_service.py
@@ -146,7 +146,7 @@ class PluginParameterService:
provider,
action,
resolved_credentials,
- CredentialType.API_KEY.value,
+ original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
parameter,
)
.options
diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py
index b8303eb724..411c335c17 100644
--- a/api/services/plugin/plugin_service.py
+++ b/api/services/plugin/plugin_service.py
@@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
+from sqlalchemy import select
from yarl import URL
from configs import dify_config
@@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
+from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@@ -506,6 +509,33 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
+
+ # Get plugin info before uninstalling to delete associated credentials
+ try:
+ plugins = manager.list_plugins(tenant_id)
+ plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
+
+ if plugin:
+ plugin_id = plugin.plugin_id
+ logger.info("Deleting credentials for plugin: %s", plugin_id)
+
+ # Delete provider credentials that match this plugin
+ credentials = db.session.scalars(
+ select(ProviderCredential).where(
+ ProviderCredential.tenant_id == tenant_id,
+ ProviderCredential.provider_name.like(f"{plugin_id}/%"),
+ )
+ ).all()
+
+ for cred in credentials:
+ db.session.delete(cred)
+
+ db.session.commit()
+ logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
+ except Exception as e:
+ logger.warning("Failed to delete credentials: %s", e)
+ # Continue with uninstall even if credential deletion fails
+
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index f53448e7fe..ccc6abcc06 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -36,7 +36,7 @@ from core.rag.entities.event import (
)
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
-from core.variables.variables import Variable
+from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict,
unique_hash: str | None,
account: Account,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list,
) -> Workflow:
"""
@@ -436,7 +436,7 @@ class RagPipelineService:
user_inputs=user_inputs,
user_id=account.id,
variable_pool=VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs=user_inputs,
environment_variables=[],
conversation_variables=[],
@@ -874,7 +874,7 @@ class RagPipelineService:
variable_pool = node_instance.graph_runtime_state.variable_pool
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
- if invoke_from.value == InvokeFrom.PUBLISHED:
+ if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
@@ -1318,7 +1318,7 @@ class RagPipelineService:
"datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)],
"original_document_id": document.id,
},
- invoke_from=InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
streaming=False,
call_depth=0,
workflow_thread_pool_id=None,
diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py
index 06f294863d..c1c6e204fb 100644
--- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py
+++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py
@@ -870,15 +870,16 @@ class RagPipelineDslService:
return dependencies
@classmethod
- def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
+ def get_leaked_dependencies(
+ cls, tenant_id: str, dsl_dependencies: list[PluginDependency]
+ ) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
- dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
- if not dependencies:
+ if not dsl_dependencies:
return []
- return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
+ return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies)
def _generate_aes_key(self, tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py
index 84f97907c0..8ea365e907 100644
--- a/api/services/rag_pipeline/rag_pipeline_transform_service.py
+++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py
@@ -44,7 +44,7 @@ class RagPipelineTransformService:
doc_form = dataset.doc_form
if not doc_form:
return self._transform_to_empty_pipeline(dataset)
- retrieval_model = dataset.retrieval_model
+ retrieval_model = RetrievalSetting.model_validate(dataset.retrieval_model) if dataset.retrieval_model else None
pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
# deal dependencies
self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
@@ -154,7 +154,12 @@ class RagPipelineTransformService:
return node
def _deal_knowledge_index(
- self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
+ self,
+ dataset: Dataset,
+ doc_form: str,
+ indexing_technique: str | None,
+ retrieval_model: RetrievalSetting | None,
+ node: dict,
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
@@ -163,10 +168,9 @@ class RagPipelineTransformService:
knowledge_configuration.embedding_model = dataset.embedding_model
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
if retrieval_model:
- retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
if indexing_technique == "economy":
- retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
- knowledge_configuration.retrieval_model = retrieval_setting
+ retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH
+ knowledge_configuration.retrieval_model = retrieval_model
else:
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py
index 544383a106..6b211a5632 100644
--- a/api/services/recommended_app_service.py
+++ b/api/services/recommended_app_service.py
@@ -1,4 +1,7 @@
from configs import dify_config
+from extensions.ext_database import db
+from models.model import AccountTrialAppRecord, TrialApp
+from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@@ -20,6 +23,15 @@ class RecommendedAppService:
)
)
+ if FeatureService.get_system_features().enable_trial_app:
+ apps = result["recommended_apps"]
+ for app in apps:
+ app_id = app["app_id"]
+ trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
+ if trial_app_model:
+ app["can_trial"] = True
+ else:
+ app["can_trial"] = False
return result
@classmethod
@@ -32,4 +44,30 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
+ if FeatureService.get_system_features().enable_trial_app:
+ app_id = result["id"]
+ trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
+ if trial_app_model:
+ result["can_trial"] = True
+ else:
+ result["can_trial"] = False
return result
+
+ @classmethod
+ def add_trial_app_record(cls, app_id: str, account_id: str):
+ """
+ Add trial app record.
+ :param app_id: app id
+ :return:
+ """
+ account_trial_app_record = (
+ db.session.query(AccountTrialAppRecord)
+ .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
+ .first()
+ )
+ if account_trial_app_record:
+ account_trial_app_record.count += 1
+ db.session.commit()
+ else:
+ db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
+ db.session.commit()
diff --git a/api/services/retention/__init__.py b/api/services/retention/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py
new file mode 100644
index 0000000000..6e647b983b
--- /dev/null
+++ b/api/services/retention/conversation/messages_clean_policy.py
@@ -0,0 +1,216 @@
+import datetime
+import logging
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+
+from configs import dify_config
+from enums.cloud_plan import CloudPlan
+from services.billing_service import BillingService, SubscriptionPlan
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SimpleMessage:
+ id: str
+ app_id: str
+ created_at: datetime.datetime
+
+
+class MessagesCleanPolicy(ABC):
+ """
+ Abstract base class for message cleanup policies.
+
+ A policy determines which messages from a batch should be deleted.
+ """
+
+ @abstractmethod
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ """
+ Filter messages and return IDs of messages that should be deleted.
+
+ Args:
+ messages: Batch of messages to evaluate
+ app_to_tenant: Mapping from app_id to tenant_id
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ ...
+
+
+class BillingDisabledPolicy(MessagesCleanPolicy):
+ """
+ Policy for community or enterpriseedition (billing disabled).
+
+ No special filter logic, just return all message ids.
+ """
+
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ return [msg.id for msg in messages]
+
+
+class BillingSandboxPolicy(MessagesCleanPolicy):
+ """
+ Policy for sandbox plan tenants in cloud edition (billing enabled).
+
+ Filters messages based on sandbox plan expiration rules:
+ - Skip tenants in the whitelist
+ - Only delete messages from sandbox plan tenants
+ - Respect grace period after subscription expiration
+ - Safe default: if tenant mapping or plan is missing, do NOT delete
+ """
+
+ def __init__(
+ self,
+ plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
+ graceful_period_days: int = 21,
+ tenant_whitelist: Sequence[str] | None = None,
+ current_timestamp: int | None = None,
+ ) -> None:
+ self._graceful_period_days = graceful_period_days
+ self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
+ self._plan_provider = plan_provider
+ self._current_timestamp = current_timestamp
+
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ """
+ Filter messages based on sandbox plan expiration rules.
+
+ Args:
+ messages: Batch of messages to evaluate
+ app_to_tenant: Mapping from app_id to tenant_id
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ if not messages or not app_to_tenant:
+ return []
+
+ # Get unique tenant_ids and fetch subscription plans
+ tenant_ids = list(set(app_to_tenant.values()))
+ tenant_plans = self._plan_provider(tenant_ids)
+
+ if not tenant_plans:
+ return []
+
+ # Apply sandbox deletion rules
+ return self._filter_expired_sandbox_messages(
+ messages=messages,
+ app_to_tenant=app_to_tenant,
+ tenant_plans=tenant_plans,
+ )
+
+ def _filter_expired_sandbox_messages(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ tenant_plans: dict[str, SubscriptionPlan],
+ ) -> list[str]:
+ """
+ Filter messages that should be deleted based on sandbox plan expiration.
+
+ A message should be deleted if:
+ 1. It belongs to a sandbox tenant AND
+ 2. Either:
+ a) The tenant has no previous subscription (expiration_date == -1), OR
+ b) The subscription expired more than graceful_period_days ago
+
+ Args:
+ messages: List of message objects with id and app_id attributes
+ app_to_tenant: Mapping from app_id to tenant_id
+ tenant_plans: Mapping from tenant_id to subscription plan info
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ current_timestamp = self._current_timestamp
+ if current_timestamp is None:
+ current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
+
+ sandbox_message_ids: list[str] = []
+ graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
+
+ for msg in messages:
+ # Get tenant_id for this message's app
+ tenant_id = app_to_tenant.get(msg.app_id)
+ if not tenant_id:
+ continue
+
+ # Skip tenant messages in whitelist
+ if tenant_id in self._tenant_whitelist:
+ continue
+
+ # Get subscription plan for this tenant
+ tenant_plan = tenant_plans.get(tenant_id)
+ if not tenant_plan:
+ continue
+
+ plan = str(tenant_plan["plan"])
+ expiration_date = int(tenant_plan["expiration_date"])
+
+ # Only process sandbox plans
+ if plan != CloudPlan.SANDBOX:
+ continue
+
+ # Case 1: No previous subscription (-1 means never had a paid subscription)
+ if expiration_date == -1:
+ sandbox_message_ids.append(msg.id)
+ continue
+
+ # Case 2: Subscription expired beyond grace period
+ if current_timestamp - expiration_date > graceful_period_seconds:
+ sandbox_message_ids.append(msg.id)
+
+ return sandbox_message_ids
+
+
+def create_message_clean_policy(
+ graceful_period_days: int = 21,
+ current_timestamp: int | None = None,
+) -> MessagesCleanPolicy:
+ """
+ Factory function to create the appropriate message clean policy.
+
+ Determines which policy to use based on BILLING_ENABLED configuration:
+ - If BILLING_ENABLED is True: returns BillingSandboxPolicy
+ - If BILLING_ENABLED is False: returns BillingDisabledPolicy
+
+ Args:
+ graceful_period_days: Grace period in days after subscription expiration (default: 21)
+ current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
+ """
+ if not dify_config.BILLING_ENABLED:
+ logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
+ return BillingDisabledPolicy()
+
+ # Billing enabled - fetch whitelist from BillingService
+ tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
+ plan_provider = BillingService.get_plan_bulk_with_cache
+
+ logger.info(
+ "create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
+ "(graceful_period_days=%s, whitelist=%s)",
+ graceful_period_days,
+ tenant_whitelist,
+ )
+
+ return BillingSandboxPolicy(
+ plan_provider=plan_provider,
+ graceful_period_days=graceful_period_days,
+ tenant_whitelist=tenant_whitelist,
+ current_timestamp=current_timestamp,
+ )
diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py
new file mode 100644
index 0000000000..3ca5d82860
--- /dev/null
+++ b/api/services/retention/conversation/messages_clean_service.py
@@ -0,0 +1,334 @@
+import datetime
+import logging
+import random
+from collections.abc import Sequence
+from typing import cast
+
+from sqlalchemy import delete, select
+from sqlalchemy.engine import CursorResult
+from sqlalchemy.orm import Session
+
+from extensions.ext_database import db
+from models.model import (
+ App,
+ AppAnnotationHitHistory,
+ DatasetRetrieverResource,
+ Message,
+ MessageAgentThought,
+ MessageAnnotation,
+ MessageChain,
+ MessageFeedback,
+ MessageFile,
+)
+from models.web import SavedMessage
+from services.retention.conversation.messages_clean_policy import (
+ MessagesCleanPolicy,
+ SimpleMessage,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class MessagesCleanService:
+ """
+ Service for cleaning expired messages based on retention policies.
+
+ Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
+ If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
+ """
+
+ def __init__(
+ self,
+ policy: MessagesCleanPolicy,
+ end_before: datetime.datetime,
+ start_from: datetime.datetime | None = None,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> None:
+ """
+ Initialize the service with cleanup parameters.
+
+ Args:
+ policy: The policy that determines which messages to delete
+ end_before: End time (exclusive) of the range
+ start_from: Optional start time (inclusive) of the range
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+ """
+ self._policy = policy
+ self._end_before = end_before
+ self._start_from = start_from
+ self._batch_size = batch_size
+ self._dry_run = dry_run
+
+ @classmethod
+ def from_time_range(
+ cls,
+ policy: MessagesCleanPolicy,
+ start_from: datetime.datetime,
+ end_before: datetime.datetime,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> "MessagesCleanService":
+ """
+ Create a service instance for cleaning messages within a specific time range.
+
+ Time range is [start_from, end_before).
+
+ Args:
+ policy: The policy that determines which messages to delete
+ start_from: Start time (inclusive) of the range
+ end_before: End time (exclusive) of the range
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+
+ Returns:
+ MessagesCleanService instance
+
+ Raises:
+ ValueError: If start_from >= end_before or invalid parameters
+ """
+ if start_from >= end_before:
+ raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
+
+ if batch_size <= 0:
+ raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
+
+ logger.info(
+ "clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
+ start_from,
+ end_before,
+ batch_size,
+ policy.__class__.__name__,
+ )
+
+ return cls(
+ policy=policy,
+ end_before=end_before,
+ start_from=start_from,
+ batch_size=batch_size,
+ dry_run=dry_run,
+ )
+
+ @classmethod
+ def from_days(
+ cls,
+ policy: MessagesCleanPolicy,
+ days: int = 30,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> "MessagesCleanService":
+ """
+ Create a service instance for cleaning messages older than specified days.
+
+ Args:
+ policy: The policy that determines which messages to delete
+ days: Number of days to look back from now
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+
+ Returns:
+ MessagesCleanService instance
+
+ Raises:
+ ValueError: If invalid parameters
+ """
+ if days < 0:
+ raise ValueError(f"days ({days}) must be greater than or equal to 0")
+
+ if batch_size <= 0:
+ raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
+
+ end_before = datetime.datetime.now() - datetime.timedelta(days=days)
+
+ logger.info(
+ "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
+ days,
+ end_before,
+ batch_size,
+ policy.__class__.__name__,
+ )
+
+ return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
+
+ def run(self) -> dict[str, int]:
+ """
+ Execute the message cleanup operation.
+
+ Returns:
+ Dict with statistics: batches, filtered_messages, total_deleted
+ """
+ return self._clean_messages_by_time_range()
+
+ def _clean_messages_by_time_range(self) -> dict[str, int]:
+ """
+ Clean messages within a time range using cursor-based pagination.
+
+ Time range is [start_from, end_before)
+
+ Steps:
+ 1. Iterate messages using cursor pagination (by created_at, id)
+ 2. Query app_id -> tenant_id mapping
+ 3. Delegate to policy to determine which messages to delete
+ 4. Batch delete messages and their relations
+
+ Returns:
+ Dict with statistics: batches, filtered_messages, total_deleted
+ """
+ stats = {
+ "batches": 0,
+ "total_messages": 0,
+ "filtered_messages": 0,
+ "total_deleted": 0,
+ }
+
+ # Cursor-based pagination using (created_at, id) to avoid infinite loops
+ # and ensure proper ordering with time-based filtering
+ _cursor: tuple[datetime.datetime, str] | None = None
+
+ logger.info(
+ "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
+ self._dry_run,
+ self._start_from,
+ self._end_before,
+ )
+
+ while True:
+ stats["batches"] += 1
+
+ # Step 1: Fetch a batch of messages using cursor
+ with Session(db.engine, expire_on_commit=False) as session:
+ msg_stmt = (
+ select(Message.id, Message.app_id, Message.created_at)
+ .where(Message.created_at < self._end_before)
+ .order_by(Message.created_at, Message.id)
+ .limit(self._batch_size)
+ )
+
+ if self._start_from:
+ msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
+
+ # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
+ # This translates to:
+ # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
+ if _cursor:
+ # Continuing from previous batch
+ msg_stmt = msg_stmt.where(
+ (Message.created_at > _cursor[0])
+ | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
+ )
+
+ raw_messages = list(session.execute(msg_stmt).all())
+ messages = [
+ SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
+ for msg_id, app_id, msg_created_at in raw_messages
+ ]
+
+ # Track total messages fetched across all batches
+ stats["total_messages"] += len(messages)
+
+ if not messages:
+ logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
+ break
+
+ # Update cursor to the last message's (created_at, id)
+ _cursor = (messages[-1].created_at, messages[-1].id)
+
+ # Step 2: Extract app_ids and query tenant_ids
+ app_ids = list({msg.app_id for msg in messages})
+
+ if not app_ids:
+ logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
+ continue
+
+ app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
+ apps = list(session.execute(app_stmt).all())
+
+ if not apps:
+ logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
+ continue
+
+ # Build app_id -> tenant_id mapping
+ app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
+
+ # Step 3: Delegate to policy to determine which messages to delete
+ message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
+
+ if not message_ids_to_delete:
+ logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
+ continue
+
+ stats["filtered_messages"] += len(message_ids_to_delete)
+
+ # Step 4: Batch delete messages and their relations
+ if not self._dry_run:
+ with Session(db.engine, expire_on_commit=False) as session:
+ # Delete related records first
+ self._batch_delete_message_relations(session, message_ids_to_delete)
+
+ # Delete messages
+ delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
+ delete_result = cast(CursorResult, session.execute(delete_stmt))
+ messages_deleted = delete_result.rowcount
+ session.commit()
+
+ stats["total_deleted"] += messages_deleted
+
+ logger.info(
+ "clean_messages (batch %s): processed %s messages, deleted %s messages",
+ stats["batches"],
+ len(messages),
+ messages_deleted,
+ )
+ else:
+ # Log random sample of message IDs that would be deleted (up to 10)
+ sample_size = min(10, len(message_ids_to_delete))
+ sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
+
+ logger.info(
+ "clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
+ stats["batches"],
+ len(message_ids_to_delete),
+ sample_size,
+ )
+ for msg_id in sampled_ids:
+ logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
+
+ logger.info(
+ "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
+ stats["batches"],
+ stats["total_messages"],
+ stats["filtered_messages"],
+ stats["total_deleted"],
+ )
+
+ return stats
+
+ @staticmethod
+ def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
+ """
+ Batch delete all related records for given message IDs.
+
+ Args:
+ session: Database session
+ message_ids: List of message IDs to delete relations for
+ """
+ if not message_ids:
+ return
+
+ # Delete all related records in batch
+ session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
+
+ session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
+
+ session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
+
+ session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))
diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py
new file mode 100644
index 0000000000..18dd42c91e
--- /dev/null
+++ b/api/services/retention/workflow_run/__init__.py
@@ -0,0 +1 @@
+"""Workflow run retention services."""
diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py
new file mode 100644
index 0000000000..ea5cbb7740
--- /dev/null
+++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py
@@ -0,0 +1,531 @@
+"""
+Archive Paid Plan Workflow Run Logs Service.
+
+This service archives workflow run logs for paid plan users older than the configured
+retention period (default: 90 days) to S3-compatible storage.
+
+Archived tables:
+- workflow_runs
+- workflow_app_logs
+- workflow_node_executions
+- workflow_node_execution_offload
+- workflow_pauses
+- workflow_pause_reasons
+- workflow_trigger_logs
+
+"""
+
+import datetime
+import io
+import json
+import logging
+import time
+import zipfile
+from collections.abc import Sequence
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass, field
+from typing import Any
+
+import click
+from sqlalchemy import inspect
+from sqlalchemy.orm import Session, sessionmaker
+
+from configs import dify_config
+from core.workflow.enums import WorkflowType
+from enums.cloud_plan import CloudPlan
+from extensions.ext_database import db
+from libs.archive_storage import (
+ ArchiveStorage,
+ ArchiveStorageNotConfiguredError,
+ get_archive_storage,
+)
+from models.workflow import WorkflowAppLog, WorkflowRun
+from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+from services.billing_service import BillingService
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class TableStats:
+ """Statistics for a single archived table."""
+
+ table_name: str
+ row_count: int
+ checksum: str
+ size_bytes: int
+
+
+@dataclass
+class ArchiveResult:
+ """Result of archiving a single workflow run."""
+
+ run_id: str
+ tenant_id: str
+ success: bool
+ tables: list[TableStats] = field(default_factory=list)
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+@dataclass
+class ArchiveSummary:
+ """Summary of the entire archive operation."""
+
+ total_runs_processed: int = 0
+ runs_archived: int = 0
+ runs_skipped: int = 0
+ runs_failed: int = 0
+ total_elapsed_time: float = 0.0
+
+
+class WorkflowRunArchiver:
+ """
+ Archive workflow run logs for paid plan users.
+
+ Storage Layout:
+ {tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/
+ └── archive.v1.0.zip
+ ├── manifest.json
+ ├── workflow_runs.jsonl
+ ├── workflow_app_logs.jsonl
+ ├── workflow_node_executions.jsonl
+ ├── workflow_node_execution_offload.jsonl
+ ├── workflow_pauses.jsonl
+ ├── workflow_pause_reasons.jsonl
+ └── workflow_trigger_logs.jsonl
+ """
+
+ ARCHIVED_TYPE = [
+ WorkflowType.WORKFLOW,
+ WorkflowType.RAG_PIPELINE,
+ ]
+ ARCHIVED_TABLES = [
+ "workflow_runs",
+ "workflow_app_logs",
+ "workflow_node_executions",
+ "workflow_node_execution_offload",
+ "workflow_pauses",
+ "workflow_pause_reasons",
+ "workflow_trigger_logs",
+ ]
+
+ start_from: datetime.datetime | None
+ end_before: datetime.datetime
+
+ def __init__(
+ self,
+ days: int = 90,
+ batch_size: int = 100,
+ start_from: datetime.datetime | None = None,
+ end_before: datetime.datetime | None = None,
+ workers: int = 1,
+ tenant_ids: Sequence[str] | None = None,
+ limit: int | None = None,
+ dry_run: bool = False,
+ delete_after_archive: bool = False,
+ workflow_run_repo: APIWorkflowRunRepository | None = None,
+ ):
+ """
+ Initialize the archiver.
+
+ Args:
+ days: Archive runs older than this many days
+ batch_size: Number of runs to process per batch
+ start_from: Optional start time (inclusive) for archiving
+ end_before: Optional end time (exclusive) for archiving
+ workers: Number of concurrent workflow runs to archive
+ tenant_ids: Optional tenant IDs for grayscale rollout
+ limit: Maximum number of runs to archive (None for unlimited)
+ dry_run: If True, only preview without making changes
+ delete_after_archive: If True, delete runs and related data after archiving
+ """
+ self.days = days
+ self.batch_size = batch_size
+ if start_from or end_before:
+ if start_from is None or end_before is None:
+ raise ValueError("start_from and end_before must be provided together")
+ if start_from >= end_before:
+ raise ValueError("start_from must be earlier than end_before")
+ self.start_from = start_from.replace(tzinfo=datetime.UTC)
+ self.end_before = end_before.replace(tzinfo=datetime.UTC)
+ else:
+ self.start_from = None
+ self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days)
+ if workers < 1:
+ raise ValueError("workers must be at least 1")
+ self.workers = workers
+ self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else []
+ self.limit = limit
+ self.dry_run = dry_run
+ self.delete_after_archive = delete_after_archive
+ self.workflow_run_repo = workflow_run_repo
+
+ def run(self) -> ArchiveSummary:
+ """
+ Main archiving loop.
+
+ Returns:
+ ArchiveSummary with statistics about the operation
+ """
+ summary = ArchiveSummary()
+ start_time = time.time()
+
+ click.echo(
+ click.style(
+ self._build_start_message(),
+ fg="white",
+ )
+ )
+
+ # Initialize archive storage (will raise if not configured)
+ try:
+ if not self.dry_run:
+ storage = get_archive_storage()
+ else:
+ storage = None
+ except ArchiveStorageNotConfiguredError as e:
+ click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
+ return summary
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ repo = self._get_workflow_run_repo()
+
+ def _archive_with_session(run: WorkflowRun) -> ArchiveResult:
+ with session_maker() as session:
+ return self._archive_run(session, storage, run)
+
+ last_seen: tuple[datetime.datetime, str] | None = None
+ archived_count = 0
+
+ with ThreadPoolExecutor(max_workers=self.workers) as executor:
+ while True:
+ # Check limit
+ if self.limit and archived_count >= self.limit:
+ click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow"))
+ break
+
+ # Fetch batch of runs
+ runs = self._get_runs_batch(last_seen)
+
+ if not runs:
+ break
+
+ run_ids = [run.id for run in runs]
+ with session_maker() as session:
+ archived_run_ids = repo.get_archived_run_ids(session, run_ids)
+
+ last_seen = (runs[-1].created_at, runs[-1].id)
+
+ # Filter to paid tenants only
+ tenant_ids = {run.tenant_id for run in runs}
+ paid_tenants = self._filter_paid_tenants(tenant_ids)
+
+ runs_to_process: list[WorkflowRun] = []
+ for run in runs:
+ summary.total_runs_processed += 1
+
+ # Skip non-paid tenants
+ if run.tenant_id not in paid_tenants:
+ summary.runs_skipped += 1
+ continue
+
+ # Skip already archived runs
+ if run.id in archived_run_ids:
+ summary.runs_skipped += 1
+ continue
+
+ # Check limit
+ if self.limit and archived_count + len(runs_to_process) >= self.limit:
+ break
+
+ runs_to_process.append(run)
+
+ if not runs_to_process:
+ continue
+
+ results = list(executor.map(_archive_with_session, runs_to_process))
+
+ for run, result in zip(runs_to_process, results):
+ if result.success:
+ summary.runs_archived += 1
+ archived_count += 1
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} "
+ f"run {run.id} (tenant={run.tenant_id}, "
+ f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)",
+ fg="green",
+ )
+ )
+ else:
+ summary.runs_failed += 1
+ click.echo(
+ click.style(
+ f"Failed to archive run {run.id}: {result.error}",
+ fg="red",
+ )
+ )
+
+ summary.total_elapsed_time = time.time() - start_time
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: "
+ f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
+ f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
+ f"time={summary.total_elapsed_time:.2f}s",
+ fg="white",
+ )
+ )
+
+ return summary
+
+ def _get_runs_batch(
+ self,
+ last_seen: tuple[datetime.datetime, str] | None,
+ ) -> Sequence[WorkflowRun]:
+ """Fetch a batch of workflow runs to archive."""
+ repo = self._get_workflow_run_repo()
+ return repo.get_runs_batch_by_time_range(
+ start_from=self.start_from,
+ end_before=self.end_before,
+ last_seen=last_seen,
+ batch_size=self.batch_size,
+ run_types=self.ARCHIVED_TYPE,
+ tenant_ids=self.tenant_ids or None,
+ )
+
+ def _build_start_message(self) -> str:
+ range_desc = f"before {self.end_before.isoformat()}"
+ if self.start_from:
+ range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}"
+ return (
+ f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving "
+ f"for runs {range_desc} "
+ f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})"
+ )
+
+ def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]:
+ """Filter tenant IDs to only include paid tenants."""
+ if not dify_config.BILLING_ENABLED:
+ # If billing is not enabled, treat all tenants as paid
+ return tenant_ids
+
+ if not tenant_ids:
+ return set()
+
+ try:
+ bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids))
+ except Exception:
+ logger.exception("Failed to fetch billing plans for tenants")
+ # On error, skip all tenants in this batch
+ return set()
+
+ # Filter to paid tenants (any plan except SANDBOX)
+ paid = set()
+ for tid, info in bulk_info.items():
+ if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM):
+ paid.add(tid)
+
+ return paid
+
+ def _archive_run(
+ self,
+ session: Session,
+ storage: ArchiveStorage | None,
+ run: WorkflowRun,
+ ) -> ArchiveResult:
+ """Archive a single workflow run."""
+ start_time = time.time()
+ result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
+
+ try:
+ # Extract data from all tables
+ table_data, app_logs, trigger_metadata = self._extract_data(session, run)
+
+ if self.dry_run:
+ # In dry run, just report what would be archived
+ for table_name in self.ARCHIVED_TABLES:
+ records = table_data.get(table_name, [])
+ result.tables.append(
+ TableStats(
+ table_name=table_name,
+ row_count=len(records),
+ checksum="",
+ size_bytes=0,
+ )
+ )
+ result.success = True
+ else:
+ if storage is None:
+ raise ArchiveStorageNotConfiguredError("Archive storage not configured")
+ archive_key = self._get_archive_key(run)
+
+ # Serialize tables for the archive bundle
+ table_stats: list[TableStats] = []
+ table_payloads: dict[str, bytes] = {}
+ for table_name in self.ARCHIVED_TABLES:
+ records = table_data.get(table_name, [])
+ data = ArchiveStorage.serialize_to_jsonl(records)
+ table_payloads[table_name] = data
+ checksum = ArchiveStorage.compute_checksum(data)
+
+ table_stats.append(
+ TableStats(
+ table_name=table_name,
+ row_count=len(records),
+ checksum=checksum,
+ size_bytes=len(data),
+ )
+ )
+
+ # Generate and upload archive bundle
+ manifest = self._generate_manifest(run, table_stats)
+ manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8")
+ archive_data = self._build_archive_bundle(manifest_data, table_payloads)
+ storage.put_object(archive_key, archive_data)
+
+ repo = self._get_workflow_run_repo()
+ archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata)
+ session.commit()
+
+ deleted_counts = None
+ if self.delete_after_archive:
+ deleted_counts = repo.delete_runs_with_related(
+ [run],
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+
+ logger.info(
+ "Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s",
+ run.id,
+ {s.table_name: s.row_count for s in table_stats},
+ archived_log_count,
+ deleted_counts,
+ )
+
+ result.tables = table_stats
+ result.success = True
+
+ except Exception as e:
+ logger.exception("Failed to archive workflow run %s", run.id)
+ result.error = str(e)
+ session.rollback()
+
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def _extract_data(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ ) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]:
+ table_data: dict[str, list[dict[str, Any]]] = {}
+ table_data["workflow_runs"] = [self._row_to_dict(run)]
+ repo = self._get_workflow_run_repo()
+ app_logs = repo.get_app_logs_by_run_id(session, run.id)
+ table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs]
+ node_exec_repo = self._get_workflow_node_execution_repo(session)
+ node_exec_records = node_exec_repo.get_executions_by_workflow_run(
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_run_id=run.id,
+ )
+ node_exec_ids = [record.id for record in node_exec_records]
+ offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids)
+ table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records]
+ table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records]
+ repo = self._get_workflow_run_repo()
+ pause_records = repo.get_pause_records_by_run_id(session, run.id)
+ pause_ids = [pause.id for pause in pause_records]
+ pause_reason_records = repo.get_pause_reason_records_by_run_id(
+ session,
+ pause_ids,
+ )
+ table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records]
+ table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records]
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ trigger_records = trigger_repo.list_by_run_id(run.id)
+ table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records]
+ trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None
+ return table_data, app_logs, trigger_metadata
+
+ @staticmethod
+ def _row_to_dict(row: Any) -> dict[str, Any]:
+ mapper = inspect(row).mapper
+ return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns}
+
+ def _get_archive_key(self, run: WorkflowRun) -> str:
+ """Get the storage key for the archive bundle."""
+ created_at = run.created_at
+ prefix = (
+ f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
+ f"month={created_at.strftime('%m')}/workflow_run_id={run.id}"
+ )
+ return f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+
+ def _generate_manifest(
+ self,
+ run: WorkflowRun,
+ table_stats: list[TableStats],
+ ) -> dict[str, Any]:
+ """Generate a manifest for the archived workflow run."""
+ return {
+ "schema_version": ARCHIVE_SCHEMA_VERSION,
+ "workflow_run_id": run.id,
+ "tenant_id": run.tenant_id,
+ "app_id": run.app_id,
+ "workflow_id": run.workflow_id,
+ "created_at": run.created_at.isoformat(),
+ "archived_at": datetime.datetime.now(datetime.UTC).isoformat(),
+ "tables": {
+ stat.table_name: {
+ "row_count": stat.row_count,
+ "checksum": stat.checksum,
+ "size_bytes": stat.size_bytes,
+ }
+ for stat in table_stats
+ },
+ }
+
+ def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes:
+ buffer = io.BytesIO()
+ with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
+ archive.writestr("manifest.json", manifest_data)
+ for table_name in self.ARCHIVED_TABLES:
+ data = table_payloads.get(table_name)
+ if data is None:
+ raise ValueError(f"Missing archive payload for {table_name}")
+ archive.writestr(f"{table_name}.jsonl", data)
+ return buffer.getvalue()
+
+ def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids)
+
+ def _get_workflow_node_execution_repo(
+ self,
+ session: Session,
+ ) -> DifyAPIWorkflowNodeExecutionRepository:
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+ return self.workflow_run_repo
diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py
new file mode 100644
index 0000000000..c3e0dce399
--- /dev/null
+++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py
@@ -0,0 +1,293 @@
+import datetime
+import logging
+from collections.abc import Iterable, Sequence
+
+import click
+from sqlalchemy.orm import Session, sessionmaker
+
+from configs import dify_config
+from enums.cloud_plan import CloudPlan
+from extensions.ext_database import db
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.factory import DifyAPIRepositoryFactory
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+from services.billing_service import BillingService, SubscriptionPlan
+
+logger = logging.getLogger(__name__)
+
+
+class WorkflowRunCleanup:
+ def __init__(
+ self,
+ days: int,
+ batch_size: int,
+ start_from: datetime.datetime | None = None,
+ end_before: datetime.datetime | None = None,
+ workflow_run_repo: APIWorkflowRunRepository | None = None,
+ dry_run: bool = False,
+ ):
+ if (start_from is None) ^ (end_before is None):
+ raise ValueError("start_from and end_before must be both set or both omitted.")
+
+ computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days)
+ self.window_start = start_from
+ self.window_end = end_before or computed_cutoff
+
+ if self.window_start and self.window_end <= self.window_start:
+ raise ValueError("end_before must be greater than start_from.")
+
+ if batch_size <= 0:
+ raise ValueError("batch_size must be greater than 0.")
+
+ self.batch_size = batch_size
+ self._cleanup_whitelist: set[str] | None = None
+ self.dry_run = dry_run
+ self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
+ self.workflow_run_repo: APIWorkflowRunRepository
+ if workflow_run_repo:
+ self.workflow_run_repo = workflow_run_repo
+ else:
+ # Lazy import to avoid circular dependencies during module import
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ def run(self) -> None:
+ click.echo(
+ click.style(
+ f"{'Inspecting' if self.dry_run else 'Cleaning'} workflow runs "
+ f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}"
+ f"{self.window_end.isoformat()} (batch={self.batch_size})",
+ fg="white",
+ )
+ )
+ if self.dry_run:
+ click.echo(click.style("Dry run mode enabled. No data will be deleted.", fg="yellow"))
+
+ total_runs_deleted = 0
+ total_runs_targeted = 0
+ related_totals = self._empty_related_counts() if self.dry_run else None
+ batch_index = 0
+ last_seen: tuple[datetime.datetime, str] | None = None
+
+ while True:
+ run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
+ start_from=self.window_start,
+ end_before=self.window_end,
+ last_seen=last_seen,
+ batch_size=self.batch_size,
+ )
+ if not run_rows:
+ break
+
+ batch_index += 1
+ last_seen = (run_rows[-1].created_at, run_rows[-1].id)
+ tenant_ids = {row.tenant_id for row in run_rows}
+ free_tenants = self._filter_free_tenants(tenant_ids)
+ free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
+ paid_or_skipped = len(run_rows) - len(free_runs)
+
+ if not free_runs:
+ skipped_message = (
+ f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
+ )
+ click.echo(
+ click.style(
+ skipped_message,
+ fg="yellow",
+ )
+ )
+ continue
+
+ total_runs_targeted += len(free_runs)
+
+ if self.dry_run:
+ batch_counts = self.workflow_run_repo.count_runs_with_related(
+ free_runs,
+ count_node_executions=self._count_node_executions,
+ count_trigger_logs=self._count_trigger_logs,
+ )
+ if related_totals is not None:
+ for key in related_totals:
+ related_totals[key] += batch_counts.get(key, 0)
+ sample_ids = ", ".join(run.id for run in free_runs[:5])
+ click.echo(
+ click.style(
+ f"[batch #{batch_index}] would delete {len(free_runs)} runs "
+ f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
+ fg="yellow",
+ )
+ )
+ continue
+
+ try:
+ counts = self.workflow_run_repo.delete_runs_with_related(
+ free_runs,
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+ except Exception:
+ logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
+ raise
+
+ total_runs_deleted += counts["runs"]
+ click.echo(
+ click.style(
+ f"[batch #{batch_index}] deleted runs: {counts['runs']} "
+ f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
+ f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
+ f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
+ f"skipped {paid_or_skipped} paid/unknown",
+ fg="green",
+ )
+ )
+
+ if self.dry_run:
+ if self.window_start:
+ summary_message = (
+ f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
+ f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
+ )
+ else:
+ summary_message = (
+ f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
+ f"before {self.window_end.isoformat()}"
+ )
+ if related_totals is not None:
+ summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
+ summary_color = "yellow"
+ else:
+ if self.window_start:
+ summary_message = (
+ f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
+ f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
+ )
+ else:
+ summary_message = (
+ f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
+ )
+ summary_color = "white"
+
+ click.echo(click.style(summary_message, fg=summary_color))
+
+ def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
+ tenant_id_list = list(tenant_ids)
+
+ if not dify_config.BILLING_ENABLED:
+ return set(tenant_id_list)
+
+ if not tenant_id_list:
+ return set()
+
+ cleanup_whitelist = self._get_cleanup_whitelist()
+
+ try:
+ bulk_info = BillingService.get_plan_bulk_with_cache(tenant_id_list)
+ except Exception:
+ bulk_info = {}
+ logger.exception("Failed to fetch billing plans in bulk for tenants: %s", tenant_id_list)
+
+ eligible_free_tenants: set[str] = set()
+ for tenant_id in tenant_id_list:
+ if tenant_id in cleanup_whitelist:
+ continue
+
+ info = bulk_info.get(tenant_id)
+ if info is None:
+ logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id)
+ continue
+
+ if info.get("plan") != CloudPlan.SANDBOX:
+ continue
+
+ if self._is_within_grace_period(tenant_id, info):
+ continue
+
+ eligible_free_tenants.add(tenant_id)
+
+ return eligible_free_tenants
+
+ def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None:
+ if expiration_value < 0:
+ return None
+
+ try:
+ return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC)
+ except (OverflowError, OSError, ValueError):
+ logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id)
+ return None
+
+ def _is_within_grace_period(self, tenant_id: str, info: SubscriptionPlan) -> bool:
+ if self.free_plan_grace_period_days <= 0:
+ return False
+
+ expiration_value = info.get("expiration_date", -1)
+ expiration_at = self._expiration_datetime(tenant_id, expiration_value)
+ if expiration_at is None:
+ return False
+
+ grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days)
+ return datetime.datetime.now(datetime.UTC) < grace_deadline
+
+ def _get_cleanup_whitelist(self) -> set[str]:
+ if self._cleanup_whitelist is not None:
+ return self._cleanup_whitelist
+
+ if not dify_config.BILLING_ENABLED:
+ self._cleanup_whitelist = set()
+ return self._cleanup_whitelist
+
+ try:
+ whitelist_ids = BillingService.get_expired_subscription_cleanup_whitelist()
+ except Exception:
+ logger.exception("Failed to fetch cleanup whitelist from billing service")
+ whitelist_ids = []
+
+ self._cleanup_whitelist = set(whitelist_ids)
+ return self._cleanup_whitelist
+
+ def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ def _count_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.count_by_run_ids(run_ids)
+
+ @staticmethod
+ def _empty_related_counts() -> dict[str, int]:
+ return {
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ @staticmethod
+ def _format_related_counts(counts: dict[str, int]) -> str:
+ return (
+ f"node_executions {counts['node_executions']}, "
+ f"offloads {counts['offloads']}, "
+ f"app_logs {counts['app_logs']}, "
+ f"trigger_logs {counts['trigger_logs']}, "
+ f"pauses {counts['pauses']}, "
+ f"pause_reasons {counts['pause_reasons']}"
+ )
+
+ def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.count_by_runs(session, run_ids)
+
+ def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.delete_by_runs(session, run_ids)
diff --git a/api/services/retention/workflow_run/constants.py b/api/services/retention/workflow_run/constants.py
new file mode 100644
index 0000000000..162bb4947d
--- /dev/null
+++ b/api/services/retention/workflow_run/constants.py
@@ -0,0 +1,2 @@
+ARCHIVE_SCHEMA_VERSION = "1.0"
+ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip"
diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py
new file mode 100644
index 0000000000..11873bf1b9
--- /dev/null
+++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py
@@ -0,0 +1,134 @@
+"""
+Delete Archived Workflow Run Service.
+
+This service deletes archived workflow run data from the database while keeping
+archive logs intact.
+"""
+
+import time
+from collections.abc import Sequence
+from dataclasses import dataclass, field
+from datetime import datetime
+
+from sqlalchemy.orm import Session, sessionmaker
+
+from extensions.ext_database import db
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+
+
+@dataclass
+class DeleteResult:
+ run_id: str
+ tenant_id: str
+ success: bool
+ deleted_counts: dict[str, int] = field(default_factory=dict)
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+class ArchivedWorkflowRunDeletion:
+ def __init__(self, dry_run: bool = False):
+ self.dry_run = dry_run
+ self.workflow_run_repo: APIWorkflowRunRepository | None = None
+
+ def delete_by_run_id(self, run_id: str) -> DeleteResult:
+ start_time = time.time()
+ result = DeleteResult(run_id=run_id, tenant_id="", success=False)
+
+ repo = self._get_workflow_run_repo()
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ with session_maker() as session:
+ run = session.get(WorkflowRun, run_id)
+ if not run:
+ result.error = f"Workflow run {run_id} not found"
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ result.tenant_id = run.tenant_id
+ if not repo.get_archived_run_ids(session, [run.id]):
+ result.error = f"Workflow run {run_id} is not archived"
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ result = self._delete_run(run)
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def delete_batch(
+ self,
+ tenant_ids: list[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int = 100,
+ ) -> list[DeleteResult]:
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ results: list[DeleteResult] = []
+
+ repo = self._get_workflow_run_repo()
+ with session_maker() as session:
+ runs = list(
+ repo.get_archived_runs_by_time_range(
+ session=session,
+ tenant_ids=tenant_ids,
+ start_date=start_date,
+ end_date=end_date,
+ limit=limit,
+ )
+ )
+ for run in runs:
+ results.append(self._delete_run(run))
+
+ return results
+
+ def _delete_run(self, run: WorkflowRun) -> DeleteResult:
+ start_time = time.time()
+ result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
+ if self.dry_run:
+ result.success = True
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ repo = self._get_workflow_run_repo()
+ try:
+ deleted_counts = repo.delete_runs_with_related(
+ [run],
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+ result.deleted_counts = deleted_counts
+ result.success = True
+ except Exception as e:
+ result.error = str(e)
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ @staticmethod
+ def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ @staticmethod
+ def _delete_node_executions(
+ session: Session,
+ runs: Sequence[WorkflowRun],
+ ) -> tuple[int, int]:
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.delete_by_runs(session, run_ids)
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
+ sessionmaker(bind=db.engine, expire_on_commit=False)
+ )
+ return self.workflow_run_repo
diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py
new file mode 100644
index 0000000000..d4a6e87585
--- /dev/null
+++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py
@@ -0,0 +1,481 @@
+"""
+Restore Archived Workflow Run Service.
+
+This service restores archived workflow run data from S3-compatible storage
+back to the database.
+"""
+
+import io
+import json
+import logging
+import time
+import zipfile
+from collections.abc import Callable
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Any, cast
+
+import click
+from sqlalchemy.dialects.postgresql import insert as pg_insert
+from sqlalchemy.engine import CursorResult
+from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
+
+from extensions.ext_database import db
+from libs.archive_storage import (
+ ArchiveStorage,
+ ArchiveStorageNotConfiguredError,
+ get_archive_storage,
+)
+from models.trigger import WorkflowTriggerLog
+from models.workflow import (
+ WorkflowAppLog,
+ WorkflowArchiveLog,
+ WorkflowNodeExecutionModel,
+ WorkflowNodeExecutionOffload,
+ WorkflowPause,
+ WorkflowPauseReason,
+ WorkflowRun,
+)
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.factory import DifyAPIRepositoryFactory
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
+
+logger = logging.getLogger(__name__)
+
+
+# Mapping of table names to SQLAlchemy models
+TABLE_MODELS = {
+ "workflow_runs": WorkflowRun,
+ "workflow_app_logs": WorkflowAppLog,
+ "workflow_node_executions": WorkflowNodeExecutionModel,
+ "workflow_node_execution_offload": WorkflowNodeExecutionOffload,
+ "workflow_pauses": WorkflowPause,
+ "workflow_pause_reasons": WorkflowPauseReason,
+ "workflow_trigger_logs": WorkflowTriggerLog,
+}
+
+SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]]
+
+SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = {
+ "1.0": {},
+}
+
+
+@dataclass
+class RestoreResult:
+ """Result of restoring a single workflow run."""
+
+ run_id: str
+ tenant_id: str
+ success: bool
+ restored_counts: dict[str, int]
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+class WorkflowRunRestore:
+ """
+ Restore archived workflow run data from storage to database.
+
+ This service reads archived data from storage and restores it to the
+ database tables. It handles idempotency by skipping records that already
+ exist in the database.
+ """
+
+ def __init__(self, dry_run: bool = False, workers: int = 1):
+ """
+ Initialize the restore service.
+
+ Args:
+ dry_run: If True, only preview without making changes
+ workers: Number of concurrent workflow runs to restore
+ """
+ self.dry_run = dry_run
+ if workers < 1:
+ raise ValueError("workers must be at least 1")
+ self.workers = workers
+ self.workflow_run_repo: APIWorkflowRunRepository | None = None
+
+ def _restore_from_run(
+ self,
+ run: WorkflowRun | WorkflowArchiveLog,
+ *,
+ session_maker: sessionmaker,
+ ) -> RestoreResult:
+ start_time = time.time()
+ run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id
+ created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at
+ result = RestoreResult(
+ run_id=run_id,
+ tenant_id=run.tenant_id,
+ success=False,
+ restored_counts={},
+ )
+
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})",
+ fg="white",
+ )
+ )
+
+ try:
+ storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ result.error = str(e)
+ click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ prefix = (
+ f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
+ f"month={created_at.strftime('%m')}/workflow_run_id={run_id}"
+ )
+ archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+ try:
+ archive_data = storage.get_object(archive_key)
+ except FileNotFoundError:
+ result.error = f"Archive bundle not found: {archive_key}"
+ click.echo(click.style(result.error, fg="red"))
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ with session_maker() as session:
+ try:
+ with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive:
+ try:
+ manifest = self._load_manifest_from_zip(archive)
+ except ValueError as e:
+ result.error = f"Archive bundle invalid: {e}"
+ click.echo(click.style(result.error, fg="red"))
+ return result
+
+ tables = manifest.get("tables", {})
+ schema_version = self._get_schema_version(manifest)
+ for table_name, info in tables.items():
+ row_count = info.get("row_count", 0)
+ if row_count == 0:
+ result.restored_counts[table_name] = 0
+ continue
+
+ if self.dry_run:
+ result.restored_counts[table_name] = row_count
+ continue
+
+ member_path = f"{table_name}.jsonl"
+ try:
+ data = archive.read(member_path)
+ except KeyError:
+ click.echo(
+ click.style(
+ f" Warning: Table data not found in archive: {member_path}",
+ fg="yellow",
+ )
+ )
+ result.restored_counts[table_name] = 0
+ continue
+
+ records = ArchiveStorage.deserialize_from_jsonl(data)
+ restored = self._restore_table_records(
+ session,
+ table_name,
+ records,
+ schema_version=schema_version,
+ )
+ result.restored_counts[table_name] = restored
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f" Restored {restored}/{len(records)} records to {table_name}",
+ fg="white",
+ )
+ )
+
+ # Verify row counts match manifest
+ manifest_total = sum(info.get("row_count", 0) for info in tables.values())
+ restored_total = sum(result.restored_counts.values())
+
+ if not self.dry_run:
+ # Note: restored count might be less than manifest count if records already exist
+ logger.info(
+ "Restore verification: manifest_total=%d, restored_total=%d",
+ manifest_total,
+ restored_total,
+ )
+
+ # Delete the archive log record after successful restore
+ repo = self._get_workflow_run_repo()
+ repo.delete_archive_log_by_run_id(session, run_id)
+
+ session.commit()
+
+ result.success = True
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f"Completed restore for workflow run {run_id}: restored={result.restored_counts}",
+ fg="green",
+ )
+ )
+
+ except Exception as e:
+ logger.exception("Failed to restore workflow run %s", run_id)
+ result.error = str(e)
+ session.rollback()
+ click.echo(click.style(f"Restore failed: {e}", fg="red"))
+
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
+ sessionmaker(bind=db.engine, expire_on_commit=False)
+ )
+ return self.workflow_run_repo
+
+ @staticmethod
+ def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
+ try:
+ data = archive.read("manifest.json")
+ except KeyError as e:
+ raise ValueError("manifest.json missing from archive bundle") from e
+ return json.loads(data.decode("utf-8"))
+
+ def _restore_table_records(
+ self,
+ session: Session,
+ table_name: str,
+ records: list[dict[str, Any]],
+ *,
+ schema_version: str,
+ ) -> int:
+ """
+ Restore records to a table.
+
+ Uses INSERT ... ON CONFLICT DO NOTHING for idempotency.
+
+ Args:
+ session: Database session
+ table_name: Name of the table
+ records: List of record dictionaries
+ schema_version: Archived schema version from manifest
+
+ Returns:
+ Number of records actually inserted
+ """
+ if not records:
+ return 0
+
+ model = TABLE_MODELS.get(table_name)
+ if not model:
+ logger.warning("Unknown table: %s", table_name)
+ return 0
+
+ column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model)
+ unknown_fields: set[str] = set()
+
+ # Apply schema mapping, filter to current columns, then convert datetimes
+ converted_records = []
+ for record in records:
+ mapped = self._apply_schema_mapping(table_name, schema_version, record)
+ unknown_fields.update(set(mapped.keys()) - column_names)
+ filtered = {key: value for key, value in mapped.items() if key in column_names}
+ for key in non_nullable_with_default:
+ if key in filtered and filtered[key] is None:
+ filtered.pop(key)
+ missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None]
+ if missing_required:
+ missing_cols = ", ".join(sorted(missing_required))
+ raise ValueError(
+ f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}"
+ )
+ converted = self._convert_datetime_fields(filtered, model)
+ converted_records.append(converted)
+ if unknown_fields:
+ logger.warning(
+ "Dropped unknown columns for %s (schema_version=%s): %s",
+ table_name,
+ schema_version,
+ ", ".join(sorted(unknown_fields)),
+ )
+
+ # Use INSERT ... ON CONFLICT DO NOTHING for idempotency
+ stmt = pg_insert(model).values(converted_records)
+ stmt = stmt.on_conflict_do_nothing(index_elements=["id"])
+
+ result = session.execute(stmt)
+ return cast(CursorResult, result).rowcount or 0
+
+ def _convert_datetime_fields(
+ self,
+ record: dict[str, Any],
+ model: type[DeclarativeBase] | Any,
+ ) -> dict[str, Any]:
+ """Convert ISO datetime strings to datetime objects."""
+ from sqlalchemy import DateTime
+
+ result = dict(record)
+
+ for column in model.__table__.columns:
+ if isinstance(column.type, DateTime):
+ value = result.get(column.key)
+ if isinstance(value, str):
+ try:
+ result[column.key] = datetime.fromisoformat(value)
+ except ValueError:
+ pass
+
+ return result
+
+ def _get_schema_version(self, manifest: dict[str, Any]) -> str:
+ schema_version = manifest.get("schema_version")
+ if not schema_version:
+ logger.warning("Manifest missing schema_version; defaulting to 1.0")
+ schema_version = "1.0"
+ schema_version = str(schema_version)
+ if schema_version not in SCHEMA_MAPPERS:
+ raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.")
+ return schema_version
+
+ def _apply_schema_mapping(
+ self,
+ table_name: str,
+ schema_version: str,
+ record: dict[str, Any],
+ ) -> dict[str, Any]:
+ # Keep hook for forward/backward compatibility when schema evolves.
+ mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name)
+ if mapper is None:
+ return dict(record)
+ return mapper(record)
+
+ def _get_model_column_info(
+ self,
+ model: type[DeclarativeBase] | Any,
+ ) -> tuple[set[str], set[str], set[str]]:
+ columns = list(model.__table__.columns)
+ column_names = {column.key for column in columns}
+ required_columns = {
+ column.key
+ for column in columns
+ if not column.nullable
+ and column.default is None
+ and column.server_default is None
+ and not column.autoincrement
+ }
+ non_nullable_with_default = {
+ column.key
+ for column in columns
+ if not column.nullable
+ and (column.default is not None or column.server_default is not None or column.autoincrement)
+ }
+ return column_names, required_columns, non_nullable_with_default
+
+ def restore_batch(
+ self,
+ tenant_ids: list[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int = 100,
+ ) -> list[RestoreResult]:
+ """
+ Restore multiple workflow runs by time range.
+
+ Args:
+ tenant_ids: Optional tenant IDs
+ start_date: Start date filter
+ end_date: End date filter
+ limit: Maximum number of runs to restore (default: 100)
+
+ Returns:
+ List of RestoreResult objects
+ """
+ results: list[RestoreResult] = []
+ if tenant_ids is not None and not tenant_ids:
+ return results
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ repo = self._get_workflow_run_repo()
+
+ with session_maker() as session:
+ archive_logs = repo.get_archived_logs_by_time_range(
+ session=session,
+ tenant_ids=tenant_ids,
+ start_date=start_date,
+ end_date=end_date,
+ limit=limit,
+ )
+
+ click.echo(
+ click.style(
+ f"Found {len(archive_logs)} archived workflow runs to restore",
+ fg="white",
+ )
+ )
+
+ def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult:
+ return self._restore_from_run(
+ archive_log,
+ session_maker=session_maker,
+ )
+
+ with ThreadPoolExecutor(max_workers=self.workers) as executor:
+ results = list(executor.map(_restore_with_session, archive_logs))
+
+ total_counts: dict[str, int] = {}
+ for result in results:
+ for table_name, count in result.restored_counts.items():
+ total_counts[table_name] = total_counts.get(table_name, 0) + count
+ success_count = sum(1 for result in results if result.success)
+
+ if self.dry_run:
+ click.echo(
+ click.style(
+ f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}",
+ fg="yellow",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}",
+ fg="green",
+ )
+ )
+
+ return results
+
+ def restore_by_run_id(
+ self,
+ run_id: str,
+ ) -> RestoreResult:
+ """
+ Restore a single workflow run by run ID.
+ """
+ repo = self._get_workflow_run_repo()
+ archive_log = repo.get_archived_log_by_run_id(run_id)
+
+ if not archive_log:
+ click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red"))
+ return RestoreResult(
+ run_id=run_id,
+ tenant_id="",
+ success=False,
+ restored_counts={},
+ error=f"Workflow run archive {run_id} not found",
+ )
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ result = self._restore_from_run(archive_log, session_maker=session_maker)
+ if self.dry_run and result.success:
+ click.echo(
+ click.style(
+ f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}",
+ fg="yellow",
+ )
+ )
+ return result
diff --git a/api/services/tag_service.py b/api/services/tag_service.py
index 937e6593fe..bd3585acf4 100644
--- a/api/services/tag_service.py
+++ b/api/services/tag_service.py
@@ -19,7 +19,10 @@ class TagService:
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
- query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
+ from libs.helper import escape_like_pattern
+
+ escaped_keyword = escape_like_pattern(keyword)
+ query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results
diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py
index b3b6e36346..c32157919b 100644
--- a/api/services/tools/api_tools_manage_service.py
+++ b/api/services/tools/api_tools_manage_service.py
@@ -7,7 +7,6 @@ from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController
@@ -86,7 +85,9 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
- def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
+ def convert_schema_to_tool_bundles(
+ schema: str, extra_info: dict | None = None
+ ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
convert schema to tool bundles
@@ -104,7 +105,7 @@ class ApiToolManageService:
provider_name: str,
icon: dict,
credentials: dict,
- schema_type: str,
+ schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
@@ -113,9 +114,6 @@ class ApiToolManageService:
"""
create api tool provider
"""
- if schema_type not in [member.value for member in ApiProviderSchemaType]:
- raise ValueError(f"invalid schema type {schema}")
-
provider_name = provider_name.strip()
# check if the provider exists
@@ -178,9 +176,6 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@staticmethod
@@ -245,18 +240,15 @@ class ApiToolManageService:
original_provider: str,
icon: dict,
credentials: dict,
- schema_type: str,
+ _schema_type: ApiProviderSchemaType,
schema: str,
- privacy_policy: str,
+ privacy_policy: str | None,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
- if schema_type not in [member.value for member in ApiProviderSchemaType]:
- raise ValueError(f"invalid schema type {schema}")
-
provider_name = provider_name.strip()
# check if the provider exists
@@ -281,7 +273,7 @@ class ApiToolManageService:
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
- provider.schema_type_str = ApiProviderSchemaType.OPENAPI
+ provider.schema_type_str = schema_type
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
@@ -322,9 +314,6 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@staticmethod
@@ -347,9 +336,6 @@ class ApiToolManageService:
db.session.delete(provider)
db.session.commit()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@staticmethod
@@ -366,7 +352,7 @@ class ApiToolManageService:
tool_name: str,
credentials: dict,
parameters: dict,
- schema_type: str,
+ schema_type: ApiProviderSchemaType,
schema: str,
):
"""
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
index 87951d53e6..6797a67dde 100644
--- a/api/services/tools/builtin_tools_manage_service.py
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@@ -205,9 +204,6 @@ class BuiltinToolManageService:
db_provider.name = name
session.commit()
-
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
@@ -290,8 +286,6 @@ class BuiltinToolManageService:
session.rollback()
raise ValueError(str(e))
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
return {"result": "success"}
@staticmethod
@@ -409,9 +403,6 @@ class BuiltinToolManageService:
)
cache.delete()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@staticmethod
@@ -434,8 +425,6 @@ class BuiltinToolManageService:
target_provider.is_default = True
session.commit()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py
index 252be77b27..0be106f597 100644
--- a/api/services/tools/mcp_tools_manage_service.py
+++ b/api/services/tools/mcp_tools_manage_service.py
@@ -319,8 +319,14 @@ class MCPToolManageService:
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
- # Update database with retrieved tools
- db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
+ # Update database with retrieved tools (ensure description is a non-null string)
+ tools_payload = []
+ for tool in tools:
+ data = tool.model_dump()
+ if data.get("description") is None:
+ data["description"] = ""
+ tools_payload.append(data)
+ db_provider.tools = json.dumps(tools_payload)
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
@@ -620,6 +626,21 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash,
)
+ @staticmethod
+ def reconnect_with_url(
+ *,
+ server_url: str,
+ headers: dict[str, str],
+ timeout: float | None,
+ sse_read_timeout: float | None,
+ ) -> ReconnectResult:
+ return MCPToolManageService._reconnect_with_url(
+ server_url=server_url,
+ headers=headers,
+ timeout=timeout,
+ sse_read_timeout=sse_read_timeout,
+ )
+
@staticmethod
def _reconnect_with_url(
*,
@@ -642,9 +663,16 @@ class MCPToolManageService:
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
+ # Ensure tool descriptions are non-null in payload
+ tools_payload = []
+ for t in tools:
+ d = t.model_dump()
+ if d.get("description") is None:
+ d["description"] = ""
+ tools_payload.append(d)
return ReconnectResult(
authed=True,
- tools=json.dumps([tool.model_dump() for tool in tools]),
+ tools=json.dumps(tools_payload),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py
index 038c462f15..51e9120b8d 100644
--- a/api/services/tools/tools_manage_service.py
+++ b/api/services/tools/tools_manage_service.py
@@ -1,6 +1,5 @@
import logging
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
@@ -16,14 +15,6 @@ class ToolCommonService:
:return: the list of tool providers
"""
- # Try to get from cache first
- cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
- if cached_result is not None:
- logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
- return cached_result
-
- # Cache miss - fetch from database
- logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
# add icon
@@ -32,7 +23,4 @@ class ToolCommonService:
result = [provider.to_dict() for provider in providers]
- # Cache the result
- ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
-
return result
diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py
index 20cb258277..0ae40199ab 100644
--- a/api/services/tools/workflow_tools_manage_service.py
+++ b/api/services/tools/workflow_tools_manage_service.py
@@ -5,9 +5,8 @@ from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
+from sqlalchemy.orm import Session
-from core.db.session_factory import session_factory
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
@@ -88,17 +87,13 @@ class WorkflowToolManageService:
except Exception as e:
raise ValueError(str(e))
- with session_factory.create_session() as session, session.begin():
+ with Session(db.engine, expire_on_commit=False) as session, session.begin():
session.add(workflow_tool_provider)
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
-
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@classmethod
@@ -188,9 +183,6 @@ class WorkflowToolManageService:
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@classmethod
@@ -253,9 +245,6 @@ class WorkflowToolManageService:
db.session.commit()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@classmethod
diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py
index 57de9b3cee..688993c798 100644
--- a/api/services/trigger/trigger_provider_service.py
+++ b/api/services/trigger/trigger_provider_service.py
@@ -799,7 +799,7 @@ class TriggerProviderService:
user_id: str,
provider_id: TriggerProviderID,
subscription_id: str,
- credentials: Mapping[str, Any],
+ credentials: dict[str, Any],
) -> dict[str, Any]:
"""
Verify credentials for an existing subscription without updating it.
@@ -853,7 +853,7 @@ class TriggerProviderService:
"""
Create a subscription builder for rebuilding an existing subscription.
- This method creates a builder pre-filled with data from the rebuild request,
+ This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub)
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
:param tenant_id: Tenant ID
@@ -876,16 +876,12 @@ class TriggerProviderService:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
- if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
- raise ValueError("Credential type not supported for rebuild")
-
- # TODO: Trying to invoke update api of the plugin trigger provider
-
- # FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
+ if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}:
+ raise ValueError(f"Credential type {credential_type} not supported for auto creation")
# Delete the previous subscription
user_id = subscription.user_id
- TriggerManager.unsubscribe_trigger(
+ unsubscribe_result = TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
@@ -893,15 +889,21 @@ class TriggerProviderService:
credentials=subscription.credentials,
credential_type=credential_type,
)
+ if not unsubscribe_result.success:
+ raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}")
# Create a new subscription with the same subscription_id and endpoint_id
+ new_credentials: dict[str, Any] = {
+ key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
+ for key, value in credentials.items()
+ }
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
- credentials=credentials,
+ credentials=new_credentials,
credential_type=credential_type,
)
TriggerProviderService.update_trigger_subscription(
@@ -909,7 +911,7 @@ class TriggerProviderService:
subscription_id=subscription.id,
name=name,
parameters=parameters,
- credentials=credentials,
+ credentials=new_credentials,
properties=new_subscription.properties,
expires_at=new_subscription.expires_at,
)
diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py
index 5c4607d400..4159f5f8f4 100644
--- a/api/services/trigger/webhook_service.py
+++ b/api/services/trigger/webhook_service.py
@@ -863,10 +863,18 @@ class WebhookService:
not_found_in_cache.append(node_id)
continue
- with Session(db.engine) as session:
- try:
- # lock the concurrent webhook trigger creation
- redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
+ lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
+ lock = redis_client.lock(lock_key, timeout=10)
+ lock_acquired = False
+
+ try:
+ # acquire the lock with blocking and timeout
+ lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
+ if not lock_acquired:
+ logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
+ raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
+
+ with Session(db.engine) as session:
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
@@ -903,11 +911,16 @@ class WebhookService:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
- except Exception:
- logger.exception("Failed to sync webhook relationships for app %s", app.id)
- raise
- finally:
- redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
+ except Exception:
+ logger.exception("Failed to sync webhook relationships for app %s", app.id)
+ raise
+ finally:
+ # release the lock only if it was acquired
+ if lock_acquired:
+ try:
+ lock.release()
+ except Exception:
+ logger.exception("Failed to release lock for webhook sync, app %s", app.id)
@classmethod
def generate_webhook_id(cls) -> str:
diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py
index 0f969207cf..f973361341 100644
--- a/api/services/variable_truncator.py
+++ b/api/services/variable_truncator.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Mapping
@@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator):
self._max_size_bytes = max_size_bytes
@classmethod
- def default(cls) -> "VariableTruncator":
+ def default(cls) -> VariableTruncator:
return VariableTruncator(
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py
index 9bd797a45f..5ca0b63001 100644
--- a/api/services/webapp_auth_service.py
+++ b/api/services/webapp_auth_service.py
@@ -12,6 +12,7 @@ from libs.passport import PassportService
from libs.password import compare_password
from models import Account, AccountStatus
from models.model import App, EndUser, Site
+from services.account_service import AccountService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
@@ -32,7 +33,7 @@ class WebAppAuthService:
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
- account = db.session.query(Account).filter_by(email=email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
raise AccountNotFoundError()
@@ -52,7 +53,7 @@ class WebAppAuthService:
@classmethod
def get_user_through_email(cls, email: str):
- account = db.session.query(Account).where(Account.email == email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
return None
diff --git a/api/services/website_service.py b/api/services/website_service.py
index a23f01ec71..fe48c3b08e 100644
--- a/api/services/website_service.py
+++ b/api/services/website_service.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import datetime
import json
from dataclasses import dataclass
@@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest:
return CrawlRequest(url=self.url, provider=self.provider, options=options)
@classmethod
- def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
+ def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
url = args.get("url")
@@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest:
job_id: str
@classmethod
- def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
+ def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
if not provider:
diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py
index 01f0c7a55a..efc76c33bc 100644
--- a/api/services/workflow_app_service.py
+++ b/api/services/workflow_app_service.py
@@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowExecutionStatus
-from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
+from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole
from models.trigger import WorkflowTriggerLog
from services.plugin.plugin_service import PluginService
@@ -86,12 +86,19 @@ class WorkflowAppService:
# Join to workflow run for filtering when needed.
if keyword:
- keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
+ from libs.helper import escape_like_pattern
+
+ # Escape special characters in keyword to prevent SQL injection via LIKE wildcards
+ escaped_keyword = escape_like_pattern(keyword[:30])
+ keyword_like_val = f"%{escaped_keyword}%"
keyword_conditions = [
- WorkflowRun.inputs.ilike(keyword_like_val),
- WorkflowRun.outputs.ilike(keyword_like_val),
+ WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"),
+ WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"),
# filter keyword by end user session id if created by end user role
- and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
+ and_(
+ WorkflowRun.created_by_role == "end_user",
+ EndUser.session_id.ilike(keyword_like_val, escape="\\"),
+ ),
]
# filter keyword by workflow run id
@@ -166,7 +173,80 @@ class WorkflowAppService:
"data": items,
}
- def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]:
+ def get_paginate_workflow_archive_logs(
+ self,
+ *,
+ session: Session,
+ app_model: App,
+ page: int = 1,
+ limit: int = 20,
+ ):
+ """
+ Get paginate workflow archive logs using SQLAlchemy 2.0 style.
+ """
+ stmt = select(WorkflowArchiveLog).where(
+ WorkflowArchiveLog.tenant_id == app_model.tenant_id,
+ WorkflowArchiveLog.app_id == app_model.id,
+ WorkflowArchiveLog.log_id.isnot(None),
+ )
+
+ stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc())
+
+ count_stmt = select(func.count()).select_from(stmt.subquery())
+ total = session.scalar(count_stmt) or 0
+
+ offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
+
+ logs = list(session.scalars(offset_stmt).all())
+ account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT}
+ end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER}
+
+ accounts_by_id = {}
+ if account_ids:
+ accounts_by_id = {
+ account.id: account
+ for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all()
+ }
+
+ end_users_by_id = {}
+ if end_user_ids:
+ end_users_by_id = {
+ end_user.id: end_user
+ for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all()
+ }
+
+ items = []
+ for log in logs:
+ if log.created_by_role == CreatorUserRole.ACCOUNT:
+ created_by_account = accounts_by_id.get(log.created_by)
+ created_by_end_user = None
+ elif log.created_by_role == CreatorUserRole.END_USER:
+ created_by_account = None
+ created_by_end_user = end_users_by_id.get(log.created_by)
+ else:
+ created_by_account = None
+ created_by_end_user = None
+
+ items.append(
+ {
+ "id": log.id,
+ "workflow_run": log.workflow_run_summary,
+ "trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata),
+ "created_by_account": created_by_account,
+ "created_by_end_user": created_by_end_user,
+ "created_at": log.log_created_at,
+ }
+ )
+
+ return {
+ "page": page,
+ "limit": limit,
+ "total": total,
+ "has_more": total > page * limit,
+ "data": items,
+ }
+
+ def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]:
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
if not metadata:
return {}
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index f299ce3baa..70b0190231 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
-from core.variables import Segment, StringSegment, Variable
+from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import (
ArrayFileSegment,
@@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
- _fallback_variables: Sequence[Variable]
+ _fallback_variables: Sequence[VariableBase]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
- fallback_variables: Sequence[Variable] | None = None,
+ fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
@@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors:
return []
- # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
- variable_by_selector: dict[tuple[str, str], Variable] = {}
+ # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
+ variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
@@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values())
- def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
+ def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.
@@ -679,6 +679,7 @@ def _batch_upsert_draft_variable(
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
d: dict[str, Any] = {
+ "id": model.id,
"app_id": model.app_id,
"last_edited_at": None,
"node_id": model.node_id,
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index 634dd3e2c3..d02f55b504 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -1,4 +1,5 @@
import json
+import logging
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
@@ -15,8 +16,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
-from core.variables import Variable
-from core.variables.variables import VariableUnion
+from core.variables import VariableBase
+from core.variables.variables import Variable
from core.workflow.entities import GraphInitParams, WorkflowNodeExecution
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -220,8 +221,8 @@ class WorkflowService:
features: dict,
unique_hash: str | None,
account: Account,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
) -> Workflow:
"""
Sync draft workflow
@@ -697,7 +698,7 @@ class WorkflowService:
else:
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=[],
@@ -1007,6 +1008,8 @@ class WorkflowService:
@staticmethod
def _load_email_recipients(form_id: str) -> list[DeliveryTestEmailRecipient]:
+ logger = logging.getLogger(__name__)
+
with Session(bind=db.engine) as session:
recipients = session.scalars(
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id)
@@ -1426,7 +1429,7 @@ def _setup_variable_pool(
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
- conversation_variables: list[Variable],
+ conversation_variables: list[VariableBase],
):
# Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node:
@@ -1445,16 +1448,16 @@ def _setup_variable_pool(
system_variable.conversation_id = conversation_id
system_variable.dialogue_count = 1
else:
- system_variable = SystemVariable.empty()
+ system_variable = SystemVariable.default()
# init variable pool
variable_pool = VariablePool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- conversation_variables=cast(list[VariableUnion], conversation_variables), #
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
+ conversation_variables=cast(list[Variable], conversation_variables), #
)
return variable_pool
diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py
index 292ac6e008..3ee41c2e8d 100644
--- a/api/services/workspace_service.py
+++ b/api/services/workspace_service.py
@@ -31,7 +31,8 @@ class WorkspaceService:
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
- can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
+ feature = FeatureService.get_features(tenant.id)
+ can_replace_logo = feature.can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
base_url = dify_config.FILES_URL
@@ -46,5 +47,19 @@ class WorkspaceService:
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
+ if dify_config.EDITION == "CLOUD":
+ tenant_info["next_credit_reset_date"] = feature.next_credit_reset_date
+
+ from services.credit_pool_service import CreditPoolService
+
+ paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
+ if paid_pool:
+ tenant_info["trial_credits"] = paid_pool.quota_limit
+ tenant_info["trial_credits_used"] = paid_pool.quota_used
+ else:
+ trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
+ if trial_pool:
+ tenant_info["trial_credits"] = trial_pool.quota_limit
+ tenant_info["trial_credits_used"] = trial_pool.quota_used
return tenant_info
diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py
index e7dead8a56..62e6497e9d 100644
--- a/api/tasks/add_document_to_index_task.py
+++ b/api/tasks/add_document_to_index_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
@@ -28,106 +28,106 @@ def add_document_to_index_task(dataset_document_id: str):
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
start_at = time.perf_counter()
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
+ if not dataset_document:
+ logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
+ return
- if dataset_document.indexing_status != "completed":
- db.session.close()
- return
+ if dataset_document.indexing_status != "completed":
+ return
- indexing_cache_key = f"document_{dataset_document.id}_indexing"
+ indexing_cache_key = f"document_{dataset_document.id}_indexing"
- try:
- dataset = dataset_document.dataset
- if not dataset:
- raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
+ try:
+ dataset = dataset_document.dataset
+ if not dataset:
+ raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
- segments = (
- db.session.query(DocumentSegment)
- .where(
- DocumentSegment.document_id == dataset_document.id,
- DocumentSegment.status == "completed",
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.status == "completed",
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
+
+ # delete auto disable log
+ session.query(DatasetAutoDisableLog).where(
+ DatasetAutoDisableLog.document_id == dataset_document.id
+ ).delete()
+
+ # update segment to enable
+ session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
+ {
+ DocumentSegment.enabled: True,
+ DocumentSegment.disabled_at: None,
+ DocumentSegment.disabled_by: None,
+ DocumentSegment.updated_at: naive_utc_now(),
+ }
)
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
+ session.commit()
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
-
- # delete auto disable log
- db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
-
- # update segment to enable
- db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
- {
- DocumentSegment.enabled: True,
- DocumentSegment.disabled_at: None,
- DocumentSegment.disabled_by: None,
- DocumentSegment.updated_at: naive_utc_now(),
- }
- )
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
- )
- except Exception as e:
- logger.exception("add document to index failed")
- dataset_document.enabled = False
- dataset_document.disabled_at = naive_utc_now()
- dataset_document.indexing_status = "error"
- dataset_document.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
+ )
+ except Exception as e:
+ logger.exception("add document to index failed")
+ dataset_document.enabled = False
+ dataset_document.disabled_at = naive_utc_now()
+ dataset_document.indexing_status = "error"
+ dataset_document.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py
index 775814318b..fc6bf03454 100644
--- a/api/tasks/annotation/batch_import_annotations_task.py
+++ b/api/tasks/annotation/batch_import_annotations_task.py
@@ -5,9 +5,9 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@@ -32,74 +32,72 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
active_jobs_key = f"annotation_import_active:{tenant_id}"
- # get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ with session_factory.create_session() as session:
+ # get app info
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- if app:
- try:
- documents = []
- for content in content_list:
- annotation = MessageAnnotation(
- app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
+ if app:
+ try:
+ documents = []
+ for content in content_list:
+ annotation = MessageAnnotation(
+ app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
+ )
+ session.add(annotation)
+ session.flush()
+
+ document = Document(
+ page_content=content["question"],
+ metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ )
+ documents.append(document)
+ # if annotation reply is enabled , batch add annotations' index
+ app_annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
- db.session.add(annotation)
- db.session.flush()
- document = Document(
- page_content=content["question"],
- metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
- )
- documents.append(document)
- # if annotation reply is enabled , batch add annotations' index
- app_annotation_setting = (
- db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
- )
+ if app_annotation_setting:
+ dataset_collection_binding = (
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ app_annotation_setting.collection_binding_id, "annotation"
+ )
+ )
+ if not dataset_collection_binding:
+ raise NotFound("App annotation setting not found")
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=dataset_collection_binding.provider_name,
+ embedding_model=dataset_collection_binding.model_name,
+ collection_binding_id=dataset_collection_binding.id,
+ )
- if app_annotation_setting:
- dataset_collection_binding = (
- DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
- app_annotation_setting.collection_binding_id, "annotation"
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ vector.create(documents, duplicate_check=True)
+
+ session.commit()
+ redis_client.setex(indexing_cache_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Build index successful for batch import annotation: {} latency: {}".format(
+ job_id, end_at - start_at
+ ),
+ fg="green",
)
)
- if not dataset_collection_binding:
- raise NotFound("App annotation setting not found")
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=dataset_collection_binding.provider_name,
- embedding_model=dataset_collection_binding.model_name,
- collection_binding_id=dataset_collection_binding.id,
- )
-
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- vector.create(documents, duplicate_check=True)
-
- db.session.commit()
- redis_client.setex(indexing_cache_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Build index successful for batch import annotation: {} latency: {}".format(
- job_id, end_at - start_at
- ),
- fg="green",
- )
- )
- except Exception as e:
- db.session.rollback()
- redis_client.setex(indexing_cache_key, 600, "error")
- indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
- redis_client.setex(indexing_error_msg_key, 600, str(e))
- logger.exception("Build index for batch import annotations failed")
- finally:
- # Clean up active job tracking to release concurrency slot
- try:
- redis_client.zrem(active_jobs_key, job_id)
- logger.debug("Released concurrency slot for job: %s", job_id)
- except Exception as cleanup_error:
- # Log but don't fail if cleanup fails - the job will be auto-expired
- logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
-
- # Close database session
- db.session.close()
+ except Exception as e:
+ session.rollback()
+ redis_client.setex(indexing_cache_key, 600, "error")
+ indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
+ redis_client.setex(indexing_error_msg_key, 600, str(e))
+ logger.exception("Build index for batch import annotations failed")
+ finally:
+ # Clean up active job tracking to release concurrency slot
+ try:
+ redis_client.zrem(active_jobs_key, job_id)
+ logger.debug("Released concurrency slot for job: %s", job_id)
+ except Exception as cleanup_error:
+ # Log but don't fail if cleanup fails - the job will be auto-expired
+ logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py
index c0020b29ed..7b5cd46b00 100644
--- a/api/tasks/annotation/disable_annotation_reply_task.py
+++ b/api/tasks/annotation/disable_annotation_reply_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import exists, select
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@@ -22,50 +22,55 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
- if not app:
- logger.info(click.style(f"App not found: {app_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
+ if not app:
+ logger.info(click.style(f"App not found: {app_id}", fg="red"))
+ return
- app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
-
- if not app_annotation_setting:
- logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
- db.session.close()
- return
-
- disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
- disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
-
- try:
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- collection_binding_id=app_annotation_setting.collection_binding_id,
+ app_annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
+ if not app_annotation_setting:
+ logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
+ return
+
+ disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
+ disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
+
try:
- if annotations_exists:
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- vector.delete()
- except Exception:
- logger.exception("Delete annotation index failed when annotation deleted.")
- redis_client.setex(disable_app_annotation_job_key, 600, "completed")
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ collection_binding_id=app_annotation_setting.collection_binding_id,
+ )
- # delete annotation setting
- db.session.delete(app_annotation_setting)
- db.session.commit()
+ try:
+ if annotations_exists:
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ vector.delete()
+ except Exception:
+ logger.exception("Delete annotation index failed when annotation deleted.")
+ redis_client.setex(disable_app_annotation_job_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("Annotation batch deleted index failed")
- redis_client.setex(disable_app_annotation_job_key, 600, "error")
- disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
- redis_client.setex(disable_app_annotation_error_key, 600, str(e))
- finally:
- redis_client.delete(disable_app_annotation_key)
- db.session.close()
+ # delete annotation setting
+ session.delete(app_annotation_setting)
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"App annotations index deleted : {app_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ logger.exception("Annotation batch deleted index failed")
+ redis_client.setex(disable_app_annotation_job_key, 600, "error")
+ disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
+ redis_client.setex(disable_app_annotation_error_key, 600, str(e))
+ finally:
+ redis_client.delete(disable_app_annotation_key)
diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py
index cdc07c77a8..4f8e2fec7a 100644
--- a/api/tasks/annotation/enable_annotation_reply_task.py
+++ b/api/tasks/annotation/enable_annotation_reply_task.py
@@ -5,9 +5,9 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
@@ -33,92 +33,98 @@ def enable_annotation_reply_task(
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ with session_factory.create_session() as session:
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- if not app:
- logger.info(click.style(f"App not found: {app_id}", fg="red"))
- db.session.close()
- return
+ if not app:
+ logger.info(click.style(f"App not found: {app_id}", fg="red"))
+ return
- annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
- enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
- enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
+ annotations = session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
+ enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
+ enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
- try:
- documents = []
- dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
- embedding_provider_name, embedding_model_name, "annotation"
- )
- annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
- if annotation_setting:
- if dataset_collection_binding.id != annotation_setting.collection_binding_id:
- old_dataset_collection_binding = (
- DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
- annotation_setting.collection_binding_id, "annotation"
- )
- )
- if old_dataset_collection_binding and annotations:
- old_dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=old_dataset_collection_binding.provider_name,
- embedding_model=old_dataset_collection_binding.model_name,
- collection_binding_id=old_dataset_collection_binding.id,
- )
-
- old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
- try:
- old_vector.delete()
- except Exception as e:
- logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
- annotation_setting.score_threshold = score_threshold
- annotation_setting.collection_binding_id = dataset_collection_binding.id
- annotation_setting.updated_user_id = user_id
- annotation_setting.updated_at = naive_utc_now()
- db.session.add(annotation_setting)
- else:
- new_app_annotation_setting = AppAnnotationSetting(
- app_id=app_id,
- score_threshold=score_threshold,
- collection_binding_id=dataset_collection_binding.id,
- created_user_id=user_id,
- updated_user_id=user_id,
+ try:
+ documents = []
+ dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+ embedding_provider_name, embedding_model_name, "annotation"
)
- db.session.add(new_app_annotation_setting)
+ annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
+ )
+ if annotation_setting:
+ if dataset_collection_binding.id != annotation_setting.collection_binding_id:
+ old_dataset_collection_binding = (
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ annotation_setting.collection_binding_id, "annotation"
+ )
+ )
+ if old_dataset_collection_binding and annotations:
+ old_dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=old_dataset_collection_binding.provider_name,
+ embedding_model=old_dataset_collection_binding.model_name,
+ collection_binding_id=old_dataset_collection_binding.id,
+ )
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=embedding_provider_name,
- embedding_model=embedding_model_name,
- collection_binding_id=dataset_collection_binding.id,
- )
- if annotations:
- for annotation in annotations:
- document = Document(
- page_content=annotation.question,
- metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ try:
+ old_vector.delete()
+ except Exception as e:
+ logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
+ annotation_setting.score_threshold = score_threshold
+ annotation_setting.collection_binding_id = dataset_collection_binding.id
+ annotation_setting.updated_user_id = user_id
+ annotation_setting.updated_at = naive_utc_now()
+ session.add(annotation_setting)
+ else:
+ new_app_annotation_setting = AppAnnotationSetting(
+ app_id=app_id,
+ score_threshold=score_threshold,
+ collection_binding_id=dataset_collection_binding.id,
+ created_user_id=user_id,
+ updated_user_id=user_id,
)
- documents.append(document)
+ session.add(new_app_annotation_setting)
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- try:
- vector.delete_by_metadata_field("app_id", app_id)
- except Exception as e:
- logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
- vector.create(documents)
- db.session.commit()
- redis_client.setex(enable_app_annotation_job_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("Annotation batch created index failed")
- redis_client.setex(enable_app_annotation_job_key, 600, "error")
- enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
- redis_client.setex(enable_app_annotation_error_key, 600, str(e))
- db.session.rollback()
- finally:
- redis_client.delete(enable_app_annotation_key)
- db.session.close()
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=embedding_provider_name,
+ embedding_model=embedding_model_name,
+ collection_binding_id=dataset_collection_binding.id,
+ )
+ if annotations:
+ for annotation in annotations:
+ document = Document(
+ page_content=annotation.question_text,
+ metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ )
+ documents.append(document)
+
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ try:
+ vector.delete_by_metadata_field("app_id", app_id)
+ except Exception as e:
+ logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
+ vector.create(documents)
+ session.commit()
+ redis_client.setex(enable_app_annotation_job_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"App annotations added to index: {app_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ logger.exception("Annotation batch created index failed")
+ redis_client.setex(enable_app_annotation_job_key, 600, "error")
+ enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
+ redis_client.setex(enable_app_annotation_error_key, 600, str(e))
+ session.rollback()
+ finally:
+ redis_client.delete(enable_app_annotation_key)
diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py
index dd7faae7a3..cc96542d4b 100644
--- a/api/tasks/async_workflow_tasks.py
+++ b/api/tasks/async_workflow_tasks.py
@@ -19,11 +19,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.app.layers.timeslice_layer import TimeSliceLayer
from core.app.layers.trigger_post_layer import TriggerPostLayer
+from core.db.session_factory import session_factory
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_database import db
from models.account import Account
-from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
+from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
@@ -107,10 +108,7 @@ def _execute_workflow_common(
):
"""Execute workflow with common logic and trigger log updates."""
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
# Get trigger log
@@ -154,7 +152,7 @@ def _execute_workflow_common(
args["workflow_id"] = str(trigger_data.workflow_id)
pause_config = PauseStateLayerConfig(
- session_factory=session_factory,
+ session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
@@ -171,7 +169,7 @@ def _execute_workflow_common(
root_node_id=trigger_data.root_node_id,
graph_engine_layers=[
# TODO: Re-enable TimeSliceLayer after the HITL release.
- TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
+ TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
],
pause_state_config=pause_config,
)
@@ -271,7 +269,7 @@ def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
graph_engine_layers.extend(
[
TimeSliceLayer(cfs_plan_scheduler),
- TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
+ TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
]
)
@@ -291,7 +289,7 @@ def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
workflow_run_repo.delete_workflow_pause(pause_entity)
-def _get_user(session: Session, workflow_run: WorkflowRun) -> Account | EndUser:
+def _get_user(session: Session, workflow_run: WorkflowRun | WorkflowTriggerLog) -> Account | EndUser:
"""Compose user from trigger log"""
tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id))
if not tenant:
@@ -326,19 +324,9 @@ def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
if not trigger_log:
logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id)
- return
+ return None
- cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
- queue=trigger_log.queue_name,
- schedule_strategy=AsyncWorkflowSystemStrategy,
- granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
- )
- cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity)
-
- try:
- trigger_type = AppTriggerType(trigger_log.trigger_type)
- except ValueError:
- trigger_type = AppTriggerType.UNKNOWN
+ return trigger_log
class _TenantNotFoundError(Exception):
diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py
index 3e1bd16cc7..74b939e84d 100644
--- a/api/tasks/batch_clean_document_task.py
+++ b/api/tasks/batch_clean_document_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
@@ -28,65 +28,64 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
"""
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
+ if not doc_form:
+ raise ValueError("doc_form is required")
- try:
- if not doc_form:
- raise ValueError("doc_form is required")
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- db.session.query(DatasetMetadataBinding).where(
- DatasetMetadataBinding.dataset_id == dataset_id,
- DatasetMetadataBinding.document_id.in_(document_ids),
- ).delete(synchronize_session=False)
+ session.query(DatasetMetadataBinding).where(
+ DatasetMetadataBinding.dataset_id == dataset_id,
+ DatasetMetadataBinding.document_id.in_(document_ids),
+ ).delete(synchronize_session=False)
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
- ).all()
- # check segment is exist
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
+ ).all()
+ # check segment is exist
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
+ for image_file in image_files:
+ try:
+ if image_file and image_file.key:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+ stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(stmt)
+ session.delete(segment)
+ if file_ids:
+ files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
+ for file in files:
try:
- if image_file and image_file.key:
- storage.delete(image_file.key)
+ storage.delete(file.key)
except Exception:
- logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
- )
- db.session.delete(image_file)
- db.session.delete(segment)
+ logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
+ stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
+ session.execute(stmt)
- db.session.commit()
- if file_ids:
- files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
- for file in files:
- try:
- storage.delete(file.key)
- except Exception:
- logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
- db.session.delete(file)
+ session.commit()
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Cleaned documents when documents deleted latency: {end_at - start_at}",
- fg="green",
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned documents when documents deleted latency: {end_at - start_at}",
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned documents when documents deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned documents when documents deleted failed")
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index bd95af2614..8ee09d5738 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -9,9 +9,9 @@ import pandas as pd
from celery import shared_task
from sqlalchemy import func
+from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
@@ -48,104 +48,107 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
- try:
- dataset = db.session.get(Dataset, dataset_id)
- if not dataset:
- raise ValueError("Dataset not exist.")
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.get(Dataset, dataset_id)
+ if not dataset:
+ raise ValueError("Dataset not exist.")
- dataset_document = db.session.get(Document, document_id)
- if not dataset_document:
- raise ValueError("Document not exist.")
+ dataset_document = session.get(Document, document_id)
+ if not dataset_document:
+ raise ValueError("Document not exist.")
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- raise ValueError("Document is not available.")
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ raise ValueError("Document is not available.")
- upload_file = db.session.get(UploadFile, upload_file_id)
- if not upload_file:
- raise ValueError("UploadFile not found.")
+ upload_file = session.get(UploadFile, upload_file_id)
+ if not upload_file:
+ raise ValueError("UploadFile not found.")
- with tempfile.TemporaryDirectory() as temp_dir:
- suffix = Path(upload_file.key).suffix
- file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
- storage.download(upload_file.key, file_path)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ suffix = Path(upload_file.key).suffix
+ file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
+ storage.download(upload_file.key, file_path)
- df = pd.read_csv(file_path)
- content = []
- for _, row in df.iterrows():
+ df = pd.read_csv(file_path)
+ content = []
+ for _, row in df.iterrows():
+ if dataset_document.doc_form == "qa_model":
+ data = {"content": row.iloc[0], "answer": row.iloc[1]}
+ else:
+ data = {"content": row.iloc[0]}
+ content.append(data)
+ if len(content) == 0:
+ raise ValueError("The CSV file is empty.")
+
+ document_segments = []
+ embedding_model = None
+ if dataset.indexing_technique == "high_quality":
+ model_manager = ModelManager()
+ embedding_model = model_manager.get_model_instance(
+ tenant_id=dataset.tenant_id,
+ provider=dataset.embedding_model_provider,
+ model_type=ModelType.TEXT_EMBEDDING,
+ model=dataset.embedding_model,
+ )
+
+ word_count_change = 0
+ if embedding_model:
+ tokens_list = embedding_model.get_text_embedding_num_tokens(
+ texts=[segment["content"] for segment in content]
+ )
+ else:
+ tokens_list = [0] * len(content)
+
+ for segment, tokens in zip(content, tokens_list):
+ content = segment["content"]
+ doc_id = str(uuid.uuid4())
+ segment_hash = helper.generate_text_hash(content)
+ max_position = (
+ session.query(func.max(DocumentSegment.position))
+ .where(DocumentSegment.document_id == dataset_document.id)
+ .scalar()
+ )
+ segment_document = DocumentSegment(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ document_id=document_id,
+ index_node_id=doc_id,
+ index_node_hash=segment_hash,
+ position=max_position + 1 if max_position else 1,
+ content=content,
+ word_count=len(content),
+ tokens=tokens,
+ created_by=user_id,
+ indexing_at=naive_utc_now(),
+ status="completed",
+ completed_at=naive_utc_now(),
+ )
if dataset_document.doc_form == "qa_model":
- data = {"content": row.iloc[0], "answer": row.iloc[1]}
- else:
- data = {"content": row.iloc[0]}
- content.append(data)
- if len(content) == 0:
- raise ValueError("The CSV file is empty.")
+ segment_document.answer = segment["answer"]
+ segment_document.word_count += len(segment["answer"])
+ word_count_change += segment_document.word_count
+ session.add(segment_document)
+ document_segments.append(segment_document)
- document_segments = []
- embedding_model = None
- if dataset.indexing_technique == "high_quality":
- model_manager = ModelManager()
- embedding_model = model_manager.get_model_instance(
- tenant_id=dataset.tenant_id,
- provider=dataset.embedding_model_provider,
- model_type=ModelType.TEXT_EMBEDDING,
- model=dataset.embedding_model,
- )
+ assert dataset_document.word_count is not None
+ dataset_document.word_count += word_count_change
+ session.add(dataset_document)
- word_count_change = 0
- if embedding_model:
- tokens_list = embedding_model.get_text_embedding_num_tokens(
- texts=[segment["content"] for segment in content]
+ VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
+ session.commit()
+ redis_client.setex(indexing_cache_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Segment batch created job: {job_id} latency: {end_at - start_at}",
+ fg="green",
+ )
)
- else:
- tokens_list = [0] * len(content)
-
- for segment, tokens in zip(content, tokens_list):
- content = segment["content"]
- doc_id = str(uuid.uuid4())
- segment_hash = helper.generate_text_hash(content)
- max_position = (
- db.session.query(func.max(DocumentSegment.position))
- .where(DocumentSegment.document_id == dataset_document.id)
- .scalar()
- )
- segment_document = DocumentSegment(
- tenant_id=tenant_id,
- dataset_id=dataset_id,
- document_id=document_id,
- index_node_id=doc_id,
- index_node_hash=segment_hash,
- position=max_position + 1 if max_position else 1,
- content=content,
- word_count=len(content),
- tokens=tokens,
- created_by=user_id,
- indexing_at=naive_utc_now(),
- status="completed",
- completed_at=naive_utc_now(),
- )
- if dataset_document.doc_form == "qa_model":
- segment_document.answer = segment["answer"]
- segment_document.word_count += len(segment["answer"])
- word_count_change += segment_document.word_count
- db.session.add(segment_document)
- document_segments.append(segment_document)
-
- assert dataset_document.word_count is not None
- dataset_document.word_count += word_count_change
- db.session.add(dataset_document)
-
- VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
- db.session.commit()
- redis_client.setex(indexing_cache_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Segment batch created job: {job_id} latency: {end_at - start_at}",
- fg="green",
- )
- )
- except Exception:
- logger.exception("Segments batch created index failed")
- redis_client.setex(indexing_cache_key, 600, "error")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Segments batch created index failed")
+ redis_client.setex(indexing_cache_key, 600, "error")
diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py
index b4d82a150d..0d51a743ad 100644
--- a/api/tasks/clean_dataset_task.py
+++ b/api/tasks/clean_dataset_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
from models import WorkflowType
from models.dataset import (
@@ -53,135 +53,155 @@ def clean_dataset_task(
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = Dataset(
- id=dataset_id,
- tenant_id=tenant_id,
- indexing_technique=indexing_technique,
- index_struct=index_struct,
- collection_binding_id=collection_binding_id,
- )
- documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
- # Use JOIN to fetch attachments with bindings in a single query
- attachments_with_bindings = db.session.execute(
- select(SegmentAttachmentBinding, UploadFile)
- .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
- .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
- ).all()
-
- # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
- # This ensures all invalid doc_form values are properly handled
- if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
- # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
- from core.rag.index_processor.constant.index_type import IndexStructureType
-
- doc_form = IndexStructureType.PARAGRAPH_INDEX
- logger.info(
- click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
- )
-
- # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
- # This ensures Document/Segment deletion can continue even if vector database cleanup fails
+ with session_factory.create_session() as session:
try:
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
- logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
- except Exception:
- logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
- # Continue with document and segment deletion even if vector cleanup fails
- logger.info(
- click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
+ dataset = Dataset(
+ id=dataset_id,
+ tenant_id=tenant_id,
+ indexing_technique=indexing_technique,
+ index_struct=index_struct,
+ collection_binding_id=collection_binding_id,
)
+ documents = session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
+ # Use JOIN to fetch attachments with bindings in a single query
+ attachments_with_bindings = session.execute(
+ select(SegmentAttachmentBinding, UploadFile)
+ .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
+ .where(
+ SegmentAttachmentBinding.tenant_id == tenant_id,
+ SegmentAttachmentBinding.dataset_id == dataset_id,
+ )
+ ).all()
- if documents is None or len(documents) == 0:
- logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
- else:
- logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
+ # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
+ # This ensures all invalid doc_form values are properly handled
+ if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
+ # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
+ from core.rag.index_processor.constant.index_type import IndexStructureType
- for document in documents:
- db.session.delete(document)
- # delete document file
+ doc_form = IndexStructureType.PARAGRAPH_INDEX
+ logger.info(
+ click.style(
+ f"Invalid doc_form detected, using default index type for cleanup: {doc_form}",
+ fg="yellow",
+ )
+ )
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
- if image_file is None:
- continue
+ # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
+ # This ensures Document/Segment deletion can continue even if vector database cleanup fails
+ try:
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
+ logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
+ except Exception:
+ logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
+ # Continue with document and segment deletion even if vector cleanup fails
+ logger.info(
+ click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
+ )
+
+ if documents is None or len(documents) == 0:
+ logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
+ else:
+ logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
+
+ for document in documents:
+ session.delete(document)
+
+ segment_ids = [segment.id for segment in segments]
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
+ for image_file in image_files:
+ if image_file is None:
+ continue
+ try:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+ stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(stmt)
+
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ # delete segment attachments
+ if attachments_with_bindings:
+ attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
+ binding_ids = [binding.id for binding, _ in attachments_with_bindings]
+ for binding, attachment_file in attachments_with_bindings:
try:
- storage.delete(image_file.key)
+ storage.delete(attachment_file.key)
except Exception:
logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
+ "Delete attachment_file failed when storage deleted, \
+ attachment_file_id: %s",
+ binding.attachment_id,
)
- db.session.delete(image_file)
- db.session.delete(segment)
- # delete segment attachments
- if attachments_with_bindings:
- for binding, attachment_file in attachments_with_bindings:
- try:
- storage.delete(attachment_file.key)
- except Exception:
- logger.exception(
- "Delete attachment_file failed when storage deleted, \
- attachment_file_id: %s",
- binding.attachment_id,
- )
- db.session.delete(attachment_file)
- db.session.delete(binding)
+ attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
+ session.execute(attachment_file_delete_stmt)
- db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
- db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
- db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
- # delete dataset metadata
- db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
- db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
- # delete pipeline and workflow
- if pipeline_id:
- db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
- db.session.query(Workflow).where(
- Workflow.tenant_id == tenant_id,
- Workflow.app_id == pipeline_id,
- Workflow.type == WorkflowType.RAG_PIPELINE,
- ).delete()
- # delete files
- if documents:
- for document in documents:
- try:
+ binding_delete_stmt = delete(SegmentAttachmentBinding).where(
+ SegmentAttachmentBinding.id.in_(binding_ids)
+ )
+ session.execute(binding_delete_stmt)
+
+ session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
+ session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
+ session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
+ # delete dataset metadata
+ session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
+ session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
+ # delete pipeline and workflow
+ if pipeline_id:
+ session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
+ session.query(Workflow).where(
+ Workflow.tenant_id == tenant_id,
+ Workflow.app_id == pipeline_id,
+ Workflow.type == WorkflowType.RAG_PIPELINE,
+ ).delete()
+ # delete files
+ if documents:
+ file_ids = []
+ for document in documents:
if document.data_source_type == "upload_file":
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
- file = (
- db.session.query(UploadFile)
- .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
- .first()
- )
- if not file:
- continue
- storage.delete(file.key)
- db.session.delete(file)
- except Exception:
- continue
+ file_ids.append(file_id)
+ files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
+ for file in files:
+ storage.delete(file.key)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
- )
- except Exception:
- # Add rollback to prevent dirty session state in case of exceptions
- # This ensures the database session is properly cleaned up
- try:
- db.session.rollback()
- logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
+ file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
+ session.execute(file_delete_stmt)
+
+ session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
except Exception:
- logger.exception("Failed to rollback database session")
+ # Add rollback to prevent dirty session state in case of exceptions
+ # This ensures the database session is properly cleaned up
+ try:
+ session.rollback()
+ logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
+ except Exception:
+ logger.exception("Failed to rollback database session")
- logger.exception("Cleaned dataset when dataset deleted failed")
- finally:
- db.session.close()
+ logger.exception("Cleaned dataset when dataset deleted failed")
+ finally:
+ # Explicitly close the session for test expectations and safety
+ try:
+ session.close()
+ except Exception:
+ logger.exception("Failed to close database session")
diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py
index 6d2feb1da3..86e7cc7160 100644
--- a/api/tasks/clean_document_task.py
+++ b/api/tasks/clean_document_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
from models.model import UploadFile
@@ -29,85 +29,94 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- # Use JOIN to fetch attachments with bindings in a single query
- attachments_with_bindings = db.session.execute(
- select(SegmentAttachmentBinding, UploadFile)
- .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
- .where(
- SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
- SegmentAttachmentBinding.dataset_id == dataset_id,
- SegmentAttachmentBinding.document_id == document_id,
- )
- ).all()
- # check segment is exist
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ # Use JOIN to fetch attachments with bindings in a single query
+ attachments_with_bindings = session.execute(
+ select(SegmentAttachmentBinding, UploadFile)
+ .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
+ .where(
+ SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
+ SegmentAttachmentBinding.dataset_id == dataset_id,
+ SegmentAttachmentBinding.document_id == document_id,
+ )
+ ).all()
+ # check segment is exist
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
- if image_file is None:
- continue
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.scalars(
+ select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ ).all()
+ for image_file in image_files:
+ if image_file is None:
+ continue
+ try:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+
+ image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(image_file_delete_stmt)
+ session.delete(segment)
+
+ session.commit()
+ if file_id:
+ file = session.query(UploadFile).where(UploadFile.id == file_id).first()
+ if file:
try:
- storage.delete(image_file.key)
+ storage.delete(file.key)
+ except Exception:
+ logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
+ session.delete(file)
+ # delete segment attachments
+ if attachments_with_bindings:
+ attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
+ binding_ids = [binding.id for binding, _ in attachments_with_bindings]
+ for binding, attachment_file in attachments_with_bindings:
+ try:
+ storage.delete(attachment_file.key)
except Exception:
logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
+ "Delete attachment_file failed when storage deleted, \
+ attachment_file_id: %s",
+ binding.attachment_id,
)
- db.session.delete(image_file)
- db.session.delete(segment)
+ attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
+ session.execute(attachment_file_delete_stmt)
- db.session.commit()
- if file_id:
- file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
- if file:
- try:
- storage.delete(file.key)
- except Exception:
- logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
- db.session.delete(file)
- db.session.commit()
- # delete segment attachments
- if attachments_with_bindings:
- for binding, attachment_file in attachments_with_bindings:
- try:
- storage.delete(attachment_file.key)
- except Exception:
- logger.exception(
- "Delete attachment_file failed when storage deleted, \
- attachment_file_id: %s",
- binding.attachment_id,
- )
- db.session.delete(attachment_file)
- db.session.delete(binding)
+ binding_delete_stmt = delete(SegmentAttachmentBinding).where(
+ SegmentAttachmentBinding.id.in_(binding_ids)
+ )
+ session.execute(binding_delete_stmt)
- # delete dataset metadata binding
- db.session.query(DatasetMetadataBinding).where(
- DatasetMetadataBinding.dataset_id == dataset_id,
- DatasetMetadataBinding.document_id == document_id,
- ).delete()
- db.session.commit()
+ # delete dataset metadata binding
+ session.query(DatasetMetadataBinding).where(
+ DatasetMetadataBinding.dataset_id == dataset_id,
+ DatasetMetadataBinding.document_id == document_id,
+ ).delete()
+ session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
- fg="green",
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when document deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned document when document deleted failed")
diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py
index 771b43f9b0..bcca1bf49f 100644
--- a/api/tasks/clean_notion_document_task.py
+++ b/api/tasks/clean_notion_document_task.py
@@ -3,10 +3,10 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@@ -24,37 +24,37 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- for document_id in document_ids:
- document = db.session.query(Document).where(Document.id == document_id).first()
- db.session.delete(document)
+ if not dataset:
+ raise Exception("Document has no dataset")
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- index_node_ids = [segment.index_node_id for segment in segments]
+ document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
+ session.execute(document_delete_stmt)
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ for document_id in document_ids:
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ index_node_ids = [segment.index_node_id for segment in segments]
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Clean document when import form notion document deleted end :: {} latency: {}".format(
- dataset_id, end_at - start_at
- ),
- fg="green",
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Clean document when import form notion document deleted end :: {} latency: {}".format(
+ dataset_id, end_at - start_at
+ ),
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when import form notion document deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned document when import form notion document deleted failed")
diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py
index 6b2907cffd..b5e472d71e 100644
--- a/api/tasks/create_segment_to_index_task.py
+++ b/api/tasks/create_segment_to_index_task.py
@@ -4,9 +4,9 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@@ -25,75 +25,77 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "waiting":
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- # update segment status to indexing
- db.session.query(DocumentSegment).filter_by(id=segment.id).update(
- {
- DocumentSegment.status: "indexing",
- DocumentSegment.indexing_at: naive_utc_now(),
- }
- )
- db.session.commit()
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
-
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "waiting":
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.load(dataset, [document])
+ try:
+ # update segment status to indexing
+ session.query(DocumentSegment).filter_by(id=segment.id).update(
+ {
+ DocumentSegment.status: "indexing",
+ DocumentSegment.indexing_at: naive_utc_now(),
+ }
+ )
+ session.commit()
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- # update segment to completed
- db.session.query(DocumentSegment).filter_by(id=segment.id).update(
- {
- DocumentSegment.status: "completed",
- DocumentSegment.completed_at: naive_utc_now(),
- }
- )
- db.session.commit()
+ dataset = segment.dataset
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("create segment to index failed")
- segment.enabled = False
- segment.disabled_at = naive_utc_now()
- segment.status = "error"
- segment.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.load(dataset, [document])
+
+ # update segment to completed
+ session.query(DocumentSegment).filter_by(id=segment.id).update(
+ {
+ DocumentSegment.status: "completed",
+ DocumentSegment.completed_at: naive_utc_now(),
+ }
+ )
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("create segment to index failed")
+ segment.enabled = False
+ segment.disabled_at = naive_utc_now()
+ segment.status = "error"
+ segment.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py
index 3d13afdec0..fa844a8647 100644
--- a/api/tasks/deal_dataset_index_update_task.py
+++ b/api/tasks/deal_dataset_index_update_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task # type: ignore
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -24,166 +24,174 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).filter_by(id=dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- if action == "upgrade":
- dataset_documents = (
- db.session.query(DatasetDocument)
- .where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ if action == "upgrade":
+ dataset_documents = (
+ session.query(DatasetDocument)
+ .where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ .all()
)
- .all()
- )
- if dataset_documents:
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
+ if dataset_documents:
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
- for dataset_document in dataset_documents:
- try:
- # add from vector index
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ for dataset_document in dataset_documents:
+ try:
+ # add from vector index
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
)
-
- documents.append(document)
- # save vector index
- # clean keywords
- index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
- index_processor.load(dataset, documents, with_keywords=False)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- elif action == "update":
- dataset_documents = (
- db.session.query(DatasetDocument)
- .where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- .all()
- )
- # add new index
- if dataset_documents:
- # update document status
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
-
- # clean index
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
-
- for dataset_document in dataset_documents:
- # update from vector index
- try:
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(
- dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- else:
- # clean collection
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+ if segments:
+ documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- end_at = time.perf_counter()
- logging.info(
- click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
- )
- except Exception:
- logging.exception("Deal dataset vector index failed")
- finally:
- db.session.close()
+ documents.append(document)
+ # save vector index
+ # clean keywords
+ index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
+ index_processor.load(dataset, documents, with_keywords=False)
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ elif action == "update":
+ dataset_documents = (
+ session.query(DatasetDocument)
+ .where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ .all()
+ )
+ # add new index
+ if dataset_documents:
+ # update document status
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
+
+ # clean index
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ for dataset_document in dataset_documents:
+ # update from vector index
+ try:
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
+ )
+ if segments:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+ # save vector index
+ index_processor.load(
+ dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ )
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ else:
+ # clean collection
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("Deal dataset vector index failed")
diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py
index 1c7de3b1ce..0047e04a17 100644
--- a/api/tasks/deal_dataset_vector_index_task.py
+++ b/api/tasks/deal_dataset_vector_index_task.py
@@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -27,160 +27,170 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).filter_by(id=dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- if action == "remove":
- index_processor.clean(dataset, None, with_keywords=False)
- elif action == "add":
- dataset_documents = db.session.scalars(
- select(DatasetDocument).where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- ).all()
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ if action == "remove":
+ index_processor.clean(dataset, None, with_keywords=False)
+ elif action == "add":
+ dataset_documents = session.scalars(
+ select(DatasetDocument).where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ ).all()
- if dataset_documents:
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
+ if dataset_documents:
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
- for dataset_document in dataset_documents:
- try:
- # add from vector index
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ for dataset_document in dataset_documents:
+ try:
+ # add from vector index
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
)
-
- documents.append(document)
- # save vector index
- index_processor.load(dataset, documents, with_keywords=False)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- elif action == "update":
- dataset_documents = db.session.scalars(
- select(DatasetDocument).where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- ).all()
- # add new index
- if dataset_documents:
- # update document status
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
-
- # clean index
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
-
- for dataset_document in dataset_documents:
- # update from vector index
- try:
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(
- dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- else:
- # clean collection
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+ if segments:
+ documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- end_at = time.perf_counter()
- logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("Deal dataset vector index failed")
- finally:
- db.session.close()
+ documents.append(document)
+ # save vector index
+ index_processor.load(dataset, documents, with_keywords=False)
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ elif action == "update":
+ dataset_documents = session.scalars(
+ select(DatasetDocument).where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ ).all()
+ # add new index
+ if dataset_documents:
+ # update document status
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
+
+ # clean index
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ for dataset_document in dataset_documents:
+ # update from vector index
+ try:
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
+ )
+ if segments:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+ # save vector index
+ index_processor.load(
+ dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ )
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ else:
+ # clean collection
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("Deal dataset vector index failed")
diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py
index cb703cc263..ecf6f9cb39 100644
--- a/api/tasks/delete_account_task.py
+++ b/api/tasks/delete_account_task.py
@@ -3,7 +3,7 @@ import logging
from celery import shared_task
from configs import dify_config
-from extensions.ext_database import db
+from core.db.session_factory import session_factory
from models import Account
from services.billing_service import BillingService
from tasks.mail_account_deletion_task import send_deletion_success_task
@@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
- account = db.session.query(Account).where(Account.id == account_id).first()
- try:
- if dify_config.BILLING_ENABLED:
- BillingService.delete_account(account_id)
- except Exception:
- logger.exception("Failed to delete account %s from billing service.", account_id)
- raise
+ with session_factory.create_session() as session:
+ account = session.query(Account).where(Account.id == account_id).first()
+ try:
+ if dify_config.BILLING_ENABLED:
+ BillingService.delete_account(account_id)
+ except Exception:
+ logger.exception("Failed to delete account %s from billing service.", account_id)
+ raise
- if not account:
- logger.error("Account %s not found.", account_id)
- return
- # send success email
- send_deletion_success_task.delay(account.email)
+ if not account:
+ logger.error("Account %s not found.", account_id)
+ return
+ # send success email
+ send_deletion_success_task.delay(account.email)
diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py
index 756b67c93e..9664b8ac73 100644
--- a/api/tasks/delete_conversation_task.py
+++ b/api/tasks/delete_conversation_task.py
@@ -4,7 +4,7 @@ import time
import click
from celery import shared_task
-from extensions.ext_database import db
+from core.db.session_factory import session_factory
from models import ConversationVariable
from models.model import Message, MessageAnnotation, MessageFeedback
from models.tools import ToolConversationVariables, ToolFile
@@ -27,44 +27,46 @@ def delete_conversation_related_data(conversation_id: str):
)
start_at = time.perf_counter()
- try:
- db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(ToolConversationVariables).where(
- ToolConversationVariables.conversation_id == conversation_id
- ).delete(synchronize_session=False)
-
- db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
-
- db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
-
- db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
- fg="green",
+ with session_factory.create_session() as session:
+ try:
+ session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
+ synchronize_session=False
)
- )
- except Exception as e:
- logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
- db.session.rollback()
- raise e
- finally:
- db.session.close()
+ session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.query(ToolConversationVariables).where(
+ ToolConversationVariables.conversation_id == conversation_id
+ ).delete(synchronize_session=False)
+
+ session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
+
+ session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
+
+ session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ (
+ f"Succeeded cleaning data from db for conversation_id {conversation_id} "
+ f"latency: {end_at - start_at}"
+ ),
+ fg="green",
+ )
+ )
+
+ except Exception:
+ logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
+ session.rollback()
+ raise
diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py
index bea5c952cf..bfa709502c 100644
--- a/api/tasks/delete_segment_from_index_task.py
+++ b/api/tasks/delete_segment_from_index_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from models.dataset import Dataset, Document, SegmentAttachmentBinding
from models.model import UploadFile
@@ -26,49 +26,52 @@ def delete_segment_from_index_task(
"""
logger.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
- return
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
+ return
- dataset_document = db.session.query(Document).where(Document.id == document_id).first()
- if not dataset_document:
- return
+ dataset_document = session.query(Document).where(Document.id == document_id).first()
+ if not dataset_document:
+ return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logging.info("Document not in valid state for index operations, skipping")
- return
- doc_form = dataset_document.doc_form
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logging.info("Document not in valid state for index operations, skipping")
+ return
+ doc_form = dataset_document.doc_form
- # Proceed with index cleanup using the index_node_ids directly
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(
- dataset,
- index_node_ids,
- with_keywords=True,
- delete_child_chunks=True,
- precomputed_child_node_ids=child_node_ids,
- )
- if dataset.is_multimodal:
- # delete segment attachment binding
- segment_attachment_bindings = (
- db.session.query(SegmentAttachmentBinding)
- .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
- .all()
+ # Proceed with index cleanup using the index_node_ids directly
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(
+ dataset,
+ index_node_ids,
+ with_keywords=True,
+ delete_child_chunks=True,
+ precomputed_child_node_ids=child_node_ids,
)
- if segment_attachment_bindings:
- attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
- index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
- for binding in segment_attachment_bindings:
- db.session.delete(binding)
- # delete upload file
- db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
- db.session.commit()
+ if dataset.is_multimodal:
+ # delete segment attachment binding
+ segment_attachment_bindings = (
+ session.query(SegmentAttachmentBinding)
+ .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
+ .all()
+ )
+ if segment_attachment_bindings:
+ attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
+ index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
+ for binding in segment_attachment_bindings:
+ session.delete(binding)
+ # delete upload file
+ session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
+ session.commit()
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("delete segment from index failed")
- finally:
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
+ except Exception:
+ logger.exception("delete segment from index failed")
diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py
index 6b5f01b416..0ce6429a94 100644
--- a/api/tasks/disable_segment_from_index_task.py
+++ b/api/tasks/disable_segment_from_index_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@@ -23,46 +23,53 @@ def disable_segment_from_index_task(segment_id: str):
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "completed":
- logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "completed":
+ logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_type = dataset_document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.clean(dataset, [segment.index_node_id])
+ try:
+ dataset = segment.dataset
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("remove segment from index failed")
- segment.enabled = True
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_type = dataset_document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.clean(dataset, [segment.index_node_id])
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Segment removed from index: {segment.id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("remove segment from index failed")
+ segment.enabled = True
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py
index c2a3de29f4..03635902d1 100644
--- a/api/tasks/disable_segments_from_index_task.py
+++ b/api/tasks/disable_segments_from_index_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
@@ -26,69 +26,65 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
+ return
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
- db.session.close()
- return
- # sync index processor
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if not dataset_document:
+ logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
+ return
+ if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
+ logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
+ return
+ # sync index processor
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- )
- ).all()
-
- if not segments:
- db.session.close()
- return
-
- try:
- index_node_ids = [segment.index_node_id for segment in segments]
- if dataset.is_multimodal:
- segment_ids = [segment.id for segment in segments]
- segment_attachment_bindings = (
- db.session.query(SegmentAttachmentBinding)
- .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
- .all()
+ segments = session.scalars(
+ select(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
)
- if segment_attachment_bindings:
- attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
- index_node_ids.extend(attachment_ids)
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+ ).all()
- end_at = time.perf_counter()
- logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
- except Exception:
- # update segment error msg
- db.session.query(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- ).update(
- {
- "disabled_at": None,
- "disabled_by": None,
- "enabled": True,
- }
- )
- db.session.commit()
- finally:
- for segment in segments:
- indexing_cache_key = f"segment_{segment.id}_indexing"
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not segments:
+ return
+
+ try:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ if dataset.is_multimodal:
+ segment_ids = [segment.id for segment in segments]
+ segment_attachment_bindings = (
+ session.query(SegmentAttachmentBinding)
+ .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
+ .all()
+ )
+ if segment_attachment_bindings:
+ attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
+ index_node_ids.extend(attachment_ids)
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
+ except Exception:
+ # update segment error msg
+ session.query(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
+ ).update(
+ {
+ "disabled_at": None,
+ "disabled_by": None,
+ "enabled": True,
+ }
+ )
+ session.commit()
+ finally:
+ for segment in segments:
+ indexing_cache_key = f"segment_{segment.id}_indexing"
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py
index 5fc2597c92..149185f6e2 100644
--- a/api/tasks/document_indexing_sync_task.py
+++ b/api/tasks/document_indexing_sync_task.py
@@ -3,12 +3,12 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.datasource_provider_service import DatasourceProviderService
@@ -28,105 +28,103 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
-
- data_source_info = document.data_source_info_dict
- if document.data_source_type == "notion_import":
- if (
- not data_source_info
- or "notion_page_id" not in data_source_info
- or "notion_workspace_id" not in data_source_info
- ):
- raise ValueError("no notion page found")
- workspace_id = data_source_info["notion_workspace_id"]
- page_id = data_source_info["notion_page_id"]
- page_type = data_source_info["type"]
- page_edited_time = data_source_info["last_edited_time"]
- credential_id = data_source_info.get("credential_id")
-
- # Get credentials from datasource provider
- datasource_provider_service = DatasourceProviderService()
- credential = datasource_provider_service.get_datasource_credentials(
- tenant_id=document.tenant_id,
- credential_id=credential_id,
- provider="notion_datasource",
- plugin_id="langgenius/notion_datasource",
- )
-
- if not credential:
- logger.error(
- "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
- document_id,
- document.tenant_id,
- credential_id,
- )
- document.indexing_status = "error"
- document.error = "Datasource credential not found. Please reconnect your Notion workspace."
- document.stopped_at = naive_utc_now()
- db.session.commit()
- db.session.close()
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
- loader = NotionExtractor(
- notion_workspace_id=workspace_id,
- notion_obj_id=page_id,
- notion_page_type=page_type,
- notion_access_token=credential.get("integration_secret"),
- tenant_id=document.tenant_id,
- )
+ data_source_info = document.data_source_info_dict
+ if document.data_source_type == "notion_import":
+ if (
+ not data_source_info
+ or "notion_page_id" not in data_source_info
+ or "notion_workspace_id" not in data_source_info
+ ):
+ raise ValueError("no notion page found")
+ workspace_id = data_source_info["notion_workspace_id"]
+ page_id = data_source_info["notion_page_id"]
+ page_type = data_source_info["type"]
+ page_edited_time = data_source_info["last_edited_time"]
+ credential_id = data_source_info.get("credential_id")
- last_edited_time = loader.get_notion_last_edited_time()
+ # Get credentials from datasource provider
+ datasource_provider_service = DatasourceProviderService()
+ credential = datasource_provider_service.get_datasource_credentials(
+ tenant_id=document.tenant_id,
+ credential_id=credential_id,
+ provider="notion_datasource",
+ plugin_id="langgenius/notion_datasource",
+ )
- # check the page is updated
- if last_edited_time != page_edited_time:
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.commit()
-
- # delete all document segment and index
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
-
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- index_node_ids = [segment.index_node_id for segment in segments]
-
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Cleaned document when document update data source or process rule: {} latency: {}".format(
- document_id, end_at - start_at
- ),
- fg="green",
- )
+ if not credential:
+ logger.error(
+ "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
+ document_id,
+ document.tenant_id,
+ credential_id,
)
- except Exception:
- logger.exception("Cleaned document when document update data source or process rule failed")
+ document.indexing_status = "error"
+ document.error = "Datasource credential not found. Please reconnect your Notion workspace."
+ document.stopped_at = naive_utc_now()
+ session.commit()
+ return
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- end_at = time.perf_counter()
- logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ loader = NotionExtractor(
+ notion_workspace_id=workspace_id,
+ notion_obj_id=page_id,
+ notion_page_type=page_type,
+ notion_access_token=credential.get("integration_secret"),
+ tenant_id=document.tenant_id,
+ )
+
+ last_edited_time = loader.get_notion_last_edited_time()
+
+ # check the page is updated
+ if last_edited_time != page_edited_time:
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.commit()
+
+ # delete all document segment and index
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ index_node_ids = [segment.index_node_id for segment in segments]
+
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Cleaned document when document update data source or process rule: {} latency: {}".format(
+ document_id, end_at - start_at
+ ),
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("Cleaned document when document update data source or process rule failed")
+
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ end_at = time.perf_counter()
+ logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py
index acbdab631b..3bdff60196 100644
--- a/api/tasks/document_indexing_task.py
+++ b/api/tasks/document_indexing_task.py
@@ -6,11 +6,11 @@ import click
from celery import shared_task
from configs import dify_config
+from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
@@ -46,66 +46,63 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
documents = []
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
- db.session.close()
- return
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- count = len(document_ids)
- batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
- if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
- raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
- if count > batch_upload_limit:
- raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
+ return
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ count = len(document_ids)
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+ if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
+ raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
+ if count > batch_upload_limit:
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ for document_id in document_ids:
+ document = (
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
- except Exception as e:
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ return
+
for document_id in document_ids:
+ logger.info(click.style(f"Start process document: {document_id}", fg="green"))
+
document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
+
if document:
- document.indexing_status = "error"
- document.error = str(e)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- db.session.close()
- return
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ documents.append(document)
+ session.add(document)
+ session.commit()
- for document_id in document_ids:
- logger.info(click.style(f"Start process document: {document_id}", fg="green"))
-
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- )
-
- if document:
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- documents.append(document)
- db.session.add(document)
- db.session.commit()
-
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run(documents)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
- finally:
- db.session.close()
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run(documents)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(
diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py
index 161502a228..67a23be952 100644
--- a/api/tasks/document_indexing_update_task.py
+++ b/api/tasks/document_indexing_update_task.py
@@ -3,8 +3,9 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -26,56 +27,54 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.commit()
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.commit()
- # delete all document segment and index
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
+ # delete all document segment and index
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise Exception("Dataset not found")
- index_type = document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_type = document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Cleaned document when document update data source or process rule: {} latency: {}".format(
- document_id, end_at - start_at
- ),
- fg="green",
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ db.session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Cleaned document when document update data source or process rule: {} latency: {}".format(
+ document_id, end_at - start_at
+ ),
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when document update data source or process rule failed")
+ except Exception:
+ logger.exception("Cleaned document when document update data source or process rule failed")
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- end_at = time.perf_counter()
- logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ end_at = time.perf_counter()
+ logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py
index 4078c8910e..00a963255b 100644
--- a/api/tasks/duplicate_document_indexing_task.py
+++ b/api/tasks/duplicate_document_indexing_task.py
@@ -4,15 +4,15 @@ from collections.abc import Callable, Sequence
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
from configs import dify_config
+from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@@ -76,63 +76,64 @@ def _duplicate_document_indexing_task_with_tenant_queue(
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
- documents = []
+ documents: list[Document] = []
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if dataset is None:
- logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
- db.session.close()
- return
-
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
+ with session_factory.create_session() as session:
try:
- if features.billing.enabled:
- vector_space = features.vector_space
- count = len(document_ids)
- if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
- raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
- batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
- if count > batch_upload_limit:
- raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
- current = int(getattr(vector_space, "size", 0) or 0)
- limit = int(getattr(vector_space, "limit", 0) or 0)
- if limit > 0 and (current + count) > limit:
- raise ValueError(
- "Your total number of documents plus the number of uploads have exceeded the limit of "
- "your subscription."
- )
- except Exception as e:
- for document_id in document_ids:
- document = (
- db.session.query(Document)
- .where(Document.id == document_id, Document.dataset_id == dataset_id)
- .first()
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if dataset is None:
+ logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ return
+
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ count = len(document_ids)
+ if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
+ raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+ if count > batch_upload_limit:
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+ current = int(getattr(vector_space, "size", 0) or 0)
+ limit = int(getattr(vector_space, "limit", 0) or 0)
+ if limit > 0 and (current + count) > limit:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have exceeded the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ documents = list(
+ session.scalars(
+ select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
+ ).all()
)
- if document:
- document.indexing_status = "error"
- document.error = str(e)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- return
+ for document in documents:
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ return
- for document_id in document_ids:
- logger.info(click.style(f"Start process document: {document_id}", fg="green"))
-
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ documents = list(
+ session.scalars(
+ select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
+ ).all()
)
- if document:
+ for document in documents:
+ logger.info(click.style(f"Start process document: {document.id}", fg="green"))
+
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -140,26 +141,24 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
- documents.append(document)
- db.session.add(document)
- db.session.commit()
+ session.add(document)
+ session.commit()
- indexing_runner = IndexingRunner()
- indexing_runner.run(documents)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
- finally:
- db.session.close()
+ indexing_runner = IndexingRunner()
+ indexing_runner.run(list(documents))
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
@shared_task(queue="dataset")
diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py
index 7615469ed0..1f9f21aa7e 100644
--- a/api/tasks/enable_segment_to_index_task.py
+++ b/api/tasks/enable_segment_to_index_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@@ -27,91 +27,93 @@ def enable_segment_to_index_task(segment_id: str):
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "completed":
- logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
-
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "completed":
+ logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ try:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+
+ dataset = segment.dataset
+
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ multimodel_documents = []
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodel_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
)
- child_documents.append(child_document)
- document.children = child_documents
- multimodel_documents = []
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodel_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- # save vector index
- index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
+ # save vector index
+ index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("enable segment to index failed")
- segment.enabled = False
- segment.disabled_at = naive_utc_now()
- segment.status = "error"
- segment.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("enable segment to index failed")
+ segment.enabled = False
+ segment.disabled_at = naive_utc_now()
+ segment.status = "error"
+ segment.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py
index 9f17d09e18..48d3c8e178 100644
--- a/api/tasks/enable_segments_to_index_task.py
+++ b/api/tasks/enable_segments_to_index_task.py
@@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DocumentSegment
@@ -29,105 +29,102 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
- return
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
+ return
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
- db.session.close()
- return
- # sync index processor
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if not dataset_document:
+ logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
+ return
+ if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
+ logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
+ return
+ # sync index processor
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- )
- ).all()
- if not segments:
- logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
- db.session.close()
- return
-
- try:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": document_id,
- "dataset_id": dataset_id,
- },
+ segments = session.scalars(
+ select(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
)
+ ).all()
+ if not segments:
+ logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
+ return
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": document_id,
- "dataset_id": dataset_id,
- },
+ try:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": document_id,
+ "dataset_id": dataset_id,
+ },
+ )
+
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": document_id,
+ "dataset_id": dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
)
- child_documents.append(child_document)
- document.children = child_documents
+ documents.append(document)
+ # save vector index
+ index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
-
- end_at = time.perf_counter()
- logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("enable segments to index failed")
- # update segment error msg
- db.session.query(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- ).update(
- {
- "error": str(e),
- "status": "error",
- "disabled_at": naive_utc_now(),
- "enabled": False,
- }
- )
- db.session.commit()
- finally:
- for segment in segments:
- indexing_cache_key = f"segment_{segment.id}_indexing"
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("enable segments to index failed")
+ # update segment error msg
+ session.query(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
+ ).update(
+ {
+ "error": str(e),
+ "status": "error",
+ "disabled_at": naive_utc_now(),
+ "enabled": False,
+ }
+ )
+ session.commit()
+ finally:
+ for segment in segments:
+ indexing_cache_key = f"segment_{segment.id}_indexing"
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
index e6492c230d..b5e6508006 100644
--- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
+++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:"
-CACHE_REDIS_TTL = 60 * 15 # 15 minutes
+CACHE_REDIS_TTL = 60 * 60 # 1 hour
def _get_redis_cache_key(plugin_id: str) -> str:
diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
index 1eef361a92..3c5e152520 100644
--- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
+++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
@@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
- invoke_from=InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py
index 275f5abe6e..093342d1a3 100644
--- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py
+++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py
@@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
- invoke_from=InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py
index 1b2a653c01..af72023da1 100644
--- a/api/tasks/recover_document_indexing_task.py
+++ b/api/tasks/recover_document_indexing_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
-from extensions.ext_database import db
from models.dataset import Document
logger = logging.getLogger(__name__)
@@ -23,26 +23,24 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- try:
- indexing_runner = IndexingRunner()
- if document.indexing_status in {"waiting", "parsing", "cleaning"}:
- indexing_runner.run([document])
- elif document.indexing_status == "splitting":
- indexing_runner.run_in_splitting_status(document)
- elif document.indexing_status == "indexing":
- indexing_runner.run_in_indexing_status(document)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ try:
+ indexing_runner = IndexingRunner()
+ if document.indexing_status in {"waiting", "parsing", "cleaning"}:
+ indexing_runner.run([document])
+ elif document.indexing_status == "splitting":
+ indexing_runner.run_in_splitting_status(document)
+ elif document.indexing_status == "indexing":
+ indexing_runner.run_in_indexing_status(document)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index 3227f6da96..817249845a 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -1,15 +1,20 @@
import logging
import time
from collections.abc import Callable
+from typing import Any, cast
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
+from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
+from configs import dify_config
+from core.db.session_factory import session_factory
from extensions.ext_database import db
+from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from models import (
ApiToken,
AppAnnotationHitHistory,
@@ -40,6 +45,7 @@ from models.workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
+ WorkflowArchiveLog,
)
from repositories.factory import DifyAPIRepositoryFactory
@@ -64,6 +70,9 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_app_workflow_runs(tenant_id, app_id)
_delete_app_workflow_node_executions(tenant_id, app_id)
_delete_app_workflow_app_logs(tenant_id, app_id)
+ if dify_config.BILLING_ENABLED and dify_config.ARCHIVE_STORAGE_ENABLED:
+ _delete_app_workflow_archive_logs(tenant_id, app_id)
+ _delete_archived_workflow_run_files(tenant_id, app_id)
_delete_app_conversations(tenant_id, app_id)
_delete_app_messages(tenant_id, app_id)
_delete_workflow_tool_providers(tenant_id, app_id)
@@ -77,7 +86,6 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_workflow_webhook_triggers(tenant_id, app_id)
_delete_workflow_schedule_plans(tenant_id, app_id)
_delete_workflow_trigger_logs(tenant_id, app_id)
-
end_at = time.perf_counter()
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
except SQLAlchemyError as e:
@@ -89,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str):
- def del_model_config(model_config_id: str):
- db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
+ def del_model_config(session, model_config_id: str):
+ session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@@ -101,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
- def del_site(site_id: str):
- db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
+ def del_site(session, site_id: str):
+ session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
@@ -113,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
- def del_mcp_server(mcp_server_id: str):
- db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
+ def del_mcp_server(session, mcp_server_id: str):
+ session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
@@ -125,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
- def del_api_token(api_token_id: str):
- db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
+ def del_api_token(session, api_token_id: str):
+ session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""",
@@ -137,8 +145,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str):
- def del_installed_app(installed_app_id: str):
- db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
+ def del_installed_app(session, installed_app_id: str):
+ session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -149,10 +157,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
- def del_recommended_app(recommended_app_id: str):
- db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
- synchronize_session=False
- )
+ def del_recommended_app(session, recommended_app_id: str):
+ session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from recommended_apps where app_id=:app_id limit 1000""",
@@ -163,8 +169,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
- def del_annotation_hit_history(annotation_hit_history_id: str):
- db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
+ def del_annotation_hit_history(session, annotation_hit_history_id: str):
+ session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
)
@@ -175,8 +181,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
"annotation hit history",
)
- def del_annotation_setting(annotation_setting_id: str):
- db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
+ def del_annotation_setting(session, annotation_setting_id: str):
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
@@ -189,8 +195,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
- def del_dataset_join(dataset_join_id: str):
- db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
+ def del_dataset_join(session, dataset_join_id: str):
+ session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@@ -201,8 +207,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
- def del_workflow(workflow_id: str):
- db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
+ def del_workflow(session, workflow_id: str):
+ session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -241,10 +247,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
- def del_workflow_app_log(workflow_app_log_id: str):
- db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
- synchronize_session=False
- )
+ def del_workflow_app_log(session, workflow_app_log_id: str):
+ session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -254,12 +258,51 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
)
-def _delete_app_conversations(tenant_id: str, app_id: str):
- def del_conversation(conversation_id: str):
- db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
+ def del_workflow_archive_log(workflow_archive_log_id: str):
+ db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
)
- db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
+
+ _delete_records(
+ """select id from workflow_archive_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
+ {"tenant_id": tenant_id, "app_id": app_id},
+ del_workflow_archive_log,
+ "workflow archive log",
+ )
+
+
+def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
+ prefix = f"{tenant_id}/app_id={app_id}/"
+ try:
+ archive_storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ logger.info("Archive storage not configured, skipping archive file cleanup: %s", e)
+ return
+
+ try:
+ keys = archive_storage.list_objects(prefix)
+ except Exception:
+ logger.exception("Failed to list archive files for app %s", app_id)
+ return
+
+ deleted = 0
+ for key in keys:
+ try:
+ archive_storage.delete_object(key)
+ deleted += 1
+ except Exception:
+ logger.exception("Failed to delete archive object %s", key)
+
+ logger.info("Deleted %s archive objects for app %s", deleted, app_id)
+
+
+def _delete_app_conversations(tenant_id: str, app_id: str):
+ def del_conversation(session, conversation_id: str):
+ session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+ session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@@ -270,28 +313,26 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
def _delete_conversation_variables(*, app_id: str):
- stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
- with db.engine.connect() as conn:
- conn.execute(stmt)
- conn.commit()
+ with session_factory.create_session() as session:
+ stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
+ session.execute(stmt)
+ session.commit()
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str):
- def del_message(message_id: str):
- db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
+ def del_message(session, message_id: str):
+ session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
+ session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
+ session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
+ session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
- db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
- synchronize_session=False
- )
- db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
- db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
- db.session.query(Message).where(Message.id == message_id).delete()
+ session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
+ session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
+ session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""",
@@ -302,8 +343,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
- def del_tool_provider(tool_provider_id: str):
- db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
+ def del_tool_provider(session, tool_provider_id: str):
+ session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
@@ -316,8 +357,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
- def del_tag_binding(tag_binding_id: str):
- db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
+ def del_tag_binding(session, tag_binding_id: str):
+ session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@@ -328,8 +369,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
- def del_end_user(end_user_id: str):
- db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
+ def del_end_user(session, end_user_id: str):
+ session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -340,10 +381,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
- def del_trace_app_config(trace_app_config_id: str):
- db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
- synchronize_session=False
- )
+ def del_trace_app_config(session, trace_app_config_id: str):
+ session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from trace_app_config where app_id=:app_id limit 1000""",
@@ -381,14 +420,14 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
- with db.engine.begin() as conn:
+ with session_factory.create_session() as session:
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
- result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
+ result = session.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
rows = list(result)
if not rows:
@@ -399,7 +438,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
# Clean up associated Offload data first
if file_ids:
- files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
+ files_deleted = _delete_draft_variable_offload_data(session, file_ids)
total_files_deleted += files_deleted
# Delete the draft variables
@@ -407,8 +446,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
- deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
- batch_deleted = deleted_result.rowcount
+ deleted_result = cast(
+ CursorResult[Any],
+ session.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}),
+ )
+ batch_deleted: int = int(getattr(deleted_result, "rowcount", 0) or 0)
total_deleted += batch_deleted
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
@@ -423,7 +465,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
return total_deleted
-def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
+def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
"""
Delete Offload data associated with WorkflowDraftVariable file_ids.
@@ -434,7 +476,7 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
4. Deletes WorkflowDraftVariableFile records
Args:
- conn: Database connection
+ session: Database connection
file_ids: List of WorkflowDraftVariableFile IDs
Returns:
@@ -450,12 +492,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
try:
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
query_sql = """
- SELECT wdvf.id, uf.key, uf.id as upload_file_id
- FROM workflow_draft_variable_files wdvf
- JOIN upload_files uf ON wdvf.upload_file_id = uf.id
- WHERE wdvf.id IN :file_ids
- """
- result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
+ SELECT wdvf.id, uf.key, uf.id as upload_file_id
+ FROM workflow_draft_variable_files wdvf
+ JOIN upload_files uf ON wdvf.upload_file_id = uf.id
+ WHERE wdvf.id IN :file_ids \
+ """
+ result = session.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
file_records = list(result)
# Delete from object storage and collect upload file IDs
@@ -473,17 +515,19 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
# Delete UploadFile records
if upload_file_ids:
delete_upload_files_sql = """
- DELETE FROM upload_files
- WHERE id IN :upload_file_ids
- """
- conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
+ DELETE \
+ FROM upload_files
+ WHERE id IN :upload_file_ids \
+ """
+ session.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
# Delete WorkflowDraftVariableFile records
delete_variable_files_sql = """
- DELETE FROM workflow_draft_variable_files
- WHERE id IN :file_ids
- """
- conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
+ DELETE \
+ FROM workflow_draft_variable_files
+ WHERE id IN :file_ids \
+ """
+ session.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
except Exception:
logging.exception("Error deleting draft variable offload data:")
@@ -493,8 +537,8 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
def _delete_app_triggers(tenant_id: str, app_id: str):
- def del_app_trigger(trigger_id: str):
- db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
+ def del_app_trigger(session, trigger_id: str):
+ session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -505,8 +549,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
- def del_plugin_trigger(trigger_id: str):
- db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
+ def del_plugin_trigger(session, trigger_id: str):
+ session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
)
@@ -519,8 +563,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
- def del_webhook_trigger(trigger_id: str):
- db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
+ def del_webhook_trigger(session, trigger_id: str):
+ session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
)
@@ -533,10 +577,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
- def del_schedule_plan(plan_id: str):
- db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(
- synchronize_session=False
- )
+ def del_schedule_plan(session, plan_id: str):
+ session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -547,8 +589,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
- def del_trigger_log(log_id: str):
- db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
+ def del_trigger_log(session, log_id: str):
+ session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -560,18 +602,22 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
- with db.engine.begin() as conn:
- rs = conn.execute(sa.text(query_sql), params)
- if rs.rowcount == 0:
+ with session_factory.create_session() as session:
+ rs = session.execute(sa.text(query_sql), params)
+ rows = rs.fetchall()
+ if not rows:
break
- for i in rs:
+ for i in rows:
record_id = str(i.id)
try:
- delete_func(record_id)
- db.session.commit()
+ delete_func(session, record_id)
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
except Exception:
logger.exception("Error occurred while deleting %s %s", name, record_id)
- continue
+ # continue with next record even if one deletion fails
+ session.rollback()
+ break
+ session.commit()
+
rs.close()
diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py
index c0ab2d0b41..c3c255fb17 100644
--- a/api/tasks/remove_document_from_index_task.py
+++ b/api/tasks/remove_document_from_index_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Document, DocumentSegment
@@ -25,52 +25,55 @@ def remove_document_from_index_task(document_id: str):
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id).first()
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- if document.indexing_status != "completed":
- logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
- db.session.close()
- return
+ if document.indexing_status != "completed":
+ logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
+ return
- indexing_cache_key = f"document_{document.id}_indexing"
+ indexing_cache_key = f"document_{document.id}_indexing"
- try:
- dataset = document.dataset
+ try:
+ dataset = document.dataset
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
- index_node_ids = [segment.index_node_id for segment in segments]
- if index_node_ids:
- try:
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
- except Exception:
- logger.exception("clean dataset %s from index failed", dataset.id)
- # update segment to disable
- db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
- {
- DocumentSegment.enabled: False,
- DocumentSegment.disabled_at: naive_utc_now(),
- DocumentSegment.disabled_by: document.disabled_by,
- DocumentSegment.updated_at: naive_utc_now(),
- }
- )
- db.session.commit()
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
+ index_node_ids = [segment.index_node_id for segment in segments]
+ if index_node_ids:
+ try:
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+ except Exception:
+ logger.exception("clean dataset %s from index failed", dataset.id)
+ # update segment to disable
+ session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
+ {
+ DocumentSegment.enabled: False,
+ DocumentSegment.disabled_at: naive_utc_now(),
+ DocumentSegment.disabled_by: document.disabled_by,
+ DocumentSegment.updated_at: naive_utc_now(),
+ }
+ )
+ session.commit()
- end_at = time.perf_counter()
- logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("remove document from index failed")
- if not document.archived:
- document.enabled = True
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Document removed from index: {document.id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("remove document from index failed")
+ if not document.archived:
+ document.enabled = True
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py
index 9d208647e6..f20b15ac83 100644
--- a/api/tasks/retry_document_indexing_task.py
+++ b/api/tasks/retry_document_indexing_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
@@ -29,97 +29,97 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
"""
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
- return
- user = db.session.query(Account).where(Account.id == user_id).first()
- if not user:
- logger.info(click.style(f"User not found: {user_id}", fg="red"))
- return
- tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
- if not tenant:
- raise ValueError("Tenant not found")
- user.current_tenant = tenant
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ return
+ user = session.query(Account).where(Account.id == user_id).first()
+ if not user:
+ logger.info(click.style(f"User not found: {user_id}", fg="red"))
+ return
+ tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
+ if not tenant:
+ raise ValueError("Tenant not found")
+ user.current_tenant = tenant
- for document_id in document_ids:
- retry_indexing_cache_key = f"document_{document_id}_is_retried"
- # check document limit
- features = FeatureService.get_features(tenant.id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
- )
- except Exception as e:
+ for document_id in document_ids:
+ retry_indexing_cache_key = f"document_{document_id}_is_retried"
+ # check document limit
+ features = FeatureService.get_features(tenant.id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ document = (
+ session.query(Document)
+ .where(Document.id == document_id, Document.dataset_id == dataset_id)
+ .first()
+ )
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ redis_client.delete(retry_indexing_cache_key)
+ return
+
+ logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
- db.session.query(Document)
- .where(Document.id == document_id, Document.dataset_id == dataset_id)
- .first()
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
- if document:
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
+ return
+ try:
+ # clean old data
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+
+ if dataset.runtime_mode == "rag_pipeline":
+ rag_pipeline_service = RagPipelineService()
+ rag_pipeline_service.retry_error_document(dataset, document, user)
+ else:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ redis_client.delete(retry_indexing_cache_key)
+ except Exception as ex:
document.indexing_status = "error"
- document.error = str(e)
+ document.error = str(ex)
document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- redis_client.delete(retry_indexing_cache_key)
- return
-
- logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ session.add(document)
+ session.commit()
+ logger.info(click.style(str(ex), fg="yellow"))
+ redis_client.delete(retry_indexing_cache_key)
+ logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception(
+ "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
- return
- try:
- # clean old data
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
-
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
-
- if dataset.runtime_mode == "rag_pipeline":
- rag_pipeline_service = RagPipelineService()
- rag_pipeline_service.retry_error_document(dataset, document, user)
- else:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- redis_client.delete(retry_indexing_cache_key)
- except Exception as ex:
- document.indexing_status = "error"
- document.error = str(ex)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- logger.info(click.style(str(ex), fg="yellow"))
- redis_client.delete(retry_indexing_cache_key)
- logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
- end_at = time.perf_counter()
- logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception(
- "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
- )
- raise e
- finally:
- db.session.close()
+ raise e
diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py
index 0dc1d841f4..f1c8c56995 100644
--- a/api/tasks/sync_website_document_indexing_task.py
+++ b/api/tasks/sync_website_document_indexing_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@@ -27,69 +27,71 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if dataset is None:
- raise ValueError("Dataset not found")
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if dataset is None:
+ raise ValueError("Dataset not found")
- sync_indexing_cache_key = f"document_{document_id}_is_sync"
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
- )
- except Exception as e:
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- )
- if document:
+ sync_indexing_cache_key = f"document_{document_id}_is_sync"
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ document = (
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ )
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ redis_client.delete(sync_indexing_cache_key)
+ return
+
+ logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
+ return
+ try:
+ # clean old data
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ redis_client.delete(sync_indexing_cache_key)
+ except Exception as ex:
document.indexing_status = "error"
- document.error = str(e)
+ document.error = str(ex)
document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- redis_client.delete(sync_indexing_cache_key)
- return
-
- logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
- return
- try:
- # clean old data
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
-
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
-
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- redis_client.delete(sync_indexing_cache_key)
- except Exception as ex:
- document.indexing_status = "error"
- document.error = str(ex)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- logger.info(click.style(str(ex), fg="yellow"))
- redis_client.delete(sync_indexing_cache_key)
- logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
- end_at = time.perf_counter()
- logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
+ session.add(document)
+ session.commit()
+ logger.info(click.style(str(ex), fg="yellow"))
+ redis_client.delete(sync_indexing_cache_key)
+ logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py
index ee1d31aa91..d18ea2c23c 100644
--- a/api/tasks/trigger_processing_tasks.py
+++ b/api/tasks/trigger_processing_tasks.py
@@ -16,6 +16,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginInvokeError
@@ -27,7 +28,6 @@ from core.trigger.trigger_manager import TriggerManager
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
-from extensions.ext_database import db
from models.enums import (
AppTriggerType,
CreatorUserRole,
@@ -257,7 +257,7 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py
index ed92f3f3c5..7698a1a6b8 100644
--- a/api/tasks/trigger_subscription_refresh_tasks.py
+++ b/api/tasks/trigger_subscription_refresh_tasks.py
@@ -7,9 +7,9 @@ from celery import shared_task
from sqlalchemy.orm import Session
from configs import dify_config
+from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.utils.locks import build_trigger_refresh_lock_key
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
@@ -92,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
try:
now: int = _now_ts()
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
if not subscription:
diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py
index 7d145fb50c..3b3c6e5313 100644
--- a/api/tasks/workflow_execution_tasks.py
+++ b/api/tasks/workflow_execution_tasks.py
@@ -10,11 +10,10 @@ import logging
from celery import shared_task
from sqlalchemy import select
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
-from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
@@ -46,10 +45,7 @@ def save_workflow_execution_task(
True if successful, False otherwise
"""
try:
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py
index 8f5127670f..b30a4ff15b 100644
--- a/api/tasks/workflow_node_execution_tasks.py
+++ b/api/tasks/workflow_node_execution_tasks.py
@@ -10,13 +10,12 @@ import logging
from celery import shared_task
from sqlalchemy import select
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
-from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@@ -48,10 +47,7 @@ def save_workflow_node_execution_task(
True if successful, False otherwise
"""
try:
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py
index f54e02a219..8c64d3ab27 100644
--- a/api/tasks/workflow_schedule_tasks.py
+++ b/api/tasks/workflow_schedule_tasks.py
@@ -1,15 +1,14 @@
import logging
from celery import shared_task
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.nodes.trigger_schedule.exc import (
ScheduleExecutionError,
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
-from extensions.ext_database import db
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
@@ -33,10 +32,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
-
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule {schedule_id} not found")
diff --git a/api/templates/invite_member_mail_template_en-US.html b/api/templates/invite_member_mail_template_en-US.html
index a07c5f4b16..7b296519f0 100644
--- a/api/templates/invite_member_mail_template_en-US.html
+++ b/api/templates/invite_member_mail_template_en-US.html
@@ -83,7 +83,30 @@
Dear {{ to }},
{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.
Click the button below to log in to Dify and join the workspace.
- Login Here
+
Best regards,
Dify Team
diff --git a/api/templates/invite_member_mail_template_zh-CN.html b/api/templates/invite_member_mail_template_zh-CN.html
index 27709a3c6d..c05b3ddb67 100644
--- a/api/templates/invite_member_mail_template_zh-CN.html
+++ b/api/templates/invite_member_mail_template_zh-CN.html
@@ -83,7 +83,30 @@
尊敬的 {{ to }},
{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。
点击下方按钮即可登录 Dify 并且加入空间。
- 在此登录
+
此致,
Dify 团队
diff --git a/api/templates/register_email_when_account_exist_template_en-US.html b/api/templates/register_email_when_account_exist_template_en-US.html
index ac5042c274..e2bb99c989 100644
--- a/api/templates/register_email_when_account_exist_template_en-US.html
+++ b/api/templates/register_email_when_account_exist_template_en-US.html
@@ -115,7 +115,30 @@
We noticed you tried to sign up, but this email is already registered with an existing account.
Please log in here:
- Log In
+
If you forgot your password, you can reset it here: Reset Password
diff --git a/api/templates/register_email_when_account_exist_template_zh-CN.html b/api/templates/register_email_when_account_exist_template_zh-CN.html
index 326b58343a..6a5bbd135b 100644
--- a/api/templates/register_email_when_account_exist_template_zh-CN.html
+++ b/api/templates/register_email_when_account_exist_template_zh-CN.html
@@ -115,7 +115,30 @@
我们注意到您尝试注册,但此电子邮件已注册。
请在此登录:
- 登录
+
如果您忘记了密码,可以在此重置: 重置密码
diff --git a/api/templates/without-brand/invite_member_mail_template_en-US.html b/api/templates/without-brand/invite_member_mail_template_en-US.html
index f9157284fa..687ece617a 100644
--- a/api/templates/without-brand/invite_member_mail_template_en-US.html
+++ b/api/templates/without-brand/invite_member_mail_template_en-US.html
@@ -92,12 +92,34 @@
platform specifically designed for LLM application development. On {{application_title}}, you can explore,
create, and collaborate to build and operate AI applications.
Click the button below to log in to {{application_title}} and join the workspace.
- Login Here
+
Best regards,
{{application_title}} Team