mirror of
https://github.com/langgenius/dify.git
synced 2026-06-15 20:37:35 +08:00
Compare commits
77 Commits
copilot/di
...
codex/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
| 5f51107260 | |||
| 44eda16261 | |||
| fe8b87d460 | |||
| 0a051b598f | |||
| 534dd50d14 | |||
| 0d8f7c41de | |||
| fb70ebb8f8 | |||
| e3cfc4d40f | |||
| 9ac71329a4 | |||
| 4fb3210f9a | |||
| 09bfbf386e | |||
| f1ef7379dd | |||
| 4c347f198e | |||
| 366e58bbbb | |||
| 8430255931 | |||
| d849d60822 | |||
| dad2e64a62 | |||
| ba9975a083 | |||
| 629e046303 | |||
| c9bb740a6b | |||
| 50e23f40a4 | |||
| 212b819f1c | |||
| 3fb1d3055e | |||
| a823649934 | |||
| 19d2a4d7a0 | |||
| 28cc3fc10d | |||
| 34f3591d4c | |||
| c88a38b8b5 | |||
| 0019e6a6f3 | |||
| 1502a57381 | |||
| 686e643632 | |||
| 8e37d95760 | |||
| 11db079428 | |||
| eb3b12fa70 | |||
| 5bec8eb33a | |||
| d11e4eeaf7 | |||
| bbdf3d7634 | |||
| a80bba2c35 | |||
| 789698cddd | |||
| a8977be999 | |||
| 22e67b4673 | |||
| f948e442e0 | |||
| 8a1c0cf5ab | |||
| 47b58a34ef | |||
| d80bd2a135 | |||
| 5d814ca8c1 | |||
| 0239b81cca | |||
| a15ecf6bec | |||
| d0b376d31a | |||
| 9c24b7bac5 | |||
| 6291452020 | |||
| d46a4c05b1 | |||
| f15a8f02ef | |||
| 0c4b36b3f5 | |||
| 37e1d452b8 | |||
| db1aa683bc | |||
| a88c15c906 | |||
| 12bd8d2aa8 | |||
| 813bfea730 | |||
| 759b4cbad3 | |||
| 72c92fa60a | |||
| 1ae98b3ea4 | |||
| 196c040c99 | |||
| fad5656b2e | |||
| 76fb1b6ea8 | |||
| 157ba6f5a0 | |||
| 1c0080be6f | |||
| 6b12152ce8 | |||
| 1231c2f976 | |||
| 00ac937934 | |||
| 2c323104eb | |||
| edeaac5d4e | |||
| d16a012575 | |||
| 23cd129802 | |||
| e40b30d746 | |||
| a1d9340a62 | |||
| 3addc1e386 |
@ -1,73 +1,94 @@
|
||||
---
|
||||
name: frontend-code-review
|
||||
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
|
||||
description: Review Dify frontend code for correctness, accessibility, component design, dify-ui usage, data/query boundaries, performance, and tests. Trigger for `.tsx`, `.ts`, `.js`, UI, React, Next.js, pending-change, or focused frontend review requests.
|
||||
---
|
||||
|
||||
# Frontend Code Review
|
||||
|
||||
## Intent
|
||||
Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
|
||||
## When To Use
|
||||
|
||||
1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission.
|
||||
2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings.
|
||||
Use this skill when the user asks to review, audit, analyze, or sanity-check frontend code under `web/`, `packages/dify-ui/`, or frontend-adjacent TypeScript files.
|
||||
|
||||
Stick to the checklist below for every applicable file and mode.
|
||||
Supported modes:
|
||||
|
||||
## Checklist
|
||||
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
|
||||
- **Pending-change review**: inspect staged and working-tree changes.
|
||||
- **File-focused review**: inspect explicitly named files or paths.
|
||||
- **Diff/snippet review**: review pasted diffs or snippets using best-effort references.
|
||||
|
||||
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
|
||||
Do not use this skill for backend-only code under `api/`; use `backend-code-review` instead.
|
||||
|
||||
## Required Context
|
||||
|
||||
Before reviewing, read the relevant local contracts:
|
||||
|
||||
- `web/AGENTS.md` for Dify frontend workflow, overlays, design tokens, state, and tests.
|
||||
- `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md` when code uses or changes `@langgenius/dify-ui/*`.
|
||||
- `web/docs/overlay.md` when reviewing dialogs, drawers, popovers, tooltips, menus, selects, comboboxes, or other floating UI.
|
||||
- `web/docs/test.md` and the `frontend-testing` skill when reviewing tests or testability.
|
||||
- `karpathy-guidelines` for scope control and focused, verifiable changes.
|
||||
- `how-to-write-component` when reviewing React component structure, ownership, effects, query/mutation contracts, or memoization.
|
||||
|
||||
For any UI, UX, or accessibility review, fetch the latest Web Interface Guidelines before finalizing findings. Treat them as a required baseline, not the complete source of accessibility truth:
|
||||
|
||||
```text
|
||||
https://raw.githubusercontent.com/vercel-labs/web-interface-guidelines/main/command.md
|
||||
```
|
||||
|
||||
If the review depends on a current framework, SDK, browser API, or accessibility behavior and local code does not settle it, check the current official docs first. For browser compatibility, deprecation, or behavior-sensitive frontend APIs, verify MDN or the relevant standard.
|
||||
|
||||
## Rule Packs
|
||||
|
||||
Apply every relevant rule pack:
|
||||
|
||||
- [references/accessibility-ui.md](references/accessibility-ui.md) — accessibility, semantic HTML, focus, forms, keyboard, disabled states, copy, and long-content behavior. Combines Web Interface Guidelines with Dify UI, Base UI, MDN, and local primitive contracts.
|
||||
- [references/dify-ui.md](references/dify-ui.md) — Dify UI primitive usage, Base UI semantics, overlays, forms, tokens, radius mapping, and primitive boundaries.
|
||||
- [references/component-architecture.md](references/component-architecture.md) — component ownership, props, state, effects, exports, wrappers, and feature organization.
|
||||
- [references/data-query-contracts.md](references/data-query-contracts.md) — generated contracts, TanStack Query, mutations, workspace/auth/SSR boundaries, URL/local storage state.
|
||||
- [references/performance.md](references/performance.md) — React/Next performance review rules from Vercel guidance, scoped to real risk.
|
||||
- [references/testing.md](references/testing.md) — frontend test review rules.
|
||||
- [references/dify-invariants.md](references/dify-invariants.md) — stable Dify-specific runtime invariants that generic React/a11y rules will not catch.
|
||||
- [references/code-quality.md](references/code-quality.md) — general TypeScript, styling, naming, and maintainability rules.
|
||||
|
||||
## Review Process
|
||||
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
|
||||
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
|
||||
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
|
||||
|
||||
## Required output
|
||||
When invoked, the response must exactly follow one of the two templates:
|
||||
1. Identify the review scope. For pending changes, inspect `git diff --stat`, `git diff`, and staged diff if relevant. For file-focused reviews, stay within the named files unless a referenced owner/contract must be read.
|
||||
2. Read code around the changed lines and the owning module. Do not review by isolated snippets when nearby ownership, labels, query inputs, or overlay structure decide correctness.
|
||||
3. Check user-visible regressions first: accessibility, broken interaction, auth/permission leaks, query/hydration errors, data loss, navigation mistakes, and impossible states.
|
||||
4. Then check maintainability and performance: ownership, effects, wrappers, memoization, bundle/waterfall risks, tests, and design-system drift.
|
||||
5. Report only actionable findings. Do not list speculative risks, style preferences, or broad refactors unless they are directly tied to a reproducible issue in scope.
|
||||
|
||||
### Template A (any findings)
|
||||
```
|
||||
# Code review
|
||||
Found <N> urgent issues need to be fixed:
|
||||
## Severity
|
||||
|
||||
## 1 <brief description of bug>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
- **P0**: security/privacy/auth leak, data loss, production crash, inaccessible critical flow, or broken primary workflow.
|
||||
- **P1**: user-visible regression, hydration/SSR failure, invalid API/query contract, broken keyboard/focus behavior, or serious design-system/a11y violation.
|
||||
- **P2**: maintainability or performance issue likely to cause bugs, duplicated state, incorrect ownership, missing tests for risky behavior, or non-critical a11y issue.
|
||||
- **P3**: minor cleanup with clear value. Omit unless the user asked for a thorough audit.
|
||||
|
||||
## Output Format
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
Lead with findings, ordered by severity. Use this structure:
|
||||
|
||||
---
|
||||
... (repeat for each urgent issue) ...
|
||||
```markdown
|
||||
## Findings
|
||||
|
||||
Found <M> suggestions for improvement:
|
||||
- [P1] Short issue title
|
||||
File: `path/to/file.tsx:123`
|
||||
Why it matters and how to reproduce or reason about it.
|
||||
Suggested fix: concrete fix direction.
|
||||
|
||||
## 1 <brief description of suggestion>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
## Open Questions
|
||||
|
||||
- Question or assumption, if any.
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
## Summary
|
||||
|
||||
---
|
||||
|
||||
... (repeat for each suggestion) ...
|
||||
Brief secondary context. Mention tests not run or residual risk.
|
||||
```
|
||||
|
||||
If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
|
||||
|
||||
If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
|
||||
|
||||
Don't compress the blank lines between sections; keep them as-is for readability.
|
||||
|
||||
If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
|
||||
|
||||
### Template B (no issues)
|
||||
```
|
||||
## Code review
|
||||
No issues found.
|
||||
```
|
||||
Rules:
|
||||
|
||||
- If there are no findings, say `No issues found.` and mention any test gaps or residual risk.
|
||||
- Always include file and line when available.
|
||||
- Keep findings concrete and reproducible.
|
||||
- Do not include praise sections by default.
|
||||
- Do not ask to apply fixes unless the user explicitly wants review plus implementation.
|
||||
|
||||
@ -0,0 +1,109 @@
|
||||
# Accessibility And UI Rules
|
||||
|
||||
Accessibility findings are first-class review findings. Treat broken keyboard access, missing accessible names, focus loss, and unreachable popup content as correctness bugs, not polish.
|
||||
|
||||
Before finalizing UI or accessibility findings, fetch the latest Web Interface Guidelines as a required baseline:
|
||||
|
||||
```text
|
||||
https://raw.githubusercontent.com/vercel-labs/web-interface-guidelines/main/command.md
|
||||
```
|
||||
|
||||
Do not treat that document as the complete accessibility rule set. Combine it with:
|
||||
|
||||
- `packages/dify-ui/README.md`, `packages/dify-ui/AGENTS.md`, and the relevant primitive implementation when code uses `@langgenius/dify-ui/*`.
|
||||
- Base UI docs and local `.d.ts` contracts when primitive semantics, focus target, labels, or popup reachability are unclear.
|
||||
- MDN or relevant WAI-ARIA/browser standards when behavior, compatibility, or deprecation status matters.
|
||||
- The current feature's product semantics, because an accessible primitive can still be used in an inaccessible workflow.
|
||||
|
||||
## Semantic HTML
|
||||
|
||||
Flag:
|
||||
|
||||
- Clickable `div` or `span` used for actions.
|
||||
- Router navigation implemented with button or `onClick` when a `Link` / `<a>` is the real semantic element.
|
||||
- Icon-only buttons without `aria-label` or `aria-labelledby`.
|
||||
- Decorative icons missing `aria-hidden="true"`.
|
||||
- Images without `alt`; use `alt=""` only when truly decorative.
|
||||
- Heading levels that skip hierarchy in page-level content.
|
||||
|
||||
Prefer semantic HTML before ARIA.
|
||||
|
||||
## Keyboard And Focus
|
||||
|
||||
Flag:
|
||||
|
||||
- Interactive elements without visible `focus-visible` treatment.
|
||||
- `outline-none` / `outline-hidden` without an equivalent focus-visible ring or state.
|
||||
- Custom interactive elements missing keyboard handling.
|
||||
- Focus trapped, lost, or sent to the wrong surface after dialog/popover/menu close.
|
||||
- Focus ring applied to the wrong DOM node. Verify the actual focus target, especially with Base UI controls such as Slider.
|
||||
|
||||
Use `focus-visible` for keyboard focus. Use `focus-within` or `has-[:focus-visible]` when the visual wrapper is not the focused element.
|
||||
|
||||
## Forms
|
||||
|
||||
Flag:
|
||||
|
||||
- Inputs, selects, switches, checkboxes, radios, comboboxes, or sliders without a label relationship.
|
||||
- Missing stable `name` on form fields that submit or validate.
|
||||
- Incorrect input `type`, `inputMode`, `autoComplete`, or `spellCheck` for email, token, URL, number, search, code, or username fields.
|
||||
- Labels that are not clickable.
|
||||
- Submit buttons disabled before a request starts, preventing normal submit behavior.
|
||||
- Non-submit buttons inside forms missing `type="button"`.
|
||||
- Errors not associated with fields or not reachable by screen readers.
|
||||
- Error recovery that does not focus or expose the first invalid field.
|
||||
- `onPaste` blocking paste.
|
||||
- Placeholder text used as the only label.
|
||||
- Password managers accidentally triggered on non-auth fields because autocomplete is missing or wrong.
|
||||
|
||||
Prefer visible labels. If visible surrounding text already labels the control, use a visually hidden label or a precise `aria-label`.
|
||||
|
||||
## Disabled, Loading, And Async States
|
||||
|
||||
Flag:
|
||||
|
||||
- Loading state without `aria-busy`, `role="status"`, or another accessible update path when it changes user interaction.
|
||||
- Spinner or decorative loading icon exposed to screen readers.
|
||||
- Disabled controls that hide the reason users cannot proceed.
|
||||
- `aria-disabled` used without manually blocking click, Space, and Enter.
|
||||
- Toasts, inline validation, or async status changes that are not announced when users need the update to continue.
|
||||
- Icon-only loading/error affordances without text or accessible status where the state matters.
|
||||
|
||||
Use native `disabled` when the control must not be interactive. Use `aria-disabled` only when the element must remain focusable and the code handles all blocked interactions.
|
||||
|
||||
For repeated shared disabled reasons, prefer a visible group message or badge plus native disabled controls. Use per-control popover/info only when the reason is item-specific.
|
||||
|
||||
## Overlays And Popup Reachability
|
||||
|
||||
Flag:
|
||||
|
||||
- Tooltip used for long, structured, interactive, or unique information.
|
||||
- Tooltip content required to understand or complete a flow.
|
||||
- PreviewCard content that touch or screen-reader users cannot reach through the trigger's click destination.
|
||||
- Popover/dialog/menu triggers without accessible names.
|
||||
- Popup content without title/description where the primitive requires them.
|
||||
|
||||
Use Popover for explanatory content, rich help, and infotips. Use Tooltip only as a short visual label for a trigger that already has an accessible name.
|
||||
|
||||
## Long Content And Layout
|
||||
|
||||
Flag:
|
||||
|
||||
- Text in flex/grid children without `min-w-0` when it can overflow.
|
||||
- Names, labels, file names, model names, workspace names, or user content lacking `truncate`, `line-clamp`, or `break-words`.
|
||||
- Right-side icons, badges, checks, or actions that shrink before the text area.
|
||||
- Empty arrays or empty strings rendering broken layout instead of an empty state.
|
||||
- Button, tab, badge, chip, menu item, or card text that can overlap sibling controls at common viewport widths.
|
||||
|
||||
The usual Dify layout chain is: container has width constraints, text region uses `min-w-0 flex-1 truncate`, adornments use `shrink-0`.
|
||||
|
||||
## Motion, Images, And Copy
|
||||
|
||||
Flag:
|
||||
|
||||
- `transition-all`.
|
||||
- Animations that do not respect reduced motion.
|
||||
- Layout-affecting animation where transform/opacity would work.
|
||||
- Images without dimensions.
|
||||
- Loading copy using `...` instead of `…`.
|
||||
- Hardcoded dates, times, numbers, or currency formats instead of `Intl.*`.
|
||||
@ -1,15 +0,0 @@
|
||||
# Rule Catalog — Business Logic
|
||||
|
||||
## Can't use workflowStore in Node components
|
||||
|
||||
IsUrgent: True
|
||||
|
||||
### Description
|
||||
|
||||
File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
|
||||
|
||||
Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.
|
||||
@ -1,44 +1,68 @@
|
||||
# Rule Catalog — Code Quality
|
||||
# Code Quality Rules
|
||||
|
||||
## Conditional class names use utility function
|
||||
## Scope Control
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
Flag changes that expand beyond the requested feature or review scope:
|
||||
|
||||
### Description
|
||||
- Repo-wide cleanup mixed into a targeted fix.
|
||||
- Compatibility exports, aliases, shims, or wrapper layers added without an explicit migration requirement.
|
||||
- Shared abstractions created before there is stable cross-feature reuse.
|
||||
- Business components moved into generic shared locations without a clear ownership boundary.
|
||||
|
||||
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
|
||||
## TypeScript
|
||||
|
||||
### Suggested Fix
|
||||
Flag:
|
||||
|
||||
```ts
|
||||
import { cn } from '@/utils/classnames'
|
||||
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
|
||||
```
|
||||
- `any` or broad `Record<string, any>` where generated/API types or local domain types exist.
|
||||
- Re-declared API shapes instead of importing generated or returned types.
|
||||
- Weak route/query param typing that leaks `string | string[] | undefined` deep into components.
|
||||
- Runtime wrappers added only to satisfy TypeScript when a narrower type boundary would preserve the existing runtime shape.
|
||||
|
||||
## Tailwind-first styling
|
||||
Prefer:
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
- Explicit domain names that match the API contract.
|
||||
- Type narrowing at route/API boundaries.
|
||||
- Small conversion helpers colocated with the component that needs them.
|
||||
|
||||
### Description
|
||||
## Styling
|
||||
|
||||
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
|
||||
Flag:
|
||||
|
||||
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
|
||||
- New CSS modules or ad hoc CSS when Tailwind utilities and Dify tokens cover the need.
|
||||
- Component-level plain `.css` files or component CSS imported through `globals.css`; use scoped `*.module.css` only when Tailwind and component variants cannot express the style.
|
||||
- Generic color utilities where Dify semantic tokens exist.
|
||||
- Hardcoded magic class values for colors, spacing, radius, shadow, z-index, or typography when Dify tokens, component variants, or documented radius mappings exist.
|
||||
- `!` important modifiers or important CSS overrides without a narrow, documented reason.
|
||||
- Manual string concatenation, template strings, array `.join(' ')`, or custom ternaries for conditional or multi-line classes.
|
||||
- JS conditional class branches for primitive visual states already exposed by Dify UI/Base UI `data-*` selectors.
|
||||
- Incoming `className` placed before default classes in `cn(...)`, preventing call-site overrides.
|
||||
- Arbitrary z-index or one-off layering fixes on overlays.
|
||||
|
||||
## Classname ordering for easy overrides
|
||||
Use:
|
||||
|
||||
### Description
|
||||
- `cn(...)` from the local package or utility already used by the file.
|
||||
- Dify semantic tokens and Tailwind v4 utilities.
|
||||
- Existing component variants before one-off class forks.
|
||||
- Primitive selectors such as `data-disabled:*`, `data-checked:*`, `data-highlighted:*`, `group-data-*`, `peer-data-*`, and `has-[:focus-visible]` before adding React state or boolean props solely for styling.
|
||||
- Component-level variants, semantic tokens, and normal cascade/order before `!` overrides. Use `!` only for a contained compatibility override that cannot be expressed through the component API or local selector structure.
|
||||
|
||||
When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles.
|
||||
## Imports
|
||||
|
||||
Example:
|
||||
Flag:
|
||||
|
||||
```tsx
|
||||
import { cn } from '@/utils/classnames'
|
||||
- Barrel imports from `@langgenius/dify-ui`; consumers must use subpath exports.
|
||||
- New overlay imports from legacy `@/app/components/base/modal`, `dialog`, or `drawer`.
|
||||
- Cross-feature imports that bypass explicit top-level public files.
|
||||
- Direct imports from generated/internal implementation files when a feature contract already exposes the intended surface.
|
||||
|
||||
const Button = ({ className }) => {
|
||||
return <div className={cn('bg-primary-600', className)}></div>
|
||||
}
|
||||
```
|
||||
## Copy And i18n
|
||||
|
||||
Flag:
|
||||
|
||||
- User-facing hardcoded strings in `web/`.
|
||||
- Added or renamed i18n keys that are not present in every supported locale file for the touched namespace.
|
||||
- Translation namespace drift, especially using unrelated module namespaces for local feature copy.
|
||||
- Generic button labels like `Continue` where the action is specific.
|
||||
- Error messages that state only the failure and not the next step.
|
||||
|
||||
Use feature-local translation keys by default. Alias only when crossing namespaces. `pnpm i18n:check --file <name>` should pass for any touched translation namespace.
|
||||
|
||||
@ -0,0 +1,89 @@
|
||||
# Component Architecture Rules
|
||||
|
||||
Use these rules for React component structure, ownership, state, props, effects, and module organization.
|
||||
|
||||
## Ownership
|
||||
|
||||
Flag:
|
||||
|
||||
- State, query, mutation, or handlers hoisted above the lowest component that actually uses them.
|
||||
- Parent components owning row/item actions that do not coordinate a workflow.
|
||||
- Prop drilling through multiple pass-through layers.
|
||||
- A page/tab-level section component becoming the data owner without needing a shared snapshot or shared loading/error/empty UI.
|
||||
- Feature code promoted to shared only because it appears once or might be reused later.
|
||||
|
||||
Accept repeated TanStack Query calls in siblings when each component independently consumes the data. Cache deduplication is not a reason to hoist by itself.
|
||||
|
||||
## Component Boundaries
|
||||
|
||||
Flag:
|
||||
|
||||
- React component files over 300 lines when the file mixes multiple responsibilities that can be split into focused colocated components, hooks, or utilities.
|
||||
- Shallow wrappers that only rename props or hide the real primitive.
|
||||
- Extra DOM wrappers that do not provide layout, semantics, accessibility, state ownership, or library integration.
|
||||
- Dialog/dropdown/popover hidden surfaces that obscure the parent flow when they should be extracted into a small local component.
|
||||
- Business forms, menu bodies, or one-off helpers moved away from their owner without reuse or semantic value.
|
||||
|
||||
Prefer colocated components split by actual data and state needs.
|
||||
|
||||
## Bad Component Design Patterns
|
||||
|
||||
Flag:
|
||||
|
||||
- Refactors of existing navigation, sidebar, dropdown, webapp list, or app-switching UI that do not preserve behavior-sensitive interactions such as expand/collapse arrows, hover persistence, pin/delete controls, routing, keyboard/focus handling, or open-state ownership.
|
||||
- Components that mix data fetching, mutation side effects, popup state, form validation, layout, and row rendering without a clear owner.
|
||||
- Generic components with many boolean props that encode one feature's workflow.
|
||||
- A shared component that imports feature-specific copy, routes, or API contracts.
|
||||
- A feature component that accepts pre-rendered fragments only to avoid placing ownership correctly.
|
||||
- A child component that receives both raw server data and separately derived flags for the same concept.
|
||||
- A wrapper that changes accessible semantics of the primitive it wraps.
|
||||
- A component that exposes controlled props but still keeps a competing private state for the same value.
|
||||
- A component that cannot render empty, loading, or missing optional API fields without caller-side preprocessing.
|
||||
|
||||
When existing components already own interaction logic, prefer reusing or extending them. If a refactor is necessary, preserve the old interaction contract and add or update focused tests for changed behavior.
|
||||
|
||||
## Props And Types
|
||||
|
||||
Flag:
|
||||
|
||||
- `React.FC` / `FC`.
|
||||
- Default exports outside framework-required files.
|
||||
- Named `Props` types for trivial one-off props where inline typing is clearer.
|
||||
- Props named by UI implementation instead of domain/API role.
|
||||
- API data converted too early or under a generic name that breaks traceability.
|
||||
- Callers duplicating fallback checks that the lowest rendering component already handles.
|
||||
|
||||
Prefer top-level `function` declarations for components and module helpers. Use arrow functions for callbacks and local lambdas.
|
||||
|
||||
## Effects
|
||||
|
||||
Flag effects that:
|
||||
|
||||
- Transform props/state for rendering.
|
||||
- Copy one state value into another representing the same concept.
|
||||
- Handle user actions that belong in event handlers.
|
||||
- Reset state from props when a keyed reset, stable ID, or render-time derivation would work.
|
||||
- Fetch data that belongs in framework APIs or TanStack Query.
|
||||
|
||||
If an effect remains, it must synchronize with a named external system: browser API, subscription, timer, analytics-on-visibility, non-React widget, or imperative DOM integration.
|
||||
|
||||
## State Modeling
|
||||
|
||||
Flag:
|
||||
|
||||
- Storing derived booleans, disabled flags, default tabs, or loading labels that can be calculated from current query/feature state.
|
||||
- Local state used to fake server data or generated contract fields.
|
||||
- UI state persisted to localStorage when it is live app state.
|
||||
- Feature-local mock shells wired to unrelated existing APIs before the real API is confirmed.
|
||||
|
||||
Prefer render-time derivation. Keep true local state for user choices, transient input, controlled popups, and feature UI state that has no server source.
|
||||
|
||||
## Navigation
|
||||
|
||||
Flag:
|
||||
|
||||
- Imperative router navigation for ordinary links.
|
||||
- Button semantics used for navigation.
|
||||
- Navigation state hidden in component state when URL state is required for shareable filters, tabs, or pagination.
|
||||
|
||||
Use `Link` for normal navigation. Use router APIs for mutation success, guarded redirects, command flows, or form submission side effects.
|
||||
@ -0,0 +1,74 @@
|
||||
# Data, Query, And Contract Rules
|
||||
|
||||
Use these rules for generated contracts, TanStack Query, mutations, auth/SSR boundaries, URL state, and client persistence.
|
||||
|
||||
## Generated Contracts
|
||||
|
||||
Flag:
|
||||
|
||||
- New legacy service/helper wrappers around generated `queryOptions()` or `mutationOptions()`.
|
||||
- Continuing to use deprecated contract operations when a ready generated contract exists.
|
||||
- Assuming a generated file means an operation is ready without checking deprecated markers, schema shape, and the actual UI consumer.
|
||||
- Re-declaring API DTOs in components.
|
||||
- Adding compatibility layers instead of migrating the pointed line and deleting the old layer.
|
||||
|
||||
Use `web/contract/*` as the API shape source of truth. Follow existing `{ params, query?, body? }` input shape.
|
||||
|
||||
## Queries
|
||||
|
||||
Flag:
|
||||
|
||||
- `enabled` used to hide missing required input instead of `input: skipToken`.
|
||||
- Fake fallback IDs or placeholder inputs used to force a query to run.
|
||||
- Query results copied into local state for rendering.
|
||||
- Shared query behavior such as invalidation, stale defaults, or retry rules reimplemented at call sites.
|
||||
- `prefetchQuery` treated as a hard gate or as returning data/errors to the caller.
|
||||
|
||||
Use `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))` directly unless a feature hook performs real orchestration.
|
||||
|
||||
## Mutations
|
||||
|
||||
Flag:
|
||||
|
||||
- Deprecated `useInvalid` or `useReset`.
|
||||
- `mutateAsync` used without a need for Promise semantics.
|
||||
- Awaited mutations without `try/catch`.
|
||||
- Components owning shared cache invalidation that belongs in query defaults.
|
||||
- Optimistic updates that do not match current list/detail ownership.
|
||||
|
||||
Use generated `mutationOptions()` directly when possible. Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`.
|
||||
|
||||
## SSR, Auth, And Route Boundaries
|
||||
|
||||
Flag:
|
||||
|
||||
- Request-time auth, setup, workspace role, or tenant decisions moved into static `next.config redirects()`.
|
||||
- Dynamic role gates depending on `workspaces.current` implemented as static path redirects.
|
||||
- Authorization logic depending on soft `prefetchQuery`.
|
||||
- Removing a client fallback before server API unavailable behavior is defined.
|
||||
- Global placeholder query contracts introduced to solve a route-local Suspense issue.
|
||||
- Branding-sensitive UI reading placeholder defaults without checking pending/placeholder state.
|
||||
|
||||
Separate hard gates from soft prefetches. `fetchQuery` can be a server decision boundary; `prefetchQuery` is cache warmup.
|
||||
|
||||
## Workspace And Tenant
|
||||
|
||||
Flag:
|
||||
|
||||
- Treating workspace switch as ordinary CRUD invalidation when the current app flow performs server switch plus full reload.
|
||||
- Query keys that omit workspace/tenant identity when the query truly varies by workspace and no full reload boundary applies.
|
||||
- Mixing `workspace_id` and `tenant_id` without tracing the current backend/API contract.
|
||||
|
||||
Current Dify workspace switch should be reviewed as a tenant cache boundary first.
|
||||
|
||||
## URL State And Local Storage
|
||||
|
||||
Flag:
|
||||
|
||||
- Shareable filters, tabs, pagination, selected panels, or search state hidden only in component state.
|
||||
- One-shot navigation signals modeled as subscribed persistent state.
|
||||
- Live app state stored in localStorage.
|
||||
- Direct `window.localStorage`, `globalThis.localStorage`, or raw storage calls in app code.
|
||||
- High-frequency interaction state persisted on every change instead of on commit/settle.
|
||||
|
||||
Use URL state for shareable UI state, feature/Jotai/store state for live UI state, and `@/hooks/use-local-storage` only for low-frequency client-only preferences, dismissed notices, and UI defaults.
|
||||
@ -0,0 +1,22 @@
|
||||
# Dify Invariants
|
||||
|
||||
Use these stable Dify-specific runtime rules in addition to the generic review packs.
|
||||
|
||||
This file is not a place for active feature notes. Do not add rules for one branch, one PR, or a short-lived product decision such as a specific agent-v2, plugin, model-provider, or onboarding task. Keep a rule here only when all of these are true:
|
||||
|
||||
- It is a stable Dify runtime invariant.
|
||||
- Generic React, TypeScript, accessibility, dify-ui, query, or performance rules would not catch it.
|
||||
- The failure mode is concrete enough to produce a file-line review finding.
|
||||
- The rule is likely to remain valid across normal feature work.
|
||||
|
||||
## Workflow Nodes And RAG Pipe
|
||||
|
||||
Flag:
|
||||
|
||||
- Node components under `web/app/components/workflow/nodes/[nodeName]/node.tsx` importing workflow store hooks that are unavailable in RAG Pipe template rendering.
|
||||
- Node UI relying on provider context that is not mounted in every rendering surface.
|
||||
- Store reads in render where React Flow `useNodes` / `useEdges` provide the actual node/edge source.
|
||||
|
||||
Known failure mode: workflow node components can also render while creating a RAG Pipe from a template. In that context there may be no workflowStore provider, causing a blank screen.
|
||||
|
||||
Prefer React Flow hooks for node/edge UI consumption. Use store APIs only where the provider is guaranteed and the code path is workflow-only.
|
||||
134
.agents/skills/frontend-code-review/references/dify-ui.md
Normal file
134
.agents/skills/frontend-code-review/references/dify-ui.md
Normal file
@ -0,0 +1,134 @@
|
||||
# Dify UI Rules
|
||||
|
||||
Use these rules whenever a review touches `packages/dify-ui/` or code consuming `@langgenius/dify-ui/*`.
|
||||
|
||||
Before finalizing findings for those files, read the current local docs that apply:
|
||||
|
||||
- `packages/dify-ui/README.md`
|
||||
- `packages/dify-ui/AGENTS.md`
|
||||
- `web/docs/overlay.md` for floating UI
|
||||
- `packages/dify-ui/src/<primitive>/index.tsx` for the primitive being changed or consumed
|
||||
|
||||
## Package Boundary
|
||||
|
||||
Flag in `packages/dify-ui`:
|
||||
|
||||
- Imports from `web/`.
|
||||
- Dependencies on Next.js, i18n, ky, Jotai, Zustand, TanStack Query, oRPC, or business APIs.
|
||||
- Business-specific component behavior that belongs in `web/`.
|
||||
- Multiple unrelated primitives in one component folder.
|
||||
|
||||
`packages/dify-ui` is a primitive layer: Base UI headless components + `cva` + `cn` + Dify design tokens.
|
||||
|
||||
## Imports And Exports
|
||||
|
||||
Flag:
|
||||
|
||||
- Consumer imports from `@langgenius/dify-ui` without a subpath.
|
||||
- Missing `package.json#exports` entry for a new primitive.
|
||||
- Internal package imports using workspace subpaths instead of relative paths.
|
||||
- Exported props using internal-only types that consumers cannot import from the component subpath.
|
||||
|
||||
Consumers use subpath exports such as `@langgenius/dify-ui/button`.
|
||||
|
||||
## Props And State
|
||||
|
||||
Flag:
|
||||
|
||||
- Flattened props where related values need a discriminated union, such as `value` / `defaultValue`, `multiple` / `value`, or `clearable` / `onChange`.
|
||||
- React state used only to mirror Base UI state for class names.
|
||||
- JavaScript conditional class logic for visual states that the Dify UI/Base UI primitive already exposes through `data-*` attributes or CSS variables.
|
||||
- Controlled props added when uncontrolled DOM state or CSS variables would be enough.
|
||||
- Thin wrappers that rename Base UI parts without adding semantics.
|
||||
|
||||
Prefer Base UI/Dify UI data attributes and CSS variables for visual state: `data-open`, `data-checked`, `data-disabled`, `data-highlighted`, `data-popup-open`, `group-data-*`, `peer-data-*`, `has-[:focus-visible]`, and primitive CSS variables such as anchor width or transform origin. Use JS conditional classes for product/business state that the primitive does not expose.
|
||||
|
||||
## Forms
|
||||
|
||||
Flag:
|
||||
|
||||
- Form-like UI using unrelated `Input` and `Button` pieces without a submit boundary.
|
||||
- Text-like fields not composed through `FieldRoot`, `FieldLabel`, and `FieldControl` when using Dify UI form semantics.
|
||||
- Select fields using `FieldLabel` instead of `SelectLabel`.
|
||||
- Slider fields using a generic label instead of `SliderLabel`.
|
||||
- Checkbox/radio groups missing `FieldsetRoot` and `FieldsetLegend`.
|
||||
- Field errors or descriptions rendered without `FieldDescription` / `FieldError` relationships.
|
||||
|
||||
`Form` is the submit boundary. Dify UI form primitives are not a form state-management framework; business validation and schema-driven behavior belong in `web/`.
|
||||
|
||||
## Overlay Contract
|
||||
|
||||
Flag:
|
||||
|
||||
- Legacy web overlay imports in new or modified code.
|
||||
- Manual portals around Dify UI overlay primitives.
|
||||
- Call-site `z-*` overrides on overlays.
|
||||
- Missing root `isolation: isolate` assumptions when debugging overlay stacking.
|
||||
- Repeated backdrop, z-index, or portal chrome at call sites.
|
||||
- Tooltip used for infotips, long text, or interactive content.
|
||||
|
||||
All Dify UI body-portalled overlays use `z-50`. Toast uses `z-60`. DOM order handles stacking between overlays.
|
||||
|
||||
## Primitive Selection
|
||||
|
||||
Flag:
|
||||
|
||||
- `Tabs` used for simple mode/filter/view selection where `SegmentedControl` is the semantic primitive.
|
||||
- `SegmentedControl` used where `tablist` / `tabpanel` semantics are required.
|
||||
- `Select` used for searchable or free-form input.
|
||||
- `Combobox` used for unrestricted search text where no selected option is remembered.
|
||||
- `Autocomplete` used for closed-list selection.
|
||||
- Tooltip or PreviewCard used for content that must be reachable on touch or by screen readers.
|
||||
|
||||
Use:
|
||||
|
||||
- `Autocomplete` for free-form text with optional suggestions.
|
||||
- `Combobox` for searchable selected values from a collection.
|
||||
- `Select` for closed, scannable option sets.
|
||||
- `Popover` for infotips, help text, rich content, or interactions.
|
||||
|
||||
## Bad Usage Patterns To Flag
|
||||
|
||||
Flag:
|
||||
|
||||
- Manually recreating UI behavior or chrome already owned by `@langgenius/dify-ui/*` or `web/app/components/base/*`, such as buttons, inputs, toggle groups, popovers, dropdown menus, alert dialogs, switches, avatars, scroll areas, toasts, borders, focus states, disabled states, segmented controls, or existing feature components.
|
||||
- Styling a raw Base UI primitive directly in `web/` when a Dify UI primitive exists.
|
||||
- Wrapping a Dify UI primitive in a feature component that hides its label, error, disabled, or focus contract.
|
||||
- Replacing a semantic primitive with a generic `div` plus classes to match a screenshot.
|
||||
- Using `Tooltip` because it is visually convenient when the content is actually help text or needs touch access.
|
||||
- Adding a `z-*` override to make a child popup appear over a parent dialog.
|
||||
- Adding a new app-level wrapper around Dialog, Drawer, Popover, Select, or Combobox that repeats portal/backdrop/positioner logic.
|
||||
- Using dify-ui `Input` as a drop-in replacement for legacy inputs that include search, clear, copy, unit, localized placeholder, or number normalization behavior.
|
||||
- Building a form row from loose text and controls instead of the matching Field/Form primitives.
|
||||
- Adding component state only to style `data-open`, `data-checked`, `data-disabled`, or highlighted states that Base UI already exposes.
|
||||
- Passing booleans down only so children can toggle classes already expressible with primitive `data-*` selectors.
|
||||
|
||||
## Tokens, Radius, And Styling
|
||||
|
||||
Flag:
|
||||
|
||||
- `radius-*` class names.
|
||||
- Custom Tailwind `borderRadius` extension for Figma radius values.
|
||||
- Generic colors where semantic Dify tokens exist.
|
||||
- Hardcoded design values where Dify tokens, component variants, or documented Figma radius mappings exist.
|
||||
- `!` important modifiers used to fight primitive styles instead of fixing the variant, selector, or component composition.
|
||||
- Manual class strings that duplicate primitive variants.
|
||||
- `min-w-(--anchor-width)` on picker popups when it defeats viewport clamping.
|
||||
|
||||
Use the Figma radius mapping from `packages/dify-ui/AGENTS.md`; for example `--radius/sm` maps to `rounded-md`, and `--radius/md` maps to `rounded-lg`.
|
||||
|
||||
Use `!` only for a tightly scoped compatibility override after confirming the primitive API, data attributes, and selector structure cannot express the state.
|
||||
|
||||
## Focus Details
|
||||
|
||||
Flag focus rings attached to the wrong element. For example, Base UI `Slider.Thumb` focuses an internal `input[type=range]`, so the visible thumb wrapper needs `has-[:focus-visible]` rather than direct wrapper `focus-visible`.
|
||||
|
||||
## Custom SVG Icons
|
||||
|
||||
Flag:
|
||||
|
||||
- New generated React icon components or JSON files under `web/app/components/base/icons/src/...` for custom SVG icons.
|
||||
- Custom SVG icons consumed outside the Tailwind `i-custom-*` icon class pipeline.
|
||||
- Generated `packages/iconify-collections/custom-*/icons.json` diffs where unrelated existing icons lost or changed intrinsic `width` or `height`.
|
||||
|
||||
New custom SVG icons belong in `packages/iconify-collections/assets/...`. Regenerate with `pnpm --filter @dify/iconify-collections generate`, validate with `pnpm --filter @dify/iconify-collections check:dimensions`, and consume the generated icon with Tailwind `i-custom-*` classes.
|
||||
@ -1,45 +1,78 @@
|
||||
# Rule Catalog — Performance
|
||||
# Performance Rules
|
||||
|
||||
## React Flow data usage
|
||||
Review performance only where there is realistic impact. Do not request `memo`, `useMemo`, `useCallback`, virtualization, or caching as style preferences.
|
||||
|
||||
IsUrgent: True
|
||||
Category: Performance
|
||||
## Async Waterfalls
|
||||
|
||||
### Description
|
||||
Flag:
|
||||
|
||||
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
|
||||
- Awaiting remote feature flags or fetches before checking cheap synchronous conditions.
|
||||
- Sequential awaits for independent operations.
|
||||
- API routes or server components starting requests late when they could start early.
|
||||
- Nested per-item fetches running serially when each item can fetch in parallel.
|
||||
- Suspense boundaries that force the whole page to wait when a lower boundary could stream or isolate loading.
|
||||
|
||||
## Complex prop stability
|
||||
Prefer `Promise.all` for independent work and branch-local awaits for conditionally needed data.
|
||||
|
||||
IsUrgent: False
|
||||
Category: Performance
|
||||
## Bundle Size
|
||||
|
||||
### Description
|
||||
Flag:
|
||||
|
||||
Only require stable object, array, or map props when there is a clear reason: the child is memoized, the value participates in effect/query dependencies, the value is part of a stable-reference API contract, or profiling/local behavior shows avoidable re-renders. Do not request `useMemo` for every inline object by default; `how-to-write-component` treats memoization as a targeted optimization.
|
||||
- Barrel imports from heavy libraries or `@langgenius/dify-ui`.
|
||||
- Dynamic paths that prevent static trace analysis.
|
||||
- Heavy components loaded eagerly when hidden behind a dialog, tab, command, or feature activation.
|
||||
- Analytics, logging, editor, visualization, or third-party SDK code loaded before it is needed.
|
||||
- Feature-local optional modules imported at top level only for rare flows.
|
||||
|
||||
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
|
||||
Use direct imports and `next/dynamic` where the user-visible path benefits.
|
||||
|
||||
Risky:
|
||||
## Server Rendering
|
||||
|
||||
```tsx
|
||||
<HeavyComp
|
||||
config={{
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}}
|
||||
/>
|
||||
```
|
||||
Flag:
|
||||
|
||||
Better when stable identity matters:
|
||||
- Request-specific mutable state stored at module scope in SSR/RSC paths.
|
||||
- Large duplicate data serialized across RSC/client boundaries.
|
||||
- Static I/O repeated per request when it could be hoisted safely.
|
||||
- Cross-request cache without a bounded invalidation strategy.
|
||||
- Server actions lacking API-route-equivalent auth checks.
|
||||
|
||||
```tsx
|
||||
const config = useMemo(() => ({
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}), [provider, detail]);
|
||||
Use request-scoped deduplication such as `React.cache()` when repeated server reads in one request are the problem.
|
||||
|
||||
<HeavyComp
|
||||
config={config}
|
||||
/>
|
||||
```
|
||||
## Re-rendering
|
||||
|
||||
Flag:
|
||||
|
||||
- Effects or subscriptions reading broad state when a derived boolean or narrower selector is enough.
|
||||
- Components defined inside components.
|
||||
- Derived rendering state stored in state/effects.
|
||||
- Non-primitive default props recreated for memoized children.
|
||||
- Expensive work recalculated on every render where it affects real interaction cost.
|
||||
- High-frequency transient values stored in state when refs or CSS variables would avoid render loops.
|
||||
|
||||
Do not flag simple primitive expressions wrapped or not wrapped in `useMemo`; prefer no memo for simple work.
|
||||
|
||||
Require stable object/array/function identity only when:
|
||||
|
||||
- The child is memoized and identity affects renders.
|
||||
- The value is an effect/query dependency.
|
||||
- A library API requires stable references.
|
||||
- Profiling or local behavior shows avoidable re-rendering.
|
||||
|
||||
## DOM, Lists, And Rendering
|
||||
|
||||
Flag:
|
||||
|
||||
- Layout reads in render (`getBoundingClientRect`, `offset*`, `scrollTop`).
|
||||
- Interleaved DOM reads/writes that can cause layout thrashing.
|
||||
- Large lists rendering without virtualization, pagination, or `content-visibility`.
|
||||
- SVG/animation code animating expensive properties when transform/opacity would work.
|
||||
- `transition-all`.
|
||||
- Long-running non-critical browser work performed immediately instead of idle/deferred scheduling.
|
||||
|
||||
## React Flow
|
||||
|
||||
For workflow React Flow components, keep this Dify-specific rule:
|
||||
|
||||
- UI consumption should use React Flow hooks such as `useNodes` / `useEdges`.
|
||||
- Callback-only reads or mutations can use `useStoreApi`.
|
||||
- Node components under `web/app/components/workflow/nodes/[nodeName]/node.tsx` must not depend on workflow stores that are absent in RAG Pipe template rendering.
|
||||
|
||||
72
.agents/skills/frontend-code-review/references/testing.md
Normal file
72
.agents/skills/frontend-code-review/references/testing.md
Normal file
@ -0,0 +1,72 @@
|
||||
# Testing Review Rules
|
||||
|
||||
Use these rules when reviewing test files, testability of changed code, or risky frontend changes that should have tests.
|
||||
|
||||
## Missing Coverage
|
||||
|
||||
Flag missing tests when the change affects:
|
||||
|
||||
- User-visible behavior, navigation, form submission, validation, permissions, or loading/error/empty states.
|
||||
- Query/mutation cache behavior.
|
||||
- Accessibility-critical behavior such as labels, keyboard flow, focus, disabled state, or popup reachability.
|
||||
- URL state parsing/serialization.
|
||||
- Storage persistence or one-shot signals.
|
||||
- Regression-prone workflow or generated contract migration paths.
|
||||
|
||||
Do not request tests for purely mechanical renames or styling-only changes unless the styling affects layout, focus, or interaction.
|
||||
|
||||
## Selectors
|
||||
|
||||
Flag:
|
||||
|
||||
- `getByTestId` used where role, label, text, placeholder, landmark, or scoped dialog/menu queries are available.
|
||||
- Production `data-testid` added only to satisfy tests.
|
||||
- Assertions against decorative icons rather than the named control.
|
||||
- Tests that cannot find controls semantically but leave broken markup unchanged.
|
||||
|
||||
Prefer `getByRole` with accessible name, then `getByLabelText`, `getByPlaceholderText`, `getByText`, and `within(...)`.
|
||||
|
||||
## Mocking
|
||||
|
||||
Flag:
|
||||
|
||||
- Mocking `@langgenius/dify-ui/*` primitives.
|
||||
- Mocking `@/app/components/base/*` components when the real component is practical.
|
||||
- Mocking sibling or child components in the same directory for integration behavior.
|
||||
- Mocks that do not match the real component's conditional rendering.
|
||||
- Module-level mock state not reset in `beforeEach`.
|
||||
- `vi.clearAllMocks()` in `afterEach` instead of `beforeEach`.
|
||||
|
||||
Use real project components for integration behavior. Mock APIs, `next/navigation`, browser shims, or complex providers only when setup would dominate the test.
|
||||
|
||||
## Behavior
|
||||
|
||||
Flag:
|
||||
|
||||
- Tests inspecting implementation details instead of user-observable behavior.
|
||||
- Assertions that hardcode brittle copy when pattern matching or semantic roles would express behavior better.
|
||||
- Fake timers used without real timing behavior.
|
||||
- Async assertions missing `await`, `findBy*`, or `waitFor`.
|
||||
- Test data missing required fields because inline partial objects bypass real types.
|
||||
|
||||
Use typed factory functions with complete defaults and partial overrides.
|
||||
|
||||
## URL State
|
||||
|
||||
For `nuqs` or query-state hooks, flag tests that:
|
||||
|
||||
- Mock URL state when URL synchronization is the behavior under review.
|
||||
- Do not test parser serialize/parse round trips for custom parsers.
|
||||
- Do not assert default-clearing behavior when defaults should be removed from the URL.
|
||||
|
||||
Prefer shared `NuqsTestingAdapter` helpers when available.
|
||||
|
||||
## Organization
|
||||
|
||||
Flag:
|
||||
|
||||
- Component/hook/util tests outside sibling `__tests__/` directories.
|
||||
- Directory-level reviews that test only `index.tsx` while other files in scope contain behavior.
|
||||
- Large test files with repeated setup that should use local builders.
|
||||
|
||||
When a component is very complex, prefer a refactor finding before asking for exhaustive tests.
|
||||
33
.agents/skills/karpathy-guidelines/SKILL.md
Normal file
33
.agents/skills/karpathy-guidelines/SKILL.md
Normal file
@ -0,0 +1,33 @@
|
||||
---
|
||||
name: karpathy-guidelines
|
||||
description: Lightweight coding guardrails for making focused, simple, and verifiable changes in this repo. Use for all coding work.
|
||||
---
|
||||
|
||||
# Karpathy Guidelines
|
||||
|
||||
Use this skill whenever you touch code in this repository.
|
||||
|
||||
## Principles
|
||||
|
||||
- Keep the change small and directly tied to the user request.
|
||||
- Prefer the simplest implementation that fits the existing codebase.
|
||||
- Read the nearby code first, then match its patterns.
|
||||
- Avoid unrelated refactors, broad rewrites, or style churn.
|
||||
- Preserve existing behavior unless the user explicitly asked to change it.
|
||||
- Treat regressions as a signal to narrow the change, not to add workaround layers.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Inspect the current implementation and tests around the change.
|
||||
2. Make the smallest coherent edit.
|
||||
3. Add or update focused tests when the behavior changes or the risk is non-trivial.
|
||||
4. Run the narrowest relevant verification first.
|
||||
5. Report exactly what was verified and anything left unverified.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
- Does this change solve the stated problem without expanding scope?
|
||||
- Did it preserve existing route/component/data-flow semantics?
|
||||
- Are new abstractions justified by real complexity?
|
||||
- Are tests focused on the behavior that could regress?
|
||||
- Are unrelated files and generated artifacts left alone?
|
||||
1
.claude/skills/karpathy-guidelines
Symbolic link
1
.claude/skills/karpathy-guidelines
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/karpathy-guidelines
|
||||
13
.github/CODEOWNERS
vendored
13
.github/CODEOWNERS
vendored
@ -15,6 +15,10 @@
|
||||
# Agents
|
||||
/.agents/skills/ @hyoban
|
||||
|
||||
# Packages
|
||||
/packages/ @lyzno1
|
||||
/packages/contracts/ @crazywoola @laipz8200
|
||||
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
@ -143,6 +147,14 @@
|
||||
# Frontend
|
||||
/web/ @iamjoel
|
||||
|
||||
# Frontend - Platform and Features
|
||||
/web/config/ @lyzno1
|
||||
/web/contract/ @lyzno1
|
||||
/web/env.ts @lyzno1
|
||||
/web/features/ @lyzno1
|
||||
/web/hooks/ @lyzno1
|
||||
/web/scripts/gen-icons.mjs @lyzno1
|
||||
|
||||
# Frontend - Web Tests
|
||||
/.github/workflows/web-tests.yml @iamjoel
|
||||
|
||||
@ -253,7 +265,6 @@
|
||||
/web/utils/time.ts @iamjoel @zxhlyh
|
||||
/web/utils/format.ts @iamjoel @zxhlyh
|
||||
/web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Billing and Education
|
||||
/web/app/components/billing/ @iamjoel @zxhlyh
|
||||
|
||||
415
.github/workflows/cli-e2e.yml
vendored
Normal file
415
.github/workflows/cli-e2e.yml
vendored
Normal file
@ -0,0 +1,415 @@
|
||||
name: CLI E2E Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
cli_ref:
|
||||
description: "Git ref (default: current branch)"
|
||||
type: string
|
||||
required: false
|
||||
|
||||
edition:
|
||||
description: "Dify edition"
|
||||
type: choice
|
||||
required: false
|
||||
default: ee
|
||||
options: [ee, ce]
|
||||
|
||||
test_scope:
|
||||
description: "smoke = [P0] only / full = all cases"
|
||||
type: choice
|
||||
required: false
|
||||
default: full
|
||||
options: [smoke, full]
|
||||
|
||||
# ── Suite on/off ────────────────────────────────────────────────────────
|
||||
suite_framework_output_error:
|
||||
description: "framework + output + error-handling suites"
|
||||
type: boolean
|
||||
default: true
|
||||
suite_discovery:
|
||||
description: "discovery suite (get app / describe app)"
|
||||
type: boolean
|
||||
default: true
|
||||
suite_run:
|
||||
description: "run suite (basic / streaming / conversation / file / hitl)"
|
||||
type: boolean
|
||||
default: true
|
||||
suite_auth:
|
||||
description: "auth suite (login / status / whoami / use / devices / logout)"
|
||||
type: boolean
|
||||
default: true
|
||||
suite_agent:
|
||||
description: "agent suite"
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# ── Shared env injected into every E2E job ───────────────────────────────────
|
||||
# Each job reads DIFY_E2E_TOKEN + app IDs from the provision job outputs,
|
||||
# so global-setup skips minting and finds existing apps in < 10 s.
|
||||
env:
|
||||
DIFY_E2E_NO_KEYRING: "1" # Linux CI has no keychain; skip probe
|
||||
VITEST_RETRY: "2" # Retry flaky staging responses
|
||||
|
||||
jobs:
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 0. PROVISION — mint token + import DSL fixtures (runs once, outputs IDs)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
provision:
|
||||
name: "Provision: mint token + DSL apps"
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
token: ${{ steps.out.outputs.DIFY_E2E_TOKEN }}
|
||||
workspace_id: ${{ steps.out.outputs.DIFY_E2E_WORKSPACE_ID }}
|
||||
workspace_name: ${{ steps.out.outputs.DIFY_E2E_WORKSPACE_NAME }}
|
||||
ws2_id: ${{ steps.out.outputs.DIFY_E2E_WS2_ID }}
|
||||
chat_app_id: ${{ steps.out.outputs.DIFY_E2E_CHAT_APP_ID }}
|
||||
workflow_app_id: ${{ steps.out.outputs.DIFY_E2E_WORKFLOW_APP_ID }}
|
||||
file_app_id: ${{ steps.out.outputs.DIFY_E2E_FILE_APP_ID }}
|
||||
file_chat_app_id: ${{ steps.out.outputs.DIFY_E2E_FILE_CHAT_APP_ID }}
|
||||
hitl_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_APP_ID }}
|
||||
hitl_external_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_EXTERNAL_APP_ID }}
|
||||
hitl_single_action_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_SINGLE_ACTION_APP_ID }}
|
||||
hitl_multi_node_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_MULTI_NODE_APP_ID }}
|
||||
ws2_app_id: ${{ steps.out.outputs.DIFY_E2E_WS2_APP_ID }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with:
|
||||
package_json_field: packageManager
|
||||
run_install: false
|
||||
|
||||
- name: Install CLI dependencies
|
||||
working-directory: cli
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Mint token & provision apps
|
||||
id: out
|
||||
working-directory: cli
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_TOKEN: ${{ secrets.DIFY_E2E_TOKEN }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
run: bun scripts/e2e-provision.ts
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 1-B. framework + output + error-handling (parallel with run/discovery)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-framework-output-error:
|
||||
name: "Suite: framework + output + error-handling"
|
||||
if: ${{ inputs.suite_framework_output_error != 'false' }}
|
||||
needs: provision
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: Run framework + output + error-handling
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
|
||||
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
|
||||
DIFY_E2E_INCLUDE: "test/e2e/suites/framework/**/*.e2e.ts,test/e2e/suites/output/**/*.e2e.ts,test/e2e/suites/error-handling/**/*.e2e.ts"
|
||||
run: |
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
pnpm test:e2e
|
||||
fi
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 1-C. Discovery (parallel)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-discovery:
|
||||
name: "Suite: discovery"
|
||||
if: ${{ inputs.suite_discovery != 'false' }}
|
||||
needs: provision
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: Run discovery suite
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
|
||||
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
|
||||
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
|
||||
DIFY_E2E_INCLUDE: "test/e2e/suites/discovery/**/*.e2e.ts"
|
||||
run: |
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
pnpm test:e2e
|
||||
fi
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 1-D. Run suite — 5 files in matrix (parallel)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-run:
|
||||
name: "Suite: run / ${{ matrix.name }}"
|
||||
if: ${{ inputs.suite_run != 'false' }}
|
||||
needs: provision
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- name: basic
|
||||
file: run-app-basic.e2e.ts
|
||||
- name: streaming
|
||||
file: run-app-streaming.e2e.ts
|
||||
- name: conversation
|
||||
file: run-app-conversation.e2e.ts
|
||||
- name: file
|
||||
file: run-app-file.e2e.ts
|
||||
- name: hitl
|
||||
file: run-app-hitl.e2e.ts
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: "Run run/${{ matrix.name }}"
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_SSO_TOKEN: ${{ secrets.DIFY_E2E_SSO_TOKEN }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
|
||||
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
|
||||
DIFY_E2E_FILE_APP_ID: ${{ needs.provision.outputs.file_app_id }}
|
||||
DIFY_E2E_FILE_CHAT_APP_ID: ${{ needs.provision.outputs.file_chat_app_id }}
|
||||
DIFY_E2E_HITL_APP_ID: ${{ needs.provision.outputs.hitl_app_id }}
|
||||
DIFY_E2E_HITL_EXTERNAL_APP_ID: ${{ needs.provision.outputs.hitl_external_app_id }}
|
||||
DIFY_E2E_HITL_SINGLE_ACTION_APP_ID: ${{ needs.provision.outputs.hitl_single_action_app_id }}
|
||||
DIFY_E2E_HITL_MULTI_NODE_APP_ID: ${{ needs.provision.outputs.hitl_multi_node_app_id }}
|
||||
DIFY_E2E_INCLUDE: "test/e2e/suites/run/${{ matrix.file }}"
|
||||
run: |
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
pnpm test:e2e
|
||||
fi
|
||||
|
||||
- name: Upload results on failure
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: e2e-run-${{ matrix.name }}-${{ github.run_id }}
|
||||
path: cli/test-results/
|
||||
retention-days: 3
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 1-E. auth/login + status + whoami (parallel, read-only, safe)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-auth-safe:
|
||||
name: "Suite: auth (login / status / whoami)"
|
||||
if: ${{ inputs.suite_auth != 'false' }}
|
||||
needs: provision
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: Run auth/login + status + whoami
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
|
||||
DIFY_E2E_INCLUDE: "test/e2e/suites/auth/login.e2e.ts,test/e2e/suites/auth/status.e2e.ts,test/e2e/suites/auth/whoami.e2e.ts"
|
||||
run: |
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
pnpm test:e2e
|
||||
fi
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 2. DESTRUCTIVE — auth/use + devices + logout + agent (serial, runs LAST)
|
||||
# Must wait for ALL parallel suites to finish to avoid token revocation
|
||||
# invalidating other in-flight requests.
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-last:
|
||||
name: "Suite: auth-use + devices + logout + agent (last, serial)"
|
||||
# Runs when auth is selected; also runs after all parallel jobs finish
|
||||
if: ${{ inputs.suite_auth != 'false' || inputs.suite_agent != 'false' }}
|
||||
needs:
|
||||
- provision
|
||||
- suite-framework-output-error
|
||||
- suite-discovery
|
||||
- suite-run
|
||||
- suite-auth-safe
|
||||
# `needs` on a skipped job is treated as success — safe to proceed even if
|
||||
# some suites were disabled via toggle.
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 25
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: Run use / devices / logout / agent (serial)
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
|
||||
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
|
||||
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
|
||||
DIFY_E2E_HITL_APP_ID: ${{ needs.provision.outputs.hitl_app_id }}
|
||||
DIFY_E2E_HITL_EXTERNAL_APP_ID: ${{ needs.provision.outputs.hitl_external_app_id }}
|
||||
DIFY_E2E_HITL_SINGLE_ACTION_APP_ID: ${{ needs.provision.outputs.hitl_single_action_app_id }}
|
||||
DIFY_E2E_HITL_MULTI_NODE_APP_ID: ${{ needs.provision.outputs.hitl_multi_node_app_id }}
|
||||
run: |
|
||||
# Collect files in safe order: use → devices → logout (revokes last) → agent
|
||||
FILES=()
|
||||
if [ "${{ inputs.suite_auth }}" = "true" ]; then
|
||||
FILES+=(
|
||||
test/e2e/suites/auth/use.e2e.ts
|
||||
test/e2e/suites/auth/devices.e2e.ts
|
||||
test/e2e/suites/auth/logout.e2e.ts
|
||||
)
|
||||
fi
|
||||
if [ "${{ inputs.suite_agent }}" = "true" ]; then
|
||||
while IFS= read -r f; do FILES+=("$f"); done \
|
||||
< <(find test/e2e/suites/agent -name '*.e2e.ts' | sort)
|
||||
fi
|
||||
|
||||
[ ${#FILES[@]} -eq 0 ] && { echo "Nothing to run."; exit 0; }
|
||||
|
||||
# Pass files via DIFY_E2E_INCLUDE (comma-separated) so vitest
|
||||
# config's include list is overridden instead of ANDed.
|
||||
INCLUDE=$(IFS=,; echo "${FILES[*]}")
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
DIFY_E2E_INCLUDE="$INCLUDE" pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
DIFY_E2E_INCLUDE="$INCLUDE" pnpm test:e2e
|
||||
fi
|
||||
|
||||
- name: Upload results on failure
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: e2e-last-${{ github.run_id }}
|
||||
path: cli/test-results/
|
||||
retention-days: 3
|
||||
178
.github/workflows/cli-release.yml
vendored
178
.github/workflows/cli-release.yml
vendored
@ -2,87 +2,165 @@ name: CLI Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- 'difyctl-v*'
|
||||
inputs:
|
||||
release_tag:
|
||||
description: Dify release tag to attach difyctl assets to (blank = latest stable)
|
||||
required: false
|
||||
type: string
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: Dify release tag to attach difyctl assets to (blank = latest stable)
|
||||
required: false
|
||||
type: string
|
||||
release:
|
||||
types: [released]
|
||||
|
||||
concurrency:
|
||||
group: cli-release-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
group: difyctl-release
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: build standalone binaries (all targets)
|
||||
validate:
|
||||
name: validate manifest + resolve target Dify release
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: read
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
outputs:
|
||||
dify_tag: ${{ steps.resolve.outputs.dify_tag }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Export manifest to env
|
||||
run: node scripts/release-naming.mjs github-env >> "$GITHUB_ENV"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Resolve target Dify release
|
||||
id: resolve
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
EVENT_TAG: ${{ github.event.release.tag_name }}
|
||||
INPUT_TAG: ${{ inputs.release_tag }}
|
||||
run: |
|
||||
if [ -n "$EVENT_TAG" ]; then
|
||||
tag="$EVENT_TAG"
|
||||
elif [ -n "$INPUT_TAG" ]; then
|
||||
tag="$INPUT_TAG"
|
||||
else
|
||||
tag="$(gh api "repos/${GITHUB_REPOSITORY}/releases/latest" --jq .tag_name)"
|
||||
fi
|
||||
if [ -z "$tag" ]; then
|
||||
echo "::error::could not resolve a target Dify release tag"
|
||||
exit 1
|
||||
fi
|
||||
if ! gh release view "$tag" --repo "$GITHUB_REPOSITORY" >/dev/null 2>&1; then
|
||||
echo "::error::target Dify release ${tag} not found"
|
||||
exit 1
|
||||
fi
|
||||
echo "dify_tag=${tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "::notice::target Dify release ${tag}"
|
||||
|
||||
- name: Compatibility check
|
||||
env:
|
||||
DIFY_TAG: ${{ steps.resolve.outputs.dify_tag }}
|
||||
run: node scripts/release-naming.mjs compat-check "$DIFY_TAG"
|
||||
|
||||
- name: Reject duplicate difyctl version
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
if gh api "repos/${GITHUB_REPOSITORY}/git/ref/tags/${difyctlTag}" >/dev/null 2>&1; then
|
||||
echo "::error::difyctl ${version} already released (tag ${difyctlTag} exists); bump cli/package.json version"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
release:
|
||||
name: build + attach standalone binaries (all targets)
|
||||
needs: validate
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
contents: write
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
env:
|
||||
DIFY_TAG: ${{ needs.validate.outputs.dify_tag }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Enable cross-arch native prebuilds
|
||||
working-directory: ./
|
||||
run: cat cli/scripts/cross-arch.pnpm.yaml >> pnpm-workspace.yaml
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Export manifest to env
|
||||
run: node scripts/release-naming.mjs github-env >> "$GITHUB_ENV"
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Read cli/package.json
|
||||
id: manifest
|
||||
run: |
|
||||
version=$(node -p "require('./package.json').version")
|
||||
channel=$(node -p "require('./package.json').difyctl.channel")
|
||||
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
|
||||
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "channel=$channel"
|
||||
echo "minDify=$minDify"
|
||||
echo "maxDify=$maxDify"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Install cross-arch native prebuilds
|
||||
# Re-installs node_modules with every @napi-rs/keyring platform variant
|
||||
# so `bun build --compile` can embed the right .node into each target.
|
||||
working-directory: ./
|
||||
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
|
||||
bun-version-file: cli/.bun-version
|
||||
|
||||
- name: Compile standalone binaries (all targets)
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
|
||||
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
|
||||
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
|
||||
run: |
|
||||
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
|
||||
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
|
||||
pnpm build:bin
|
||||
|
||||
- name: Generate sha256 checksum file
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
run: scripts/release-write-checksums.sh
|
||||
|
||||
- name: Publish GitHub Release
|
||||
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
|
||||
with:
|
||||
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
|
||||
name: difyctl ${{ steps.manifest.outputs.version }}
|
||||
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
|
||||
generate_release_notes: true
|
||||
fail_on_unmatched_files: true
|
||||
files: |
|
||||
cli/dist/bin/difyctl-v*
|
||||
- name: Attach difyctl assets to Dify release
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
gh release upload "$DIFY_TAG" dist/bin/${tagPrefix}* \
|
||||
--repo "$GITHUB_REPOSITORY" --clobber
|
||||
|
||||
- name: Prune stale difyctl assets
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
new_set="$(cd dist/bin && ls ${tagPrefix}*)"
|
||||
gh release view "$DIFY_TAG" --repo "$GITHUB_REPOSITORY" \
|
||||
--json assets --jq '.assets[].name' \
|
||||
| { grep -E "^${tagPrefix}" || true; } \
|
||||
| while IFS= read -r name; do
|
||||
if ! printf '%s\n' "$new_set" | grep -qxF -- "$name"; then
|
||||
echo "::notice::pruning stale asset ${name}"
|
||||
gh release delete-asset "$DIFY_TAG" "$name" \
|
||||
--repo "$GITHUB_REPOSITORY" --yes
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Create provenance tag
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
ref="refs/tags/${difyctlTag}"
|
||||
sha="$(git rev-parse HEAD)"
|
||||
status="$(gh api -X POST "repos/${GITHUB_REPOSITORY}/git/refs" \
|
||||
-f ref="$ref" -f sha="$sha" --silent --include 2>/dev/null \
|
||||
| awk 'NR==1 {print $2; exit}' || true)"
|
||||
case "$status" in
|
||||
201) echo "::notice::created ${ref}" ;;
|
||||
422) echo "::notice::tag ${ref} already exists; skipping (immutable)" ;;
|
||||
*) echo "::error::provenance tag ${ref} not created (HTTP ${status:-unknown})"; exit 1 ;;
|
||||
esac
|
||||
|
||||
11
.github/workflows/cli-tests.yml
vendored
11
.github/workflows/cli-tests.yml
vendored
@ -37,8 +37,17 @@ jobs:
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Validate release manifest
|
||||
if: matrix.os == 'depot-ubuntu-24.04'
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: CI pipeline (typecheck, lint, coverage, build)
|
||||
run: pnpm ci
|
||||
run: pnpm run ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.os == 'depot-ubuntu-24.04' }}
|
||||
|
||||
@ -7,7 +7,6 @@ consumes injected context managers when it needs to preserve thread-local state.
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, override, runtime_checkable
|
||||
@ -15,28 +14,25 @@ from typing import Any, Protocol, final, override, runtime_checkable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AppContext(ABC):
|
||||
class AppContext(Protocol):
|
||||
"""
|
||||
Abstract application context interface.
|
||||
Application context interface.
|
||||
|
||||
Application adapters can implement this to restore framework-specific state
|
||||
such as Flask app context around worker execution.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get application extension by name."""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def enter(self) -> AbstractContextManager[None]:
|
||||
"""Enter the application context."""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
||||
@ -41,3 +41,13 @@ class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "not_found"
|
||||
code = 404
|
||||
|
||||
|
||||
class InvalidArgumentError(BaseHTTPException):
|
||||
error_code = "invalid_param"
|
||||
code = 400
|
||||
|
||||
@ -122,6 +122,7 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
from .socketio import workflow as socketio_workflow
|
||||
|
||||
# Import tag controllers
|
||||
@ -137,6 +138,7 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -212,6 +214,9 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"spec",
|
||||
"statistic",
|
||||
|
||||
@ -5,11 +5,17 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
from services.agent.skill_package_service import SkillPackageError, SkillPackageService
|
||||
from services.agent.skill_standardize_service import SkillStandardizeService
|
||||
from services.agent_drive_service import AgentDriveError
|
||||
from services.agent_service import AgentService
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class AgentLogQuery(BaseModel):
|
||||
@ -44,3 +50,80 @@ class AgentLogApi(Resource):
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/skills/upload")
|
||||
class AgentSkillUploadApi(Resource):
|
||||
@console_ns.doc("upload_agent_skill")
|
||||
@console_ns.doc(description="Upload + validate a Skill package (.zip/.skill) and extract its manifest")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(201, "Skill validated")
|
||||
@console_ns.response(400, "Invalid skill package")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Validate an uploaded Skill package and persist the archive.
|
||||
|
||||
Returns a validated skill ref (to bind into the Agent soul config on save)
|
||||
plus its manifest. Standardizing into the agent drive is ENG-594.
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
return {"code": "no_file", "message": "no skill file uploaded"}, 400
|
||||
if len(request.files) > 1:
|
||||
return {"code": "too_many_files", "message": "only one skill file is allowed"}, 400
|
||||
|
||||
upload = request.files["file"]
|
||||
content = upload.stream.read()
|
||||
try:
|
||||
manifest = SkillPackageService().validate_and_extract(content=content, filename=upload.filename or "")
|
||||
except SkillPackageError as exc:
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=upload.filename or "skill.zip",
|
||||
content=content,
|
||||
mimetype=upload.mimetype or "application/zip",
|
||||
user=current_user,
|
||||
)
|
||||
skill_ref = manifest.to_skill_ref(file_id=upload_file.id)
|
||||
return {"skill": skill_ref.model_dump(exclude_none=True), "manifest": manifest.model_dump()}, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/skills/standardize")
|
||||
class AgentSkillStandardizeApi(Resource):
|
||||
@console_ns.doc("standardize_agent_skill")
|
||||
@console_ns.doc(description="Validate + standardize a Skill into the agent drive (ENG-594)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(201, "Skill standardized into drive")
|
||||
@console_ns.response(400, "Invalid skill package or no bound agent")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Upload a Skill, validate it, and standardize it into the app agent's drive."""
|
||||
agent_id = app_model.bound_agent_id
|
||||
if not agent_id:
|
||||
return {"code": "no_bound_agent", "message": "app has no bound agent"}, 400
|
||||
if "file" not in request.files:
|
||||
return {"code": "no_file", "message": "no skill file uploaded"}, 400
|
||||
if len(request.files) > 1:
|
||||
return {"code": "too_many_files", "message": "only one skill file is allowed"}, 400
|
||||
|
||||
upload = request.files["file"]
|
||||
content = upload.stream.read()
|
||||
try:
|
||||
result = SkillStandardizeService().standardize(
|
||||
content=content,
|
||||
filename=upload.filename or "",
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
except (SkillPackageError, AgentDriveError) as exc:
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
return result, 201
|
||||
|
||||
@ -24,9 +24,9 @@ from controllers.common.schema import (
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_workspace_service import (
|
||||
AgentAppWorkspaceService,
|
||||
@ -142,8 +142,8 @@ class AgentAppWorkspaceListResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
query = query_params_from_request(AgentWorkspaceListQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().list_files(
|
||||
@ -167,8 +167,8 @@ class AgentAppWorkspacePreviewResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().preview(
|
||||
@ -194,8 +194,8 @@ class AgentAppWorkspaceDownloadResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().download(
|
||||
@ -228,8 +228,8 @@ class WorkflowAgentWorkspaceListResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceListQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().list_files(
|
||||
@ -264,8 +264,8 @@ class WorkflowAgentWorkspacePreviewResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().preview(
|
||||
@ -302,8 +302,8 @@ class WorkflowAgentWorkspaceDownloadResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().download(
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -64,16 +64,17 @@ register_enum_models(console_ns, IconType)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
_CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$")
|
||||
AppListMode = Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"]
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
mode: AppListMode = Field(default=cast(AppListMode, "all"), description="App mode filter")
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
creator_ids: list[str] | None = Field(default=None, description="Filter by creator account IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@ -94,10 +95,29 @@ class AppListQuery(BaseModel):
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
@field_validator("creator_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_creator_ids(cls, value: list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("Unsupported creator_ids type.")
|
||||
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
if not items:
|
||||
return None
|
||||
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in creator_ids.") from exc
|
||||
|
||||
|
||||
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
|
||||
normalized: dict[str, str | list[str]] = {}
|
||||
indexed_tag_ids: list[tuple[int, str]] = []
|
||||
indexed_creator_ids: list[tuple[int, str]] = []
|
||||
|
||||
for key in query_args:
|
||||
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
@ -105,12 +125,19 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
|
||||
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
match = _CREATOR_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_creator_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
value = query_args.get(key)
|
||||
if value is not None:
|
||||
normalized[key] = value
|
||||
|
||||
if indexed_tag_ids:
|
||||
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
|
||||
if indexed_creator_ids:
|
||||
normalized["creator_ids"] = [value for _, value in sorted(indexed_creator_ids)]
|
||||
|
||||
return normalized
|
||||
|
||||
@ -486,6 +513,7 @@ class AppListApi(Resource):
|
||||
mode=args.mode,
|
||||
name=args.name,
|
||||
tag_ids=args.tag_ids,
|
||||
creator_ids=args.creator_ids,
|
||||
is_created_by_me=args.is_created_by_me,
|
||||
)
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
@ -36,7 +37,7 @@ from core.helper.trace_id_helper import get_external_trace_id
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -104,7 +105,8 @@ class CompletionMessageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
@ -112,8 +114,6 @@ class CompletionMessageApi(Resource):
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account or EndUser instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
@ -178,7 +178,8 @@ class ChatMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
raw_payload = console_ns.payload or {}
|
||||
args_model = ChatMessagePayload.model_validate(raw_payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
@ -197,8 +198,6 @@ class ChatMessageApi(Resource):
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account or EndUser instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
|
||||
@ -12,6 +12,7 @@ from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotF
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.common.errors import InvalidArgumentError
|
||||
from controllers.common.fields import NewAppResponse, SimpleResultResponse
|
||||
from controllers.common.schema import (
|
||||
register_response_schema_model,
|
||||
@ -19,9 +20,19 @@ from controllers.common.schema import (
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
DraftWorkflowNotSync,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -50,11 +61,12 @@ from graphon.file import helpers as file_helpers
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.variables import SecretVariable, SegmentType, VariableBase
|
||||
from graphon.variables.exc import VariableError
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
@ -401,13 +413,12 @@ class DraftWorkflowApi(Resource):
|
||||
)
|
||||
@console_ns.response(400, "Invalid workflow configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
@ -447,6 +458,8 @@ class DraftWorkflowApi(Resource):
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except VariableError as e:
|
||||
raise InvalidArgumentError(description=str(e))
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@ -468,13 +481,12 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
@ -514,12 +526,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
@ -552,12 +564,12 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
@ -590,12 +602,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -628,12 +640,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -695,12 +707,12 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
@ -724,12 +736,12 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service = WorkflowService()
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
@ -753,12 +765,12 @@ class WorkflowDraftHumanInputFormPreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
@ -782,12 +794,12 @@ class WorkflowDraftHumanInputFormRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
@ -811,12 +823,12 @@ class WorkflowDraftHumanInputDeliveryTestApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Test human input delivery
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service.test_human_input_delivery(
|
||||
@ -841,12 +853,12 @@ class DraftWorkflowRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -911,12 +923,12 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
@ -981,12 +993,12 @@ class PublishedWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -1083,14 +1095,14 @@ class ConvertToWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Convert basic mode of chatbot app to workflow mode
|
||||
Convert expert mode of chatbot app to workflow mode
|
||||
Convert Completion App to Workflow App
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = console_ns.payload or {}
|
||||
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
|
||||
@ -1122,9 +1134,9 @@ class WorkflowFeaturesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
|
||||
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
features = args.features
|
||||
@ -1150,12 +1162,12 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
page = args.page
|
||||
@ -1199,9 +1211,9 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def post(self, current_user: Account, app_model: App, workflow_id: str):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
@ -1237,12 +1249,12 @@ class WorkflowByIdApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
def patch(self, current_user: Account, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Prepare update data
|
||||
@ -1355,12 +1367,12 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
|
||||
node_id = args.node_id
|
||||
workflow_service = WorkflowService()
|
||||
@ -1419,12 +1431,12 @@ class DraftWorkflowTriggerNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
@ -1499,12 +1511,12 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Full workflow debug when the start node is a trigger
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
|
||||
node_ids = args.node_ids
|
||||
@ -1565,7 +1577,8 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_ids = args.app_ids
|
||||
@ -1575,7 +1588,6 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
if not app_ids:
|
||||
return {"data": []}
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
|
||||
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]
|
||||
|
||||
@ -7,12 +7,18 @@ from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_valida
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import AccountWithRole
|
||||
from libs.helper import build_avatar_url, dump_response, to_timestamp
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
@ -213,9 +219,10 @@ class WorkflowCommentListApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json")
|
||||
|
||||
@ -229,12 +236,14 @@ class WorkflowCommentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
@ -258,10 +267,11 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return dump_response(WorkflowCommentDetail, comment)
|
||||
@ -276,12 +286,14 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -302,10 +314,12 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -327,10 +341,12 @@ class WorkflowCommentResolveApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -353,11 +369,13 @@ class WorkflowCommentReplyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
@ -386,17 +404,19 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
@ -416,15 +436,17 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
@ -448,9 +470,13 @@ class WorkflowCommentMentionUsersApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
current_tenant = current_user.current_tenant # need the tenant object here
|
||||
if current_tenant is None:
|
||||
raise ValueError("current tenant is required")
|
||||
members = TenantService.get_tenant_members(current_tenant)
|
||||
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersPayload(users=users)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@ -9,14 +9,19 @@ from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.errors import InvalidArgumentError, NotFoundError
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
@ -27,8 +32,8 @@ from graphon.file import helpers as file_helpers
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -123,14 +128,15 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
def ensure_variable_access(
|
||||
variable: WorkflowDraftVariable | None,
|
||||
app_id: str,
|
||||
variable_id: str,
|
||||
current_user_id: str,
|
||||
) -> WorkflowDraftVariable:
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_id or variable.user_id != current_user.id:
|
||||
if variable.app_id != app_id or variable.user_id != current_user_id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@ -215,7 +221,7 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
|
||||
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
@ -232,9 +238,10 @@ def _api_prerequisite[T, **P, R](
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, *args, **kwargs)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -251,7 +258,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, app_model: App):
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -281,7 +288,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.doc(description="Delete all draft workflow variables")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
def delete(self, current_user: Account, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -315,7 +322,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
def get(self, current_user: Account, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -329,7 +336,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc(description="Delete all variables for a specific node")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
def delete(self, current_user: Account, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
@ -349,15 +356,16 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, app_model: App, variable_id: UUID):
|
||||
def get(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
return variable
|
||||
|
||||
@ -368,7 +376,7 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, app_model: App, variable_id: UUID):
|
||||
def patch(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -396,10 +404,11 @@ class VariableApi(Resource):
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
|
||||
new_name = args_model.name
|
||||
@ -440,15 +449,16 @@ class VariableApi(Resource):
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: UUID):
|
||||
def delete(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
@ -464,7 +474,7 @@ class VariableResetApi(Resource):
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: UUID):
|
||||
def put(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -476,10 +486,11 @@ class VariableResetApi(Resource):
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
@ -490,20 +501,20 @@ class VariableResetApi(Resource):
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
def _get_variable_list(app_model: App, node_id: str, current_user_id: str) -> WorkflowDraftVariableList:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user_id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user_id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
return draft_vars
|
||||
|
||||
@ -517,7 +528,7 @@ class ConversationVariableCollectionApi(Resource):
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
workflow_srv = WorkflowService()
|
||||
@ -527,7 +538,7 @@ class ConversationVariableCollectionApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@ -539,7 +550,8 @@ class ConversationVariableCollectionApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
@ -566,8 +578,8 @@ class SystemVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
@ -578,7 +590,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
def get(self, _current_user: Account, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -619,7 +631,8 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -30,6 +30,7 @@ from uuid import UUID
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -38,8 +39,13 @@ from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from services.workflow import inspector_events
|
||||
from services.workflow.node_output_inspector_service import (
|
||||
CheckResultView,
|
||||
NodeOutputInspectorError,
|
||||
NodeOutputInspectorService,
|
||||
NodeOutputsView,
|
||||
NodeOutputView,
|
||||
OutputPreviewView,
|
||||
WorkflowRunSnapshotView,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -54,6 +60,15 @@ _HEARTBEAT_EVERY_TICKS = 15
|
||||
# many ticks (= seconds).
|
||||
_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min
|
||||
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
CheckResultView,
|
||||
NodeOutputView,
|
||||
NodeOutputsView,
|
||||
WorkflowRunSnapshotView,
|
||||
OutputPreviewView,
|
||||
)
|
||||
|
||||
|
||||
def _service() -> NodeOutputInspectorService:
|
||||
"""One-line factory so tests can monkeypatch a stub if needed."""
|
||||
@ -124,6 +139,7 @@ class WorkflowDraftRunNodeOutputsApi(Resource):
|
||||
@console_ns.doc("get_workflow_draft_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Workflow run node outputs", console_ns.models[WorkflowRunSnapshotView.__name__])
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -146,6 +162,7 @@ class WorkflowDraftRunNodeOutputDetailApi(Resource):
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Workflow run node output detail", console_ns.models[NodeOutputsView.__name__])
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -171,6 +188,7 @@ class WorkflowDraftRunNodeOutputPreviewApi(Resource):
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Workflow run node output preview", console_ns.models[OutputPreviewView.__name__])
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -309,6 +327,7 @@ class WorkflowDraftRunNodeOutputEventsApi(Resource):
|
||||
@console_ns.doc("stream_workflow_draft_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Workflow run node output event stream")
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -338,6 +357,7 @@ class WorkflowPublishedRunNodeOutputsApi(Resource):
|
||||
@console_ns.doc("get_workflow_published_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Workflow run node outputs", console_ns.models[WorkflowRunSnapshotView.__name__])
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -360,6 +380,7 @@ class WorkflowPublishedRunNodeOutputDetailApi(Resource):
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Workflow run node output detail", console_ns.models[NodeOutputsView.__name__])
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -386,6 +407,7 @@ class WorkflowPublishedRunNodeOutputPreviewApi(Resource):
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Workflow run node output preview", console_ns.models[OutputPreviewView.__name__])
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -402,6 +424,7 @@ class WorkflowPublishedRunNodeOutputEventsApi(Resource):
|
||||
@console_ns.doc("stream_workflow_published_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Workflow run node output event stream")
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
@ -9,11 +9,16 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.errors import NotFoundError
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
@ -30,8 +35,8 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
@ -190,8 +195,8 @@ class WorkflowRunExportApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
tenant_id = app_model.tenant_id
|
||||
app_id = app_model.id
|
||||
run_id_str = str(run_id)
|
||||
|
||||
run_created_at = db.session.scalar(
|
||||
@ -397,18 +402,18 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, run_id: UUID):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app_model,
|
||||
run_id=run_id_str,
|
||||
user=user,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionListResponse.model_validate(
|
||||
@ -432,7 +437,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workflow_run_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow pause details.
|
||||
|
||||
@ -449,7 +455,7 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
if not workflow_run:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
if workflow_run.tenant_id != current_user.current_tenant_id:
|
||||
if workflow_run.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
# Check if workflow is suspended
|
||||
|
||||
@ -12,14 +12,14 @@ from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
from models.model import App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
from .. import console_ns
|
||||
from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required, with_current_tenant_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -124,18 +124,16 @@ class AppTriggersApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.tenant_id == current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
@ -166,19 +164,18 @@ class AppTriggerEnableApi(Resource):
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
|
||||
def post(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.tenant_id == current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
@ -76,7 +76,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
language = "en-US"
|
||||
if args.language in languages:
|
||||
if args.language is not None and args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
|
||||
@ -22,6 +22,8 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
@ -36,9 +38,9 @@ from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import build_icon_url, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from libs.url_utils import normalize_api_base_url
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models import Account, ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
@ -389,8 +391,9 @@ class DatasetListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
# Convert query parameters to dict, handling list parameters correctly
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
|
||||
@ -471,9 +474,10 @@ class DatasetListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -512,8 +516,9 @@ class DatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -566,14 +571,15 @@ class DatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check embedding model setting
|
||||
if (
|
||||
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
@ -614,9 +620,9 @@ class DatasetApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Dataset deleted successfully")
|
||||
def delete(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
@ -664,8 +670,8 @@ class DatasetQueryApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -704,10 +710,10 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
@ -804,8 +810,8 @@ class DatasetRelatedAppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -840,8 +846,8 @@ class DatasetIndexingStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
|
||||
@ -898,8 +904,8 @@ class DatasetApiKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
@ -911,9 +917,8 @@ class DatasetApiKeyApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
current_key_count = (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
@ -952,8 +957,8 @@ class DatasetApiDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, api_key_id: UUID):
|
||||
api_key_id_str = str(api_key_id)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
@ -1079,8 +1084,8 @@ class DatasetPermissionUserListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -2,13 +2,12 @@ from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -30,13 +29,11 @@ class DataSourceContentPreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
args = Parser.model_validate(console_ns.payload)
|
||||
|
||||
inputs = args.inputs
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
@ -9,6 +10,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import InvalidArgumentError, NotFoundError
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@ -21,15 +23,14 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -58,7 +59,7 @@ register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
@ -74,10 +75,12 @@ def _api_prerequisite[T, **P, R](
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(self, *args, **kwargs)
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -86,7 +89,7 @@ def _api_prerequisite[T, **P, R](
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -114,7 +117,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline):
|
||||
def delete(self, current_user: Account, pipeline: Pipeline):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -145,7 +148,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
def get(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -156,7 +159,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, node_id: str):
|
||||
def delete(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id)
|
||||
@ -171,7 +174,7 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def get(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -186,7 +189,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def patch(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -255,7 +258,7 @@ class RagPipelineVariableApi(Resource):
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def delete(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -273,7 +276,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def put(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -299,17 +302,17 @@ class RagPipelineVariableResetApi(Resource):
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||
def _get_variable_list(pipeline: Pipeline, node_id: str, current_user_id: str) -> WorkflowDraftVariableList:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user_id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user_id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
@ -317,14 +320,14 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
def get(self, current_user: Account, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, _current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
|
||||
@ -29,6 +29,8 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -46,7 +48,7 @@ from fields.workflow_run_fields import (
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty, dump_response
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import EndUser
|
||||
@ -187,16 +189,14 @@ class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
@console_ns.response(200, "Success", console_ns.models[RagPipelineWorkflowSyncResponse.__name__])
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
@ -247,15 +247,13 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
@ -283,14 +281,12 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
@ -318,14 +314,12 @@ class DraftRagPipelineRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
|
||||
@ -350,14 +344,12 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Run published workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
streaming = payload.response_mode == "streaming"
|
||||
@ -383,14 +375,12 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
@ -416,14 +406,12 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
@ -454,14 +442,12 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
|
||||
inputs = payload.inputs
|
||||
|
||||
@ -485,14 +471,12 @@ class RagPipelineTaskStopApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, task_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
@ -532,13 +516,12 @@ class PublishedRagPipelineApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=db.session, # type: ignore[reportArgumentType,arg-type]
|
||||
@ -609,13 +592,12 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
query = WorkflowListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
page = query.page
|
||||
@ -655,9 +637,9 @@ class RagPipelineDraftWorkflowRestoreApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def post(self, current_user: Account, pipeline: Pipeline, workflow_id: str):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
@ -689,14 +671,12 @@ class RagPipelineByIdApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def patch(self, pipeline: Pipeline, workflow_id: str):
|
||||
def patch(self, current_user: Account, pipeline: Pipeline, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
@ -925,8 +905,8 @@ class DatasourceListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
|
||||
|
||||
|
||||
@ -961,9 +941,8 @@ class RagPipelineTransformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
@ -984,13 +963,13 @@ class RagPipelineDatasourceVariableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Set datasource variables
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy import and_, exists, or_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.fields import SimpleMessageResponse, SimpleResultMessageResponse
|
||||
@ -24,8 +24,8 @@ from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from models import Account, App, AppModelConfig, InstalledApp, RecommendedApp, Workflow
|
||||
from models.model import AppMode, IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
@ -61,6 +61,24 @@ def _safe_primitive(value: Any) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
def _published_app_filter():
|
||||
"""Return the SQL predicate for installed-app web API availability.
|
||||
|
||||
The installed-app parameters endpoint reads the published workflow for
|
||||
workflow-style apps and the published app model config for easy UI apps.
|
||||
Keep the list endpoint aligned in SQL so it does not return entries that
|
||||
will immediately fail with app_unavailable when opened.
|
||||
"""
|
||||
workflow_app_modes = (AppMode.ADVANCED_CHAT, AppMode.WORKFLOW)
|
||||
has_published_workflow = exists(select(Workflow.id).where(Workflow.id == App.workflow_id))
|
||||
has_published_model_config = exists(select(AppModelConfig.id).where(AppModelConfig.id == App.app_model_config_id))
|
||||
|
||||
return or_(
|
||||
and_(App.mode.in_(workflow_app_modes), App.workflow_id.isnot(None), has_published_workflow),
|
||||
and_(~App.mode.in_(workflow_app_modes), App.app_model_config_id.isnot(None), has_published_model_config),
|
||||
)
|
||||
|
||||
|
||||
class InstalledAppInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
@ -141,33 +159,32 @@ class InstalledAppsListApi(Resource):
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
stmt = (
|
||||
select(InstalledApp, App)
|
||||
.join(App, App.id == InstalledApp.app_id)
|
||||
.where(InstalledApp.tenant_id == current_tenant_id, _published_app_filter())
|
||||
)
|
||||
if query.app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
select(InstalledApp).where(
|
||||
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
installed_apps = db.session.scalars(
|
||||
select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
stmt = stmt.where(InstalledApp.app_id == query.app_id)
|
||||
|
||||
installed_apps = db.session.execute(stmt).all()
|
||||
|
||||
if current_user.current_tenant is None:
|
||||
raise ValueError("current_user.current_tenant must not be None")
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_app_list: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": installed_app.id,
|
||||
"app": installed_app.app,
|
||||
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
||||
"is_pinned": installed_app.is_pinned,
|
||||
"last_used_at": installed_app.last_used_at,
|
||||
"editable": current_user.role in {"owner", "admin"},
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
if installed_app.app is not None
|
||||
]
|
||||
installed_app_list: list[dict[str, Any]] = []
|
||||
for installed_app, app_model in installed_apps:
|
||||
installed_app_list.append(
|
||||
{
|
||||
"id": installed_app.id,
|
||||
"app": app_model,
|
||||
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
||||
"is_pinned": installed_app.is_pinned,
|
||||
"last_used_at": installed_app.last_used_at,
|
||||
"editable": current_user.role in {"owner", "admin"},
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
}
|
||||
)
|
||||
|
||||
# filter out apps that user doesn't have access to
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
||||
@ -8,10 +8,11 @@ from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import query_params_from_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, with_current_user
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
@ -79,13 +80,14 @@ class RecommendedAppListApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
language = args.language
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
elif current_user and current_user.interface_language:
|
||||
elif current_user.interface_language:
|
||||
language_prefix = current_user.interface_language
|
||||
else:
|
||||
language_prefix = languages[0]
|
||||
|
||||
@ -11,6 +11,7 @@ from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.errors import InvalidArgumentError, NotFoundError
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -21,7 +22,6 @@ from controllers.console.wraps import (
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
|
||||
164
api/controllers/console/snippets/payloads.py
Normal file
164
api/controllers/console/snippets/payloads.py
Normal file
@ -0,0 +1,164 @@
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
is_published: bool | None = Field(default=None, description="Filter by published status")
|
||||
creators: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Filter by creator account IDs",
|
||||
validation_alias=AliasChoices("creators", "creator_id"),
|
||||
)
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
|
||||
@field_validator("creators", mode="before")
|
||||
@classmethod
|
||||
def parse_creators(cls, value: object) -> list[str] | None:
|
||||
"""Normalize creators filter from query string or list input."""
|
||||
return cls._normalize_string_list(value)
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def parse_tag_ids(cls, value: object) -> list[str] | None:
|
||||
"""Normalize and validate tag IDs from query string or list input."""
|
||||
items = cls._normalize_string_list(value)
|
||||
if not items:
|
||||
return None
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
@staticmethod
|
||||
def _normalize_string_list(value: object) -> list[str] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [item.strip() for item in value.split(",") if item.strip()] or None
|
||||
if isinstance(value, list):
|
||||
return [str(item).strip() for item in value if str(item).strip()] or None
|
||||
return None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Ignored. Snippet workflows do not persist conversation variables.",
|
||||
)
|
||||
input_fields: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetWorkflowListQuery(BaseModel):
|
||||
"""Query parameters for listing snippet published workflows."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetImportPayload(BaseModel):
|
||||
"""Payload for importing snippet from DSL."""
|
||||
|
||||
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
|
||||
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
|
||||
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
|
||||
name: str | None = Field(default=None, description="Override snippet name")
|
||||
description: str | None = Field(default=None, description="Override snippet description")
|
||||
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
"""Query parameter for including secret variables in export."""
|
||||
|
||||
include_secret: str = Field(default="false", description="Whether to include secret variables")
|
||||
678
api/controllers/console/snippets/snippet_workflow.py
Normal file
678
api/controllers/console/snippets/snippet_workflow.py
Normal file
@ -0,0 +1,678 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow import (
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowResponse,
|
||||
)
|
||||
from controllers.console.snippets.payloads import (
|
||||
PublishWorkflowPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_run_fields import (
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
)
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
|
||||
|
||||
def _snippet_session_maker() -> sessionmaker[Session]:
|
||||
return sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def _snippet_service() -> SnippetService:
|
||||
return SnippetService(_snippet_session_maker())
|
||||
|
||||
|
||||
class SnippetWorkflowResponse(WorkflowResponse):
|
||||
input_fields: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
PublishWorkflowPayload,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SnippetWorkflowResponse,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
)
|
||||
|
||||
|
||||
class SnippetNotFoundError(Exception):
|
||||
"""Snippet not found error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_snippet[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Decorator to fetch and validate snippet access."""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("snippet_id"):
|
||||
raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_id = str(kwargs.get("snippet_id"))
|
||||
del kwargs["snippet_id"]
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=snippet_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
kwargs["snippet"] = snippet
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
class SnippetDraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_workflow")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow retrieved successfully",
|
||||
console_ns.models[SnippetWorkflowResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get draft workflow for snippet."""
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
workflow.conversation_variables = []
|
||||
response = SnippetWorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
|
||||
response["input_fields"] = snippet.input_fields_list
|
||||
return response
|
||||
|
||||
@console_ns.doc("sync_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow synced successfully")
|
||||
@console_ns.response(400, "Hash mismatch")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet):
|
||||
"""Sync draft workflow for snippet."""
|
||||
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.sync_draft_workflow(
|
||||
snippet=snippet,
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
input_fields=payload.input_fields,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
class SnippetDraftConfigApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_config")
|
||||
@console_ns.response(200, "Draft config retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get snippet draft workflow configuration limits."""
|
||||
return {
|
||||
"parallel_depth_limit": 3,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
class SnippetPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_published_workflow")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflow retrieved successfully",
|
||||
console_ns.models[SnippetWorkflowResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get published workflow for snippet."""
|
||||
if not snippet.is_published:
|
||||
return None
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
return None
|
||||
|
||||
response = SnippetWorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
|
||||
response["input_fields"] = snippet.input_fields_list
|
||||
return response
|
||||
|
||||
@console_ns.doc("publish_snippet_workflow")
|
||||
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
@console_ns.response(200, "Workflow published successfully")
|
||||
@console_ns.response(400, "No draft workflow found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet):
|
||||
"""Publish snippet workflow."""
|
||||
snippet_service = _snippet_service()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
try:
|
||||
workflow = snippet_service.publish_workflow(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account=current_user,
|
||||
)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
class SnippetDefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_snippet_default_block_configs")
|
||||
@console_ns.response(200, "Default block configs retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get default block configurations for snippet workflow."""
|
||||
snippet_service = _snippet_service()
|
||||
return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows")
|
||||
class SnippetPublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
|
||||
@console_ns.doc("get_all_snippet_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflows retrieved successfully",
|
||||
console_ns.models[WorkflowPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get all published workflow versions for snippet."""
|
||||
args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = snippet_service.get_all_published_workflows(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return WorkflowPaginationResponse.model_validate(
|
||||
{
|
||||
"items": workflows,
|
||||
"page": args.page,
|
||||
"limit": args.limit,
|
||||
"has_more": has_more,
|
||||
},
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/<string:workflow_id>/restore")
|
||||
class SnippetDraftWorkflowRestoreApi(Resource):
|
||||
@console_ns.doc("restore_snippet_workflow_to_draft")
|
||||
@console_ns.doc(description="Restore a published snippet workflow version into the draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"})
|
||||
@console_ns.response(200, "Workflow restored successfully")
|
||||
@console_ns.response(400, "Source workflow must be published")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet, workflow_id: str):
|
||||
"""Restore a published snippet workflow version into the draft workflow."""
|
||||
snippet_service = _snippet_service()
|
||||
|
||||
try:
|
||||
workflow = snippet_service.restore_published_workflow_to_draft(
|
||||
snippet=snippet,
|
||||
workflow_id=workflow_id,
|
||||
account=current_user,
|
||||
)
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
class SnippetWorkflowRunsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_runs")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow runs retrieved successfully",
|
||||
console_ns.models[WorkflowRunPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""List workflow runs for snippet."""
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": query.last_id,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
class SnippetWorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_run_detail")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow run detail retrieved successfully",
|
||||
console_ns.models[WorkflowRunDetailResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""Get workflow run detail for snippet."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Node executions retrieved successfully",
|
||||
console_ns.models[WorkflowRunNodeExecutionListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""List node executions for a workflow run."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
snippet=snippet,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionListResponse.model_validate(
|
||||
{"data": node_executions}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class SnippetDraftNodeRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_node")
|
||||
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
@console_ns.response(
|
||||
200, "Node run completed successfully", console_ns.models[WorkflowRunNodeExecutionResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a single node in snippet draft workflow.
|
||||
|
||||
Executes a specific node with provided inputs for single-step debugging.
|
||||
Returns the node execution result including status, outputs, and timing.
|
||||
"""
|
||||
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
user_inputs = payload.inputs
|
||||
|
||||
# Get draft workflow for file parsing
|
||||
snippet_service = _snippet_service()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
snippet=snippet,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=payload.query,
|
||||
files=files,
|
||||
session_maker=_snippet_session_maker(),
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(
|
||||
workflow_node_execution, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class SnippetDraftNodeLastRunApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.response(
|
||||
200, "Node last run retrieved successfully", console_ns.models[WorkflowRunNodeExecutionResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
Returns the most recent execution record for the given node,
|
||||
including status, inputs, outputs, and timing information.
|
||||
"""
|
||||
snippet_service = _snippet_service()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
node_exec = snippet_service.get_snippet_node_last_run(
|
||||
snippet=snippet,
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("Node last run not found")
|
||||
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow iteration node for snippet.
|
||||
|
||||
Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
Returns an SSE event stream with iteration progress and results.
|
||||
"""
|
||||
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_iteration(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True,
|
||||
session_maker=_snippet_session_maker(),
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow loop node for snippet.
|
||||
|
||||
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
Returns an SSE event stream with loop progress and results.
|
||||
"""
|
||||
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_loop(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True,
|
||||
session_maker=_snippet_session_maker(),
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
class SnippetDraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet):
|
||||
"""
|
||||
Run draft workflow for snippet.
|
||||
|
||||
Executes the snippet's draft workflow with the provided inputs
|
||||
and returns an SSE event stream with execution progress and results.
|
||||
"""
|
||||
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
session_maker=_snippet_session_maker(),
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class SnippetWorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_snippet_workflow_task")
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
"""
|
||||
Stop a running snippet workflow task.
|
||||
|
||||
Uses both the legacy stop flag mechanism and the graph engine
|
||||
command channel for backward compatibility.
|
||||
"""
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
@ -0,0 +1,334 @@
|
||||
"""
|
||||
Snippet draft workflow variable APIs.
|
||||
|
||||
Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
|
||||
using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
|
||||
|
||||
Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
|
||||
(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
|
||||
reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
|
||||
Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.errors import InvalidArgumentError, NotFoundError
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
WorkflowDraftVariableListQuery,
|
||||
WorkflowDraftVariableUpdatePayload,
|
||||
ensure_variable_access,
|
||||
validate_node_id,
|
||||
workflow_draft_variable_list_model,
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.snippets.snippet_workflow import get_snippet
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
def _snippet_service() -> SnippetService:
|
||||
return SnippetService(sessionmaker(bind=db.engine, expire_on_commit=False))
|
||||
|
||||
|
||||
def _ensure_snippet_draft_variable_row_allowed(
|
||||
*,
|
||||
variable: WorkflowDraftVariable,
|
||||
variable_id: str,
|
||||
) -> None:
|
||||
"""Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
|
||||
if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
def _snippet_draft_var_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
|
||||
class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
@console_ns.doc("get_snippet_workflow_variables")
|
||||
@console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow variables retrieved successfully",
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
if snippet_service.get_draft_workflow(snippet=snippet) is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=snippet.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
user_id=current_user.id,
|
||||
exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variables")
|
||||
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class SnippetNodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_node_variables")
|
||||
@console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_node_variables")
|
||||
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
class SnippetVariableApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
|
||||
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Update a draft workflow variable (snippet scope)")
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class SnippetVariableResetApi(Resource):
|
||||
@console_ns.doc("reset_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
|
||||
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def put(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
snippet_service = _snippet_service()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, snippet_id={snippet.id}",
|
||||
)
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
|
||||
class SnippetConversationVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_conversation_variables")
|
||||
@console_ns.doc(
|
||||
description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
|
||||
class SnippetSystemVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_system_variables")
|
||||
@console_ns.doc(
|
||||
description="System variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
|
||||
class SnippetEnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_environment_variables")
|
||||
@console_ns.doc(description="Get environment variables from snippet draft workflow graph")
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars_list: list[dict[str, Any]] = []
|
||||
for v in workflow.environment_variables:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value": v.value,
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
@ -51,7 +51,7 @@ class TagBindingRemovePayload(BaseModel):
|
||||
|
||||
|
||||
class TagListQueryParam(BaseModel):
|
||||
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
|
||||
type: Literal["knowledge", "app", "snippet", ""] = Field("", description="Tag type filter")
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
@ -96,7 +96,10 @@ class TagListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
params={
|
||||
"type": 'Tag type filter. Can be "knowledge", "app", or "snippet".',
|
||||
"keyword": "Search keyword for tag name.",
|
||||
}
|
||||
)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
@with_current_tenant_id
|
||||
|
||||
@ -13,13 +13,20 @@ from controllers.common.fields import SuccessResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from fields.base import ResponseModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from libs.login import login_required
|
||||
from models.account import Account, TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -200,9 +207,8 @@ class PluginDebuggingKeyApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
try:
|
||||
return {
|
||||
"key": PluginService.get_debugging_key(tenant_id),
|
||||
@ -219,11 +225,12 @@ class PluginListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = ParserList.model_validate(request.args.to_dict(flat=True))
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, user_id, args.page, args.page_size)
|
||||
except PluginDaemonClientSideError as e:
|
||||
return {"code": "plugin_error", "message": e.description}, 400
|
||||
|
||||
@ -253,9 +260,8 @@ class PluginListInstallationsFromIdsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserLatest.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -288,10 +294,10 @@ class PluginAssetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserAsset.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
try:
|
||||
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
|
||||
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
|
||||
@ -305,9 +311,8 @@ class PluginUploadFromPkgApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
file = request.files["pkg"]
|
||||
content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
try:
|
||||
@ -325,9 +330,8 @@ class PluginUploadFromGithubApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserGithubUpload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -344,9 +348,8 @@ class PluginUploadFromBundleApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
file = request.files["bundle"]
|
||||
content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE)
|
||||
try:
|
||||
@ -364,8 +367,8 @@ class PluginInstallFromPkgApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -383,9 +386,8 @@ class PluginInstallFromGithubApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserGithubInstall.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -409,9 +411,8 @@ class PluginInstallFromMarketplaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -429,8 +430,8 @@ class PluginFetchMarketplacePkgApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
@ -453,9 +454,8 @@ class PluginFetchManifestApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
@ -473,9 +473,8 @@ class PluginFetchInstallTasksApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserTasks.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
@ -490,9 +489,8 @@ class PluginFetchInstallTaskApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self, task_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, task_id: str):
|
||||
try:
|
||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -506,9 +504,8 @@ class PluginDeleteInstallTaskApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, task_id: str):
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -522,9 +519,8 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
try:
|
||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -538,9 +534,8 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str, identifier: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, task_id: str, identifier: str):
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -554,9 +549,8 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -576,9 +570,8 @@ class PluginUpgradeFromGithubApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserGithubUpgrade.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -604,11 +597,10 @@ class PluginUninstallApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserUninstall.model_validate(console_ns.payload)
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -622,16 +614,14 @@ class PluginChangePermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = ParserPermissionChange.model_validate(console_ns.payload)
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return {
|
||||
"success": PluginPermissionService.change_permission(
|
||||
tenant_id, args.install_permission, args.debug_permission
|
||||
@ -644,9 +634,8 @@ class PluginFetchPermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
if not permission:
|
||||
return jsonable_encoder(
|
||||
@ -671,16 +660,15 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
user_id = current_user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, current_user: Account):
|
||||
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=args.plugin_id,
|
||||
provider=args.provider,
|
||||
action=args.action,
|
||||
@ -701,17 +689,16 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, current_user: Account):
|
||||
"""Fetch dynamic options using credentials directly (for edit mode)."""
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
user_id = current_user.id
|
||||
|
||||
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options_with_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=args.plugin_id,
|
||||
provider=args.provider,
|
||||
action=args.action,
|
||||
@ -731,8 +718,9 @@ class PluginChangePreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
@ -780,9 +768,8 @@ class PluginFetchPreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
permission_dict = {
|
||||
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
@ -820,10 +807,9 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
# exclude one single plugin
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserExcludePlugin.model_validate(console_ns.payload)
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
|
||||
@ -835,8 +821,8 @@ class PluginReadmeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserReadme.model_validate(request.args.to_dict(flat=True))
|
||||
return jsonable_encoder(
|
||||
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
|
||||
|
||||
424
api/controllers/console/workspace/snippets.py
Normal file
424
api/controllers/console/workspace/snippets.py
Normal file
@ -0,0 +1,424 @@
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.snippets.payloads import (
|
||||
CreateSnippetPayload,
|
||||
IncludeSecretQuery,
|
||||
SnippetImportPayload,
|
||||
SnippetListQuery,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.snippet import SnippetType
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.snippet_dsl_service import SnippetDslService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
_CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$")
|
||||
|
||||
|
||||
def _snippet_service() -> SnippetService:
|
||||
return SnippetService(sessionmaker(bind=db.engine, expire_on_commit=False))
|
||||
|
||||
|
||||
def _normalize_snippet_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
|
||||
normalized: dict[str, str | list[str]] = {}
|
||||
indexed_tag_ids: list[tuple[int, str]] = []
|
||||
indexed_creator_ids: list[tuple[int, str]] = []
|
||||
|
||||
for key in query_args:
|
||||
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
match = _CREATOR_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_creator_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
value = query_args.get(key)
|
||||
if value is not None:
|
||||
normalized[key] = value
|
||||
|
||||
if indexed_tag_ids:
|
||||
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
|
||||
if indexed_creator_ids:
|
||||
normalized["creators"] = [value for _, value in sorted(indexed_creator_ids)]
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetListQuery,
|
||||
CreateSnippetPayload,
|
||||
UpdateSnippetPayload,
|
||||
SnippetImportPayload,
|
||||
IncludeSecretQuery,
|
||||
)
|
||||
|
||||
# Create namespace models for marshaling
|
||||
snippet_model = console_ns.model("Snippet", snippet_fields)
|
||||
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
|
||||
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets")
|
||||
class CustomizedSnippetsApi(Resource):
|
||||
@console_ns.doc("list_customized_snippets")
|
||||
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
|
||||
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
"""List customized snippets with pagination and search."""
|
||||
query = SnippetListQuery.model_validate(_normalize_snippet_list_query_args(request.args))
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippets, total, has_more = snippet_service.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
is_published=query.is_published,
|
||||
creators=query.creators,
|
||||
tag_ids=query.tag_ids,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(snippets, snippet_list_fields),
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"has_more": has_more,
|
||||
}, 200
|
||||
|
||||
@console_ns.doc("create_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
"""Create a new customized snippet."""
|
||||
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_type = SnippetType(payload.type)
|
||||
except ValueError:
|
||||
snippet_type = SnippetType.NODE
|
||||
|
||||
try:
|
||||
if payload.graph is not None:
|
||||
SnippetService.validate_snippet_graph_forbidden_nodes(payload.graph)
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.create_snippet(
|
||||
tenant_id=current_tenant_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
snippet_type=snippet_type,
|
||||
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
|
||||
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
|
||||
account=current_user,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
|
||||
class CustomizedSnippetDetailApi(Resource):
|
||||
@console_ns.doc("get_customized_snippet")
|
||||
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Get customized snippet details."""
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("update_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, snippet_id: str):
|
||||
"""Update customized snippet."""
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "icon_info" in update_data and update_data["icon_info"] is not None:
|
||||
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
snippet = session.merge(snippet)
|
||||
snippet = SnippetService.update_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("delete_customized_snippet")
|
||||
@console_ns.response(204, "Snippet deleted successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Delete customized snippet."""
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.delete_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
|
||||
class CustomizedSnippetExportApi(Resource):
|
||||
@console_ns.doc("export_customized_snippet")
|
||||
@console_ns.doc(description="Export snippet configuration as DSL")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
|
||||
@console_ns.response(200, "Snippet exported successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Export snippet as DSL."""
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
# Get include_secret parameter
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = SnippetDslService(session)
|
||||
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
|
||||
|
||||
# Set filename with .snippet extension
|
||||
filename = f"{snippet.name}.snippet"
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
response = Response(
|
||||
result,
|
||||
mimetype="application/x-yaml",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/x-yaml"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports")
|
||||
class CustomizedSnippetImportApi(Resource):
|
||||
@console_ns.doc("import_customized_snippet")
|
||||
@console_ns.doc(description="Import snippet from DSL")
|
||||
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
|
||||
@console_ns.response(200, "Snippet imported successfully")
|
||||
@console_ns.response(202, "Import pending confirmation")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
"""Import snippet from DSL."""
|
||||
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.import_snippet(
|
||||
account=current_user,
|
||||
import_mode=payload.mode,
|
||||
yaml_content=payload.yaml_content,
|
||||
yaml_url=payload.yaml_url,
|
||||
snippet_id=payload.snippet_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
|
||||
class CustomizedSnippetImportConfirmApi(Resource):
|
||||
@console_ns.doc("confirm_snippet_import")
|
||||
@console_ns.doc(description="Confirm a pending snippet import")
|
||||
@console_ns.doc(params={"import_id": "Import ID to confirm"})
|
||||
@console_ns.response(200, "Import confirmed successfully")
|
||||
@console_ns.response(400, "Import failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
"""Confirm a pending snippet import."""
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.confirm_import(import_id=import_id, account=current_user)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
|
||||
class CustomizedSnippetCheckDependenciesApi(Resource):
|
||||
@console_ns.doc("check_snippet_dependencies")
|
||||
@console_ns.doc(description="Check dependencies for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Dependencies checked successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Check dependencies for a snippet."""
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.check_dependencies(snippet=snippet)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
|
||||
class CustomizedSnippetUseCountIncrementApi(Resource):
|
||||
@console_ns.doc("increment_snippet_use_count")
|
||||
@console_ns.doc(description="Increment snippet use count by 1")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Use count incremented successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Increment snippet use count when it is inserted into a workflow."""
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.increment_use_count(session=session, snippet=snippet)
|
||||
session.commit()
|
||||
session.refresh(snippet)
|
||||
|
||||
return {"result": "success", "use_count": snippet.use_count}, 200
|
||||
@ -18,9 +18,11 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
@ -30,7 +32,8 @@ from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToo
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.provider_ids import ToolProviderID
|
||||
|
||||
# from models.provider_ids import ToolProviderID
|
||||
@ -210,6 +213,30 @@ class MCPProviderBasePayload(BaseModel):
|
||||
configuration: dict[str, Any] | None = Field(default_factory=dict)
|
||||
headers: dict[str, Any] | None = Field(default_factory=dict)
|
||||
authentication: dict[str, Any] | None = Field(default_factory=dict)
|
||||
# None means "leave unchanged" on update; the controller resolves it to a
|
||||
# concrete IdentityMode before calling the service (see _resolve_identity_mode).
|
||||
identity_mode: IdentityMode | None = None
|
||||
|
||||
|
||||
def _resolve_identity_mode(requested: IdentityMode | None, *, current: IdentityMode) -> IdentityMode:
|
||||
"""Resolve the effective MCP identity_mode for a create/update request.
|
||||
|
||||
Keeps two API-layer concerns out of the service so the service always
|
||||
receives a concrete value:
|
||||
|
||||
* ``None`` means "leave unchanged" (update semantics) — fall back to
|
||||
``current`` (``IdentityMode.OFF`` for a brand-new provider).
|
||||
* Identity forwarding is an enterprise-only capability. On non-enterprise
|
||||
deployments any non-OFF value is coerced back to OFF so a persisted row
|
||||
can never imply forwarding that the runtime won't perform. This gates the
|
||||
API surface to match the backend gate in
|
||||
``MCPTool._forwarding_requested`` — both the API and the backend
|
||||
invocation must be gated on ``dify_config.ENTERPRISE_ENABLED``.
|
||||
"""
|
||||
mode = current if requested is None else requested
|
||||
if mode != IdentityMode.OFF and not dify_config.ENTERPRISE_ENABLED:
|
||||
return IdentityMode.OFF
|
||||
return mode
|
||||
|
||||
|
||||
class MCPProviderCreatePayload(MCPProviderBasePayload):
|
||||
@ -262,15 +289,13 @@ class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = ToolProviderListQuery.model_validate(raw_args)
|
||||
|
||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore
|
||||
return ToolCommonService.list_tool_providers(user.id, tenant_id, query.type) # type: ignore
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
@ -278,9 +303,8 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||
tenant_id,
|
||||
@ -294,9 +318,8 @@ class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
|
||||
@ -307,9 +330,8 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||
@ -325,15 +347,13 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str):
|
||||
payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credentials=payload.credentials,
|
||||
@ -350,14 +370,13 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str):
|
||||
payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_id=payload.credential_id,
|
||||
@ -372,8 +391,9 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
# Optional list of credential IDs to include even if visibility would hide them
|
||||
# (used when a workflow/agent node still references another member's only_me credential).
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
@ -406,15 +426,13 @@ class ToolApiProviderAddApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return ApiToolManageService.create_api_tool_provider(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.provider,
|
||||
payload.icon,
|
||||
@ -432,16 +450,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = UrlQuery.model_validate(raw_args)
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
str(query.url),
|
||||
)
|
||||
@ -452,17 +468,15 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = ProviderQuery.model_validate(raw_args)
|
||||
|
||||
return jsonable_encoder(
|
||||
ApiToolManageService.list_api_tool_provider_tools(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.provider,
|
||||
)
|
||||
@ -476,15 +490,13 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return ApiToolManageService.update_api_tool_provider(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.provider,
|
||||
payload.original_provider,
|
||||
@ -505,15 +517,13 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return ApiToolManageService.delete_api_tool_provider(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.provider,
|
||||
)
|
||||
@ -524,16 +534,14 @@ class ToolApiProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = ProviderQuery.model_validate(raw_args)
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.provider,
|
||||
)
|
||||
@ -544,9 +552,8 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, credential_type):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider, credential_type):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||
provider, CredentialType.of(credential_type), tenant_id
|
||||
@ -574,9 +581,9 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = ApiToolTestPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_tenant_id,
|
||||
payload.provider_name or "",
|
||||
@ -595,15 +602,13 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_app_id=payload.workflow_app_id,
|
||||
name=payload.name,
|
||||
@ -623,14 +628,13 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return WorkflowToolManageService.update_workflow_tool(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.workflow_tool_id,
|
||||
payload.name,
|
||||
@ -650,15 +654,13 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return WorkflowToolManageService.delete_workflow_tool(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.workflow_tool_id,
|
||||
)
|
||||
@ -669,23 +671,21 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = WorkflowToolGetQuery.model_validate(raw_args)
|
||||
|
||||
if query.workflow_tool_id:
|
||||
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.workflow_tool_id,
|
||||
)
|
||||
elif query.workflow_app_id:
|
||||
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.workflow_app_id,
|
||||
)
|
||||
@ -700,17 +700,15 @@ class ToolWorkflowProviderListToolApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = WorkflowToolListQuery.model_validate(raw_args)
|
||||
|
||||
return jsonable_encoder(
|
||||
WorkflowToolManageService.list_single_workflow_tools(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.workflow_tool_id,
|
||||
)
|
||||
@ -722,16 +720,14 @@ class ToolBuiltinListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
@ -743,9 +739,8 @@ class ToolApiListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
@ -761,16 +756,14 @@ class ToolWorkflowListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in WorkflowToolManageService.list_tenant_workflow_tools(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
@ -793,13 +786,13 @@ class ToolPluginOAuthApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
@ -888,8 +881,8 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, provider=provider, id=payload.id
|
||||
@ -903,11 +896,10 @@ class ToolOAuthCustomClient(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
@ -920,8 +912,8 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, provider: str):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
)
|
||||
@ -929,8 +921,8 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
)
|
||||
@ -941,8 +933,8 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, provider: str):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||
tenant_id=current_tenant_id, provider_name=provider
|
||||
@ -955,8 +947,9 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
|
||||
]
|
||||
@ -977,9 +970,10 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {})
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# Parse and validate models
|
||||
configuration = MCPConfiguration.model_validate(payload.configuration or {})
|
||||
@ -1000,6 +994,7 @@ class ToolProviderMCPApi(Resource):
|
||||
headers=payload.headers or {},
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
identity_mode=_resolve_identity_mode(payload.identity_mode, current=IdentityMode.OFF),
|
||||
)
|
||||
|
||||
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||
@ -1029,11 +1024,11 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self):
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str):
|
||||
payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
configuration = MCPConfiguration.model_validate(payload.configuration or {})
|
||||
authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||
validation_data = None
|
||||
@ -1054,6 +1049,11 @@ class ToolProviderMCPApi(Resource):
|
||||
# Step 3: Perform database update in a transaction
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
# Resolve "leave unchanged" (None) against the stored value, and gate
|
||||
# the result on ENTERPRISE_ENABLED — both are API-layer concerns, so
|
||||
# the service receives a concrete IdentityMode.
|
||||
existing = service.get_provider(provider_id=payload.provider_id, tenant_id=current_tenant_id)
|
||||
identity_mode = _resolve_identity_mode(payload.identity_mode, current=IdentityMode(existing.identity_mode))
|
||||
service.update_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=payload.provider_id,
|
||||
@ -1067,6 +1067,7 @@ class ToolProviderMCPApi(Resource):
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
identity_mode=identity_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
@ -1076,9 +1077,9 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str):
|
||||
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
@ -1093,10 +1094,10 @@ class ToolMCPAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
payload = MCPAuthPayload.model_validate(console_ns.payload or {})
|
||||
provider_id = payload.provider_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
@ -1166,8 +1167,8 @@ class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider_id: str):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
@ -1179,9 +1180,8 @@ class ToolMCPListAllApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
# Skip sensitive data decryption for list view to improve performance
|
||||
@ -1195,8 +1195,8 @@ class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider_id: str):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
|
||||
@ -8,9 +8,9 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.errors import NotFoundError
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from fastopenapi.routers.flask import FlaskRouter
|
||||
|
||||
console_router = FlaskRouter()
|
||||
|
||||
@ -9,6 +9,7 @@ from werkzeug.exceptions import Forbidden
|
||||
import services
|
||||
from core.tools.signature import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from fields.file_fields import FileResponse
|
||||
|
||||
from ..common.errors import (
|
||||
@ -58,7 +59,8 @@ class PluginUploadFileApi(Resource):
|
||||
The file must be accompanied by valid timestamp, nonce, and signature parameters.
|
||||
|
||||
Returns:
|
||||
dict: File metadata including ID, URLs, and properties
|
||||
dict: File metadata including ID, canonical ``reference`` for
|
||||
output-file reconstruction, URLs, and properties
|
||||
int: HTTP status code (201 for success)
|
||||
|
||||
Raises:
|
||||
@ -112,6 +114,7 @@ class PluginUploadFileApi(Resource):
|
||||
# Create a dictionary with all the necessary attributes
|
||||
result = FileResponse(
|
||||
id=tool_file.id,
|
||||
reference=build_file_reference(record_id=tool_file.id),
|
||||
name=tool_file.name,
|
||||
size=tool_file.size,
|
||||
extension=extension,
|
||||
|
||||
@ -17,12 +17,14 @@ inner_api_ns = Namespace("inner_api", description="Internal API operations", pat
|
||||
|
||||
from . import mail as _mail
|
||||
from .app import dsl as _app_dsl
|
||||
from .plugin import agent_drive as _agent_drive
|
||||
from .plugin import plugin as _plugin
|
||||
from .workspace import workspace as _workspace
|
||||
|
||||
api.add_namespace(inner_api_ns)
|
||||
|
||||
__all__ = [
|
||||
"_agent_drive",
|
||||
"_app_dsl",
|
||||
"_mail",
|
||||
"_plugin",
|
||||
|
||||
80
api/controllers/inner_api/plugin/agent_drive.py
Normal file
80
api/controllers/inner_api/plugin/agent_drive.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""Inner API for the agent drive (agent 网盘) control plane — ENG-591.
|
||||
|
||||
Two endpoints, called by the dify-agent server (not the sandbox) with the inner
|
||||
API key. The drive ref is the URL segment ``agent-<agent_id>``; the path-like
|
||||
file key travels in the query/body, never as a URL path segment (so its ``/``
|
||||
characters do not collide with routing). Drive-owned semantics: tenant scoped,
|
||||
no user-level FileAccessScope.
|
||||
"""
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from services.agent_drive_service import (
|
||||
AgentDriveError,
|
||||
AgentDriveService,
|
||||
DriveCommitItem,
|
||||
parse_agent_drive_ref,
|
||||
)
|
||||
|
||||
|
||||
class _CommitRequest(BaseModel):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
items: list[DriveCommitItem]
|
||||
|
||||
|
||||
def _error_response(exc: AgentDriveError) -> tuple[dict[str, str], int]:
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
|
||||
|
||||
@inner_api_ns.route("/drive/<string:drive_ref>/manifest")
|
||||
class AgentDriveManifestApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@inner_api_ns.doc("agent_drive_manifest")
|
||||
@inner_api_ns.doc(description="List an agent drive (optionally with download URLs)")
|
||||
def get(self, drive_ref: str):
|
||||
try:
|
||||
agent_id = parse_agent_drive_ref(drive_ref)
|
||||
tenant_id = (request.args.get("tenant_id") or "").strip()
|
||||
if not tenant_id:
|
||||
raise AgentDriveError("missing_tenant_id", "tenant_id is required", status_code=400)
|
||||
include_download_url = (request.args.get("include_download_url") or "").lower() in ("1", "true", "yes")
|
||||
items = AgentDriveService().manifest(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
prefix=request.args.get("prefix", ""),
|
||||
include_download_url=include_download_url,
|
||||
)
|
||||
except AgentDriveError as exc:
|
||||
return _error_response(exc)
|
||||
return {"items": items}
|
||||
|
||||
|
||||
@inner_api_ns.route("/drive/<string:drive_ref>/commit")
|
||||
class AgentDriveCommitApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@inner_api_ns.doc("agent_drive_commit")
|
||||
@inner_api_ns.doc(description="Commit a batch of file refs into an agent drive")
|
||||
def post(self, drive_ref: str):
|
||||
try:
|
||||
agent_id = parse_agent_drive_ref(drive_ref)
|
||||
try:
|
||||
body = _CommitRequest.model_validate(request.get_json(silent=True) or {})
|
||||
except ValidationError as exc:
|
||||
raise AgentDriveError("invalid_request", str(exc), status_code=400) from exc
|
||||
items = AgentDriveService().commit(
|
||||
tenant_id=body.tenant_id,
|
||||
user_id=body.user_id,
|
||||
agent_id=agent_id,
|
||||
items=body.items,
|
||||
)
|
||||
except AgentDriveError as exc:
|
||||
return _error_response(exc)
|
||||
return {"items": items}
|
||||
@ -25,14 +25,17 @@ from core.plugin.entities.request import (
|
||||
RequestInvokeTextEmbedding,
|
||||
RequestInvokeTool,
|
||||
RequestInvokeTTS,
|
||||
RequestRequestDownloadFile,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.signature import get_signed_file_url_for_plugin
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import length_prefixed_response
|
||||
from models import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.file_request_service import FileRequestService
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/llm")
|
||||
@ -429,6 +432,54 @@ class PluginUploadFileRequestApi(Resource):
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/download/file/request")
|
||||
class PluginDownloadFileRequestApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@plugin_data(payload_type=RequestRequestDownloadFile)
|
||||
@inner_api_ns.doc("plugin_download_file_request")
|
||||
@inner_api_ns.doc(description="Request signed URL for file download through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Signed URL generated successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, payload: RequestRequestDownloadFile):
|
||||
"""Resolve signed download metadata for trusted external runtimes.
|
||||
|
||||
Unlike end-user-facing upload/download APIs, this inner endpoint serves
|
||||
trusted callers such as the ``dify-agent`` back proxy. The caller sends
|
||||
flattened ``tenant_id`` / ``user_id`` / ``user_from`` / ``invoke_from``
|
||||
context explicitly in the body, and ``FileRequestService`` rebuilds the
|
||||
corresponding ``FileAccessScope`` before resolving the signed URL.
|
||||
|
||||
The response is control-plane metadata only: filename, mime type, size,
|
||||
and the signed download URL. File bytes still flow through the existing
|
||||
signed file endpoints rather than through this inner API.
|
||||
"""
|
||||
tenant_model = db.session.get(Tenant, payload.tenant_id)
|
||||
if tenant_model is None:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
result = FileRequestService().request_download_url(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=payload.user_id,
|
||||
user_from=payload.user_from,
|
||||
invoke_from=payload.invoke_from,
|
||||
file_mapping=payload.file.model_dump(mode="python", exclude_none=True),
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data={
|
||||
"filename": result.filename,
|
||||
"mime_type": result.mime_type,
|
||||
"size": result.size,
|
||||
"download_url": result.download_url,
|
||||
}
|
||||
).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/fetch/app/info")
|
||||
class PluginFetchAppInfoApi(Resource):
|
||||
@get_user_tenant
|
||||
|
||||
@ -25,6 +25,9 @@ from controllers.openapi._models import (
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppDslExportQuery,
|
||||
AppDslExportResponse,
|
||||
AppDslImportPayload,
|
||||
AppInfoResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
@ -37,6 +40,8 @@ from controllers.openapi._models import (
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
FormSubmitResponse,
|
||||
HealthResponse,
|
||||
MemberActionResponse,
|
||||
MemberInvitePayload,
|
||||
MemberInviteResponse,
|
||||
@ -49,9 +54,11 @@ from controllers.openapi._models import (
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
ServerVersionResponse,
|
||||
SessionListQuery,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
TagItem,
|
||||
TaskStopResponse,
|
||||
UsageInfo,
|
||||
WorkflowRunData,
|
||||
WorkspaceDetailResponse,
|
||||
@ -60,10 +67,14 @@ from controllers.openapi._models import (
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
from fields.file_fields import FileResponse
|
||||
from services.app_dsl_service import Import
|
||||
from services.entities.dsl_entities import CheckDependenciesResult
|
||||
|
||||
register_schema_models(
|
||||
openapi_ns,
|
||||
AppDescribeQuery,
|
||||
AppDslImportPayload,
|
||||
AppDslExportQuery,
|
||||
AppListQuery,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
@ -74,6 +85,7 @@ register_schema_models(
|
||||
MemberListQuery,
|
||||
MemberRoleUpdatePayload,
|
||||
PermittedExternalAppsListQuery,
|
||||
SessionListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
openapi_ns,
|
||||
@ -85,6 +97,9 @@ register_response_schema_models(
|
||||
AppInfoResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
AppDslExportResponse,
|
||||
Import,
|
||||
CheckDependenciesResult,
|
||||
WorkflowRunData,
|
||||
AccountPayload,
|
||||
WorkspacePayload,
|
||||
@ -100,16 +115,20 @@ register_response_schema_models(
|
||||
MemberListResponse,
|
||||
MemberInviteResponse,
|
||||
MemberActionResponse,
|
||||
TaskStopResponse,
|
||||
FormSubmitResponse,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateResponse,
|
||||
FileResponse,
|
||||
ServerVersionResponse,
|
||||
HealthResponse,
|
||||
)
|
||||
|
||||
from . import (
|
||||
_meta,
|
||||
account,
|
||||
app_dsl,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted_external,
|
||||
@ -127,6 +146,7 @@ from . import (
|
||||
__all__ = [
|
||||
"_meta",
|
||||
"account",
|
||||
"app_dsl",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted_external",
|
||||
|
||||
81
api/controllers/openapi/_contract.py
Normal file
81
api/controllers/openapi/_contract.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Request/response contract decorators for the openapi controllers.
|
||||
|
||||
``@accepts`` and ``@returns`` own one slice of the contract from a single model
|
||||
reference — emitting the Swagger schema AND doing the runtime validation/
|
||||
serialisation — so the advertised and enforced contracts can't drift. Validation
|
||||
failures map to a single shape: 422.
|
||||
|
||||
They must sit BELOW ``@auth_router.guard`` so auth runs before validation and the
|
||||
``view.__wrapped__`` unit-test seam unwraps exactly the guard layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import abort
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from controllers.common.schema import query_params_from_model, query_params_from_request
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
def accepts(*, query: type[BaseModel] | None = None, body: type[BaseModel] | None = None) -> Callable:
|
||||
"""Validate ``query``/``body`` against the models and inject them as keyword-only kwargs.
|
||||
|
||||
Emits the matching Swagger schema from the same models, so doc and enforcement
|
||||
stay in lockstep.
|
||||
"""
|
||||
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
if query is not None:
|
||||
kwargs["query"] = query_params_from_request(query)
|
||||
if body is not None:
|
||||
kwargs["body"] = body.model_validate(request.get_json(silent=True) or {})
|
||||
except ValidationError as exc:
|
||||
# Sanitized 422 — no pydantic `url` (version) or `input` (user payload) leak.
|
||||
abort(
|
||||
422,
|
||||
message="Request validation failed",
|
||||
errors=exc.errors(include_url=False, include_input=False, include_context=False),
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if query is not None:
|
||||
openapi_ns.doc(params=query_params_from_model(query))(wrapper)
|
||||
if body is not None:
|
||||
openapi_ns.expect(openapi_ns.models[body.__name__])(wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def returns(code: int, model: type[BaseModel], description: str | None = None) -> Callable:
|
||||
"""Serialise the handler's returned model and emit the response schema.
|
||||
|
||||
Accepts a ``BaseModel`` (serialised with ``code``) or a ``(model, status[, headers])``
|
||||
tuple (status/headers honoured). Other returns — a bare ``(dict, status)``, an SSE
|
||||
``Response`` — pass through untouched.
|
||||
"""
|
||||
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
result = view(*args, **kwargs)
|
||||
if isinstance(result, BaseModel):
|
||||
return result.model_dump(mode="json"), code
|
||||
if isinstance(result, tuple) and result and isinstance(result[0], BaseModel):
|
||||
payload, *rest = result
|
||||
return (payload.model_dump(mode="json"), *rest)
|
||||
return result
|
||||
|
||||
openapi_ns.response(code, description or model.__name__, openapi_ns.models[model.__name__])(wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@ -9,15 +9,16 @@ from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import returns
|
||||
from controllers.openapi._models import ServerVersionResponse
|
||||
|
||||
|
||||
@openapi_ns.route("/_version")
|
||||
class VersionApi(Resource):
|
||||
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
|
||||
@returns(200, ServerVersionResponse, description="Server version")
|
||||
def get(self):
|
||||
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
|
||||
return ServerVersionResponse(
|
||||
version=dify_config.project.version,
|
||||
edition=edition,
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
|
||||
@ -4,9 +4,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from libs.helper import EmailStr, UUIDStrOrEmpty, uuid_value
|
||||
from libs.helper import EmailStr, UUIDStr, UUIDStrOrEmpty, uuid_value
|
||||
from models.model import AppMode
|
||||
|
||||
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||
@ -87,8 +87,12 @@ class AppDescribeInfo(AppInfoResponse):
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
# `parameters` (the app-config blob) and `input_schema` (a Draft 2020-12 JSON Schema derived
|
||||
# per-app) are deliberately open JSON, not under-annotated. The `x-dify-opaque` marker tells the
|
||||
# contract generator's readiness detector to treat them as intentional, so the route is not
|
||||
# flagged "annotations incomplete". CLI/web consume them as opaque objects either way.
|
||||
parameters: dict[str, Any] | None = Field(default=None, json_schema_extra={"x-dify-opaque": True})
|
||||
input_schema: dict[str, Any] | None = Field(default=None, json_schema_extra={"x-dify-opaque": True})
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
@ -173,6 +177,15 @@ class SessionListResponse(BaseModel):
|
||||
data: list[SessionRow]
|
||||
|
||||
|
||||
class SessionListQuery(BaseModel):
|
||||
"""Pagination for GET /account/sessions. Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(100, ge=1, le=MAX_PAGE_LIMIT)
|
||||
|
||||
|
||||
class RevokeResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
@ -223,6 +236,23 @@ class ServerVersionResponse(BaseModel):
|
||||
edition: Literal["SELF_HOSTED", "CLOUD"]
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Liveness payload for `GET /openapi/v1/_health` — no auth required."""
|
||||
|
||||
ok: bool
|
||||
|
||||
|
||||
def _csv_string_query_schema(schema: dict[str, Any]) -> None:
|
||||
"""Re-shape a set/list field's query schema to a comma-separated string — the wire form the
|
||||
handler actually accepts (`request.args` is flat + the validator splits on ','). Without this
|
||||
the generated contract would type it as an array and serialize `fields[0]=…&fields[1]=…`,
|
||||
which `extra='forbid'` rejects. Runtime `set[str]` validation is unaffected."""
|
||||
schema.pop("anyOf", None)
|
||||
schema.pop("items", None)
|
||||
schema.pop("uniqueItems", None)
|
||||
schema["type"] = "string"
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
@ -231,23 +261,7 @@ class AppDescribeQuery(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
import uuid as _uuid
|
||||
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
fields: set[str] | None = Field(default=None, json_schema_extra=_csv_string_query_schema)
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
@ -267,7 +281,7 @@ class AppDescribeQuery(BaseModel):
|
||||
class AppListQuery(BaseModel):
|
||||
"""mode is a closed enum."""
|
||||
|
||||
workspace_id: str
|
||||
workspace_id: UUIDStr
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
@ -400,3 +414,58 @@ class MemberInviteResponse(BaseModel):
|
||||
|
||||
class MemberActionResponse(BaseModel):
|
||||
result: Literal["success"] = "success"
|
||||
|
||||
|
||||
class TaskStopResponse(BaseModel):
|
||||
"""200 body for POST /apps/<id>/tasks/<task_id>/stop. The handler always returns
|
||||
{"result": "success"}, so `result` is required (no default) — the generated contract
|
||||
types it as a required `'success'` rather than an optional field."""
|
||||
|
||||
result: Literal["success"]
|
||||
|
||||
|
||||
class AppDslImportPayload(BaseModel):
|
||||
"""Request body for POST /workspaces/<workspace_id>/apps/imports."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
mode: Literal["yaml-content", "yaml-url"] = Field(..., description="Import mode: yaml-content or yaml-url")
|
||||
yaml_content: str | None = Field(None, description="Inline YAML DSL string (required when mode is yaml-content)")
|
||||
yaml_url: str | None = Field(None, description="Remote URL to fetch YAML from (required when mode is yaml-url)")
|
||||
name: str | None = Field(None, description="Override the app name from the DSL")
|
||||
description: str | None = Field(None, description="Override the app description from the DSL")
|
||||
icon_type: str | None = Field(None)
|
||||
icon: str | None = Field(None)
|
||||
icon_background: str | None = Field(None)
|
||||
app_id: str | None = Field(None, description="Existing app ID to overwrite (workflow/advanced-chat apps only)")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_source_by_mode(self) -> AppDslImportPayload:
|
||||
if self.mode == "yaml-content" and not self.yaml_content:
|
||||
raise ValueError("yaml_content is required when mode is 'yaml-content'")
|
||||
if self.mode == "yaml-url" and not self.yaml_url:
|
||||
raise ValueError("yaml_url is required when mode is 'yaml-url'")
|
||||
return self
|
||||
|
||||
|
||||
class AppDslExportQuery(BaseModel):
|
||||
"""Query parameters for GET /apps/<app_id>/export."""
|
||||
|
||||
include_secret: bool = Field(False, description="Include encrypted secret values in the exported DSL")
|
||||
workflow_id: UUIDStr | None = Field(
|
||||
None, description="Export a specific workflow version instead of the current draft"
|
||||
)
|
||||
|
||||
|
||||
class AppDslExportResponse(BaseModel):
|
||||
"""Export DSL response."""
|
||||
|
||||
data: str = Field(..., description="DSL YAML string")
|
||||
|
||||
|
||||
class FormSubmitResponse(BaseModel):
|
||||
"""Empty 200 body for POST /apps/<id>/form/human_input/<token>. `extra='forbid'`
|
||||
pins `additionalProperties: false` so the generated contract is an exact `{}` rather
|
||||
than an under-annotated open object."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@ -2,17 +2,16 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
PaginationEnvelope,
|
||||
RevokeResponse,
|
||||
SessionListQuery,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
@ -40,8 +39,8 @@ from services.oauth_device_flow import (
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, AccountResponse, description="Account info")
|
||||
def get(self, *, auth_data: AuthData):
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
|
||||
|
||||
@ -56,27 +55,31 @@ class AccountApi(Resource):
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, RevokeResponse, description="Session revoked")
|
||||
def delete(self, *, auth_data: AuthData):
|
||||
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
return RevokeResponse(status="revoked")
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
@returns(200, SessionListResponse, description="Session list")
|
||||
@accepts(query=SessionListQuery)
|
||||
def get(self, *, auth_data: AuthData, query: SessionListQuery):
|
||||
# SessionListQuery enforces the advertised bounds (extra='forbid', page>=1,
|
||||
# 1<=limit<=MAX_PAGE_LIMIT) so the server rejects out-of-range paging rather
|
||||
# than silently coercing (e.g. page=0 -> empty slice).
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
page = query.page
|
||||
limit = query.limit
|
||||
|
||||
all_rows = list_active_sessions(db.session, ctx, now)
|
||||
|
||||
@ -96,16 +99,19 @@ class AccountSessionsApi(Resource):
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
return SessionListResponse(
|
||||
page=page,
|
||||
limit=limit,
|
||||
total=total,
|
||||
has_more=page * limit < total,
|
||||
data=items,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, RevokeResponse, description="Session revoked")
|
||||
def delete(self, session_id: str, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
@ -115,7 +121,7 @@ class AccountSessionByIdApi(Resource):
|
||||
raise NotFound("session not found")
|
||||
|
||||
revoke_oauth_token(db.session, redis_client, session_id)
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
return RevokeResponse(status="revoked")
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
|
||||
167
api/controllers/openapi/app_dsl.py
Normal file
167
api/controllers/openapi/app_dsl.py
Normal file
@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import AppDslExportQuery, AppDslExportResponse, AppDslImportPayload
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, App
|
||||
from models.account import TenantAccountRole
|
||||
from services.app_dsl_service import AppDslService, Import
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus
|
||||
from services.errors.app import WorkflowNotFoundError
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/apps/imports")
|
||||
class AppDslImportApi(Resource):
|
||||
"""Import a DSL YAML string into the specified workspace.
|
||||
|
||||
Use ``mode=yaml-content`` with ``yaml_content`` for inline YAML, or
|
||||
``mode=yaml-url`` with ``yaml_url`` for a remote URL. Provide ``app_id``
|
||||
to overwrite an existing workflow or advanced-chat app; omit it to create
|
||||
a new app.
|
||||
|
||||
Returns 202 when the DSL version requires an explicit confirmation step
|
||||
(major version mismatch). Callers must then POST to the confirm endpoint.
|
||||
Returns 400 when the import failed due to invalid DSL or a business error.
|
||||
"""
|
||||
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
)
|
||||
@returns(200, Import, "Import completed")
|
||||
@returns(202, Import, "Import pending confirmation")
|
||||
@returns(400, Import, "Import failed")
|
||||
@accepts(body=AppDslImportPayload)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData, body: AppDslImportPayload):
|
||||
account = cast(Account, auth_data.caller)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
service = AppDslService(session)
|
||||
result = service.import_app(
|
||||
account=account,
|
||||
import_mode=body.mode,
|
||||
yaml_content=body.yaml_content,
|
||||
yaml_url=body.yaml_url,
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
icon_type=body.icon_type,
|
||||
icon=body.icon,
|
||||
icon_background=body.icon_background,
|
||||
app_id=body.app_id,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
match result.status:
|
||||
case ImportStatus.FAILED:
|
||||
return result, 400
|
||||
case ImportStatus.PENDING:
|
||||
return result, 202
|
||||
case _:
|
||||
return result, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/apps/imports/<string:import_id>/confirm")
|
||||
class AppDslImportConfirmApi(Resource):
|
||||
"""Confirm a pending DSL import identified by ``import_id``.
|
||||
|
||||
Required only when the initial import returned 202 (major DSL version
|
||||
mismatch that requires explicit acknowledgement). The pending state is
|
||||
stored in Redis for 10 minutes; this call retrieves it and completes the
|
||||
import under the given workspace.
|
||||
|
||||
Returns 400 when the pending data has expired or the import fails.
|
||||
"""
|
||||
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
)
|
||||
@returns(200, Import, "Import confirmed")
|
||||
@returns(400, Import, "Import failed")
|
||||
def post(self, workspace_id: str, import_id: str, *, auth_data: AuthData):
|
||||
account = cast(Account, auth_data.caller)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
service = AppDslService(session)
|
||||
result = service.confirm_import(import_id=import_id, account=account)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result, 400
|
||||
return result, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/export")
|
||||
class AppDslExportApi(Resource):
|
||||
"""Export an app's current draft configuration as a DSL YAML string.
|
||||
|
||||
The auth pipeline resolves the app and its tenant from ``app_id``. Pass
|
||||
``include_secret=true`` to embed encrypted credential values (e.g. tool
|
||||
node secrets); omit it to produce a portable, sharable DSL safe to share.
|
||||
|
||||
Note: the pipeline enforces ``app.enable_api`` for all ``/apps/<app_id>``
|
||||
routes in the openapi group. Apps with the service API disabled will
|
||||
receive a 403; enable the API in the console first if needed.
|
||||
"""
|
||||
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
)
|
||||
@accepts(query=AppDslExportQuery)
|
||||
@returns(200, AppDslExportResponse, "Export successful")
|
||||
def get(self, app_id: str, *, auth_data: AuthData, query: AppDslExportQuery):
|
||||
app = cast(App, auth_data.app)
|
||||
try:
|
||||
data = AppDslService.export_dsl(
|
||||
app_model=app,
|
||||
include_secret=query.include_secret,
|
||||
workflow_id=query.workflow_id,
|
||||
)
|
||||
except WorkflowNotFoundError as exc:
|
||||
return str(exc), 404
|
||||
return AppDslExportResponse(data=data), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/check-dependencies")
|
||||
class AppDslCheckDependenciesApi(Resource):
|
||||
"""Check for leaked plugin dependencies after a DSL import.
|
||||
|
||||
Call this after an import that reported ``COMPLETED_WITH_WARNINGS`` to
|
||||
find which plugin dependencies referenced in the DSL are not yet installed
|
||||
in the workspace. Returns an empty ``leaked_dependencies`` list when all
|
||||
dependencies are satisfied.
|
||||
"""
|
||||
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
)
|
||||
@returns(200, CheckDependenciesResult, "Dependencies checked")
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
app = cast(App, auth_data.app)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
service = AppDslService(session)
|
||||
result = service.check_dependencies(app_model=app)
|
||||
|
||||
return result, 200
|
||||
@ -7,15 +7,14 @@ from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import AppRunRequest, TaskStopResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import (
|
||||
@ -123,23 +122,18 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@accepts(body=AppRunRequest)
|
||||
def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
try:
|
||||
stream_obj = handler(app_model, caller, payload)
|
||||
stream_obj = handler(app_model, caller, body)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
@ -159,10 +153,10 @@ class AppRunApi(Resource):
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@returns(200, TaskStopResponse, description="Task stopped")
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
return TaskStopResponse(result="success")
|
||||
|
||||
@ -5,14 +5,12 @@ from __future__ import annotations
|
||||
import uuid as _uuid
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeInfo,
|
||||
@ -88,16 +86,12 @@ def parameters_payload(app: App) -> dict:
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app = self._load(app_id, workspace_id=query.workspace_id)
|
||||
@returns(200, AppDescribeResponse, description="App description")
|
||||
@accepts(query=AppDescribeQuery)
|
||||
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
|
||||
# describe is UUID-only (workspace_id query param dropped in #37212).
|
||||
app = self._load(app_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
@ -133,35 +127,22 @@ class AppDescribeApi(AppReadResource):
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
return AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
@returns(200, AppListResponse, description="App list")
|
||||
@accepts(query=AppListQuery)
|
||||
def get(self, *, auth_data: AuthData, query: AppListQuery):
|
||||
workspace_id = query.workspace_id
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
200,
|
||||
)
|
||||
empty = AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[])
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
@ -189,7 +170,7 @@ class AppListApi(Resource):
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
return env
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
@ -240,4 +221,4 @@ class AppListApi(Resource):
|
||||
has_more=query.page * query.limit < cast(int, pagination.total),
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
return env
|
||||
|
||||
@ -7,12 +7,10 @@ EE blueprint chain so this module is unreachable there.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
@ -30,20 +28,14 @@ from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
@returns(200, PermittedExternalAppsListResponse, description="Permitted external apps list")
|
||||
@accepts(query=PermittedExternalAppsListQuery)
|
||||
def get(self, *, auth_data: AuthData, query: PermittedExternalAppsListQuery):
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
@ -55,7 +47,7 @@ class PermittedExternalAppsListApi(Resource):
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
return env
|
||||
|
||||
apps_by_id: dict[str, App] = {
|
||||
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
|
||||
@ -89,4 +81,4 @@ class PermittedExternalAppsListApi(Resource):
|
||||
has_more=query.page * query.limit < page_result.total,
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
return env
|
||||
|
||||
@ -19,6 +19,10 @@ def load_app(data: AuthData) -> None:
|
||||
if data.app is not None:
|
||||
return
|
||||
app_id = data.path_params["app_id"]
|
||||
try:
|
||||
uuid.UUID(app_id)
|
||||
except ValueError:
|
||||
raise NotFound("app not found")
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.common.errors import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import returns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
@ -38,8 +39,8 @@ class AppFileUploadApi(Resource):
|
||||
415: "Unsupported file type or blocked extension",
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@returns(HTTPStatus.CREATED, FileResponse, description="File uploaded")
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, _ = auth_data.require_app_context()
|
||||
if "file" not in request.files:
|
||||
@ -69,5 +70,4 @@ class AppFileUploadApi(Resource):
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError(exc.description)
|
||||
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
return FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
|
||||
@ -10,13 +10,15 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import Response, request
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import FormSubmitResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
@ -69,12 +71,11 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
@returns(200, FormSubmitResponse, description="Form submitted")
|
||||
@accepts(body=HumanInputFormSubmitPayload)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
@ -99,12 +100,12 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
service.submit_form_by_token(
|
||||
recipient_type=form.recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
selected_action_id=body.action,
|
||||
form_data=body.inputs,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
return FormSubmitResponse()
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import returns
|
||||
from controllers.openapi._models import HealthResponse
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
@returns(200, HealthResponse, description="Health check")
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
return HealthResponse(ok=True)
|
||||
|
||||
@ -14,14 +14,13 @@ from __future__ import annotations
|
||||
from itertools import starmap
|
||||
from urllib import parse
|
||||
|
||||
from flask import jsonify, make_response, request
|
||||
from flask import jsonify, make_response
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import (
|
||||
MemberActionResponse,
|
||||
MemberInvitePayload,
|
||||
@ -53,14 +52,6 @@ from services.errors.account import (
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _validate_body[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _member_response(account: Account) -> MemberResponse:
|
||||
return MemberResponse(
|
||||
id=str(account.id),
|
||||
@ -118,18 +109,18 @@ def _check_member_invite_quota(tenant_id: str) -> None:
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, WorkspaceListResponse, description="Workspace list")
|
||||
def get(self, *, auth_data: AuthData):
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows)))
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, WorkspaceDetailResponse, description="Workspace detail")
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
@ -137,7 +128,7 @@ class WorkspaceByIdApi(Resource):
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
return _workspace_detail(tenant, membership)
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/switch")
|
||||
@ -149,8 +140,8 @@ class WorkspaceSwitchApi(Resource):
|
||||
that ``hosts.yml`` never diverges from the server's ``current`` state.
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, WorkspaceDetailResponse, description="Workspace detail")
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
@ -163,7 +154,7 @@ class WorkspaceSwitchApi(Resource):
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
return _workspace_detail(tenant, membership)
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
@ -174,15 +165,10 @@ class WorkspaceMembersApi(Resource):
|
||||
assigned through invite (ownership transfer is console-only).
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
|
||||
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
|
||||
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
@returns(200, MemberListResponse, description="Member list")
|
||||
@accepts(query=MemberListQuery)
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery):
|
||||
tenant = _load_tenant(workspace_id)
|
||||
members = TenantService.get_tenant_members(tenant)
|
||||
total = len(members)
|
||||
@ -194,17 +180,16 @@ class WorkspaceMembersApi(Resource):
|
||||
total=total,
|
||||
has_more=query.page * query.limit < total,
|
||||
data=[_member_response(m) for m in page_items],
|
||||
).model_dump(mode="json"), 200
|
||||
)
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
|
||||
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberInvitePayload)
|
||||
@returns(201, MemberInviteResponse, description="Member invited")
|
||||
@accepts(body=MemberInvitePayload)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData, body: MemberInvitePayload):
|
||||
inviter = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
|
||||
@ -213,9 +198,9 @@ class WorkspaceMembersApi(Resource):
|
||||
try:
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=tenant,
|
||||
email=payload.email,
|
||||
email=body.email,
|
||||
language=None,
|
||||
role=payload.role,
|
||||
role=body.role,
|
||||
inviter=inviter,
|
||||
)
|
||||
except AccountAlreadyInTenantError as exc:
|
||||
@ -225,7 +210,7 @@ class WorkspaceMembersApi(Resource):
|
||||
except AccountRegisterError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
normalized_email = payload.email.lower()
|
||||
normalized_email = body.email.lower()
|
||||
member = AccountService.get_account_by_email_with_case_fallback(normalized_email)
|
||||
if member is None:
|
||||
# invite_new_member just created or fetched this account.
|
||||
@ -235,11 +220,11 @@ class WorkspaceMembersApi(Resource):
|
||||
invite_url = f"{dify_config.CONSOLE_WEB_URL}/activate?email={encoded_email}&token={token}"
|
||||
return MemberInviteResponse(
|
||||
email=normalized_email,
|
||||
role=payload.role,
|
||||
role=body.role,
|
||||
member_id=str(member.id),
|
||||
invite_url=invite_url,
|
||||
tenant_id=str(tenant.id),
|
||||
).model_dump(mode="json"), 201
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>")
|
||||
@ -251,12 +236,12 @@ class WorkspaceMemberApi(Resource):
|
||||
400 per the spec, with the service's message preserved.
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
@returns(200, MemberActionResponse, description="Member removed")
|
||||
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
@ -273,7 +258,7 @@ class WorkspaceMemberApi(Resource):
|
||||
except MemberNotInTenantError as exc:
|
||||
raise NotFound(str(exc))
|
||||
|
||||
return MemberActionResponse().model_dump(mode="json"), 200
|
||||
return MemberActionResponse()
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>/role")
|
||||
@ -284,15 +269,14 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
standing owner (service NoPermissionError → 400, per spec).
|
||||
"""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard_workspace(
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
|
||||
)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberRoleUpdatePayload)
|
||||
@returns(200, MemberActionResponse, description="Role updated")
|
||||
@accepts(body=MemberRoleUpdatePayload)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData, body: MemberRoleUpdatePayload):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
member = AccountService.get_account_by_id(db.session, member_id)
|
||||
@ -300,7 +284,7 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.update_member_role(tenant, member, payload.role, operator)
|
||||
TenantService.update_member_role(tenant, member, body.role, operator)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
@ -310,7 +294,7 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
except RoleAlreadyAssignedError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
return MemberActionResponse().model_dump(mode="json"), 200
|
||||
return MemberActionResponse()
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||
|
||||
@ -121,13 +121,3 @@ class WebFormRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "web_form_rate_limit_exceeded"
|
||||
description = "Too many form requests. Please try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "not_found"
|
||||
code = 404
|
||||
|
||||
|
||||
class InvalidArgumentError(BaseHTTPException):
|
||||
error_code = "invalid_param"
|
||||
code = 400
|
||||
|
||||
@ -15,10 +15,11 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.errors import NotFoundError
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.error import WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
from extensions.ext_database import db
|
||||
from graphon.nodes.human_input.entities import FormInputConfig
|
||||
|
||||
@ -8,8 +8,8 @@ from collections.abc import Generator
|
||||
from flask import Response, request
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.errors import InvalidArgumentError, NotFoundError
|
||||
from controllers.web import api
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
|
||||
@ -13,7 +13,11 @@ from dataclasses import dataclass
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
|
||||
from dify_agent.layers.execution_context import (
|
||||
DifyExecutionContextInvokeFrom,
|
||||
DifyExecutionContextLayerConfig,
|
||||
DifyExecutionContextUserFrom,
|
||||
)
|
||||
from dify_agent.protocol import CreateRunRequest
|
||||
|
||||
from clients.agent_backend import (
|
||||
@ -126,7 +130,10 @@ class AgentAppRuntimeRequestBuilder:
|
||||
conversation_id=context.conversation_id,
|
||||
agent_id=context.agent_id,
|
||||
agent_config_version_id=context.agent_config_snapshot_id,
|
||||
invoke_from="agent_app",
|
||||
# Agent Files §1.3: real Dify access context + agent run mode.
|
||||
user_from=cast(DifyExecutionContextUserFrom, context.dify_context.user_from.value),
|
||||
invoke_from=cast(DifyExecutionContextInvokeFrom, context.dify_context.invoke_from.value),
|
||||
agent_mode="agent_app",
|
||||
),
|
||||
agent_soul_prompt=agent_soul.prompt.system_prompt or None,
|
||||
user_prompt=context.user_query,
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
@ -68,6 +68,25 @@ def _extract_trace_session_id_from_debug_args(args: Mapping[str, Any] | Any) ->
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
|
||||
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
|
||||
if workflow.kind_or_standard != "snippet":
|
||||
return workflow
|
||||
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
|
||||
snippet = session.scalar(
|
||||
select(CustomizedSnippet).where(
|
||||
CustomizedSnippet.id == workflow.app_id,
|
||||
CustomizedSnippet.tenant_id == workflow.tenant_id,
|
||||
)
|
||||
)
|
||||
if snippet is None:
|
||||
return workflow
|
||||
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
|
||||
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
@ -592,6 +611,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.nodes.agent_v2.session_cleanup_layer import build_workflow_agent_session_cleanup_layer
|
||||
from core.workflow.snippet_start import get_compatible_start_aliases
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -118,7 +119,15 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
),
|
||||
)
|
||||
root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict)
|
||||
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
|
||||
add_node_inputs_to_pool(
|
||||
variable_pool,
|
||||
node_id=root_node_id,
|
||||
inputs=inputs,
|
||||
aliases=get_compatible_start_aliases(
|
||||
workflow_kind=getattr(self._workflow, "kind_or_standard", None),
|
||||
root_node_id=root_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph = self._init_graph(
|
||||
|
||||
@ -37,6 +37,13 @@ class MCPSupportGrantType(StrEnum):
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
class IdentityMode(StrEnum):
|
||||
"""How Dify forwards the end-user's identity to an MCP server."""
|
||||
|
||||
OFF = "off"
|
||||
IDP_TOKEN = "idp_token"
|
||||
|
||||
|
||||
class MCPAuthentication(BaseModel):
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
@ -76,6 +83,8 @@ class MCPProviderEntity(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
identity_mode: IdentityMode = IdentityMode.OFF
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
|
||||
"""Create entity from database model with decryption"""
|
||||
@ -96,6 +105,7 @@ class MCPProviderEntity(BaseModel):
|
||||
icon=db_provider.icon or "",
|
||||
created_at=db_provider.created_at,
|
||||
updated_at=db_provider.updated_at,
|
||||
identity_mode=IdentityMode(db_provider.identity_mode),
|
||||
)
|
||||
|
||||
@property
|
||||
@ -170,6 +180,7 @@ class MCPProviderEntity(BaseModel):
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
"identity_mode": self.identity_mode,
|
||||
}
|
||||
|
||||
# Add configuration
|
||||
|
||||
@ -316,6 +316,7 @@ class IndexingRunner:
|
||||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
deleted_preview_images = False
|
||||
# doc_form represents the segmentation method (general, parent-child, QA)
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
@ -368,6 +369,10 @@ class IndexingRunner:
|
||||
upload_file_id,
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
deleted_preview_images = True
|
||||
|
||||
if deleted_preview_images:
|
||||
db.session.commit()
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
|
||||
@ -40,6 +40,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
authorization_code: str | None = None,
|
||||
by_server_id: bool = False,
|
||||
forward_identity_active: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
@ -52,12 +53,15 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
provider_entity: Provider entity for authentication
|
||||
authorization_code: Optional authorization code for initial auth
|
||||
by_server_id: Whether to look up provider by server ID
|
||||
forward_identity_active: If True, suppress the static-OAuth retry
|
||||
on 401 — the forwarded identity must propagate as-is.
|
||||
"""
|
||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||
|
||||
self.provider_entity = provider_entity
|
||||
self.authorization_code = authorization_code
|
||||
self.by_server_id = by_server_id
|
||||
self.forward_identity_active = forward_identity_active
|
||||
self._has_retried = False
|
||||
|
||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||
@ -73,6 +77,8 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
if self.forward_identity_active:
|
||||
raise error
|
||||
if not self.provider_entity:
|
||||
raise error
|
||||
if self._has_retried:
|
||||
|
||||
@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, override
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from cachetools import LRUCache
|
||||
@ -221,9 +221,10 @@ class TracingProviderConfigEntry(TypedDict):
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
|
||||
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
|
||||
@override
|
||||
def __getitem__(self, key: str) -> TracingProviderConfigEntry:
|
||||
try:
|
||||
match provider:
|
||||
match key:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
|
||||
@ -330,9 +331,9 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigE
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
raise KeyError(f"Unsupported tracing provider: {key}")
|
||||
except ImportError:
|
||||
raise ImportError(f"Provider {provider} is not installed.")
|
||||
raise ImportError(f"Provider {key} is not installed.")
|
||||
|
||||
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
@ -4,10 +4,11 @@ from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import Response
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.plugin.utils.http_parser import deserialize_response
|
||||
from core.workflow.file_reference import is_canonical_file_reference
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -231,6 +232,53 @@ class RequestRequestUploadFile(BaseModel):
|
||||
mimetype: str
|
||||
|
||||
|
||||
class RequestDownloadFileMapping(BaseModel):
|
||||
"""File mapping accepted by trusted download-request control-plane APIs."""
|
||||
|
||||
transfer_method: Literal["local_file", "tool_file", "datasource_file", "remote_url"]
|
||||
reference: str | None = None
|
||||
url: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_locator(self) -> "RequestDownloadFileMapping":
|
||||
if self.transfer_method == "remote_url":
|
||||
if not self.url:
|
||||
raise ValueError("url is required when transfer_method is remote_url")
|
||||
if self.reference is not None:
|
||||
raise ValueError("reference is not allowed when transfer_method is remote_url")
|
||||
return self
|
||||
if not self.reference:
|
||||
raise ValueError("reference is required for non-remote file mappings")
|
||||
if not is_canonical_file_reference(self.reference):
|
||||
raise ValueError("reference must be a canonical Dify file reference")
|
||||
if self.url is not None:
|
||||
raise ValueError("url is not allowed for non-remote file mappings")
|
||||
return self
|
||||
|
||||
|
||||
class RequestRequestDownloadFile(BaseModel):
|
||||
"""Request to resolve a signed download URL for one runtime file mapping."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
user_from: Literal["account", "end-user"]
|
||||
invoke_from: Literal[
|
||||
"service-api",
|
||||
"openapi",
|
||||
"web-app",
|
||||
"trigger",
|
||||
"explore",
|
||||
"debugger",
|
||||
"published",
|
||||
"validation",
|
||||
]
|
||||
file: RequestDownloadFileMapping
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RequestFetchAppInfo(BaseModel):
|
||||
"""
|
||||
Request to fetch app info
|
||||
|
||||
@ -36,7 +36,10 @@ class PluginEndpointClient(BasePluginClient):
|
||||
|
||||
def list_endpoints(self, tenant_id: str, user_id: str, page: int, page_size: int):
|
||||
"""
|
||||
List all endpoints for the given tenant and user.
|
||||
List all endpoints for the given tenant.
|
||||
|
||||
The daemon list route binds only tenant and pagination fields; user_id is
|
||||
retained in this client signature for consistency with endpoint services.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
@ -47,7 +50,10 @@ class PluginEndpointClient(BasePluginClient):
|
||||
|
||||
def list_endpoints_for_single_plugin(self, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
|
||||
"""
|
||||
List all endpoints for the given tenant, user and plugin.
|
||||
List all endpoints for the given tenant and plugin.
|
||||
|
||||
The daemon list route binds tenant, plugin and pagination fields; user_id
|
||||
is retained in this client signature for consistency with endpoint services.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
|
||||
@ -4,6 +4,12 @@ This module owns plugin daemon management calls that are shared by API services
|
||||
and core runtimes. Plugin model provider discovery is cached here, alongside
|
||||
plugin install, uninstall, and upgrade invalidation, so all cache mutations for
|
||||
plugin-owned provider metadata stay tenant-scoped and in one place.
|
||||
|
||||
The console plugin list also normalizes endpoint setup counters against live
|
||||
endpoint records. Some plugin daemon builds return stale ``endpoints_*``
|
||||
aggregates in ``management/list`` even while plugin-scoped endpoint queries are
|
||||
current, so the API reconciles those counts before serving workspace plugin
|
||||
metadata.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@ -38,6 +44,7 @@ from core.plugin.entities.plugin_daemon import (
|
||||
)
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.endpoint import PluginEndpointClient
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
@ -69,6 +76,9 @@ class PluginService:
|
||||
REDIS_TTL = 60 * 5 # 5 minutes
|
||||
PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:"
|
||||
PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed)
|
||||
# Mirror the detail-panel endpoint query size so list reconciliation and
|
||||
# the visible endpoint drawer exercise the same daemon pagination path.
|
||||
ENDPOINT_RECONCILIATION_PAGE_SIZE = 100
|
||||
|
||||
@classmethod
|
||||
def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str:
|
||||
@ -287,14 +297,104 @@ class PluginService:
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def list_with_total(tenant_id: str, page: int, page_size: int) -> PluginListResponse:
|
||||
"""
|
||||
list all plugins of the tenant
|
||||
def list_with_total(tenant_id: str, user_id: str, page: int, page_size: int) -> PluginListResponse:
|
||||
"""List tenant plugins with endpoint counts reconciled from live records.
|
||||
|
||||
The plugin daemon's ``management/list`` payload is tenant-scoped, but
|
||||
some daemon builds undercount or stale-cache plugin endpoint aggregates.
|
||||
The list response therefore refreshes counters from the daemon's
|
||||
tenant-scoped endpoint records before returning workspace plugin metadata.
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins_with_total(tenant_id, page, page_size)
|
||||
PluginService._reconcile_endpoint_counts(tenant_id, user_id, plugins.list)
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def _normalize_endpoint_count(value: object) -> int:
|
||||
"""Convert daemon endpoint counters to safe non-negative integers.
|
||||
|
||||
Some daemon builds use ``-1`` as an "unknown / not synced yet" sentinel
|
||||
for endpoint counters. That value is acceptable internally as a daemon
|
||||
transport detail, but it must never leak through the console API because
|
||||
the UI displays these counters directly.
|
||||
"""
|
||||
if value is None:
|
||||
return 0
|
||||
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
|
||||
if isinstance(value, int):
|
||||
return max(0, value)
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return max(0, int(value))
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def _normalize_plugin_endpoint_counts(cls, plugin: PluginEntity) -> None:
|
||||
"""Clamp endpoint counters on plugin entities before returning them."""
|
||||
plugin.endpoints_setups = cls._normalize_endpoint_count(plugin.endpoints_setups)
|
||||
plugin.endpoints_active = cls._normalize_endpoint_count(plugin.endpoints_active)
|
||||
|
||||
@classmethod
|
||||
def _reconcile_endpoint_counts(cls, tenant_id: str, user_id: str, plugins: Sequence[PluginEntity]) -> None:
|
||||
"""Refresh endpoint counters from live plugin endpoint records.
|
||||
|
||||
``management/list`` is the source of truth for plugin installations, but
|
||||
some daemon versions lag when populating ``endpoints_setups`` and
|
||||
``endpoints_active``. The plugin-scoped endpoint listing is the same
|
||||
tenant-scoped source the console detail panel uses after reinstall flows,
|
||||
so the list view recomputes counts per plugin instead of trusting stale
|
||||
daemon aggregates.
|
||||
"""
|
||||
endpoint_client = PluginEndpointClient()
|
||||
|
||||
for plugin in plugins:
|
||||
cls._normalize_plugin_endpoint_counts(plugin)
|
||||
|
||||
if plugin.declaration.endpoint is None:
|
||||
continue
|
||||
|
||||
page = 1
|
||||
endpoints_setups = 0
|
||||
endpoints_active = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
endpoints = endpoint_client.list_endpoints_for_single_plugin(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin.plugin_id,
|
||||
page=page,
|
||||
page_size=cls.ENDPOINT_RECONCILIATION_PAGE_SIZE,
|
||||
)
|
||||
endpoints_setups += len(endpoints)
|
||||
endpoints_active += sum(int(endpoint.enabled) for endpoint in endpoints)
|
||||
|
||||
if len(endpoints) < cls.ENDPOINT_RECONCILIATION_PAGE_SIZE:
|
||||
break
|
||||
page += 1
|
||||
except Exception:
|
||||
logger.warning(
|
||||
(
|
||||
"Failed to reconcile live endpoint counters for tenant %s plugin %s; "
|
||||
"falling back to daemon plugin stats."
|
||||
),
|
||||
tenant_id,
|
||||
plugin.plugin_id,
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
plugin.endpoints_setups = cls._normalize_endpoint_count(endpoints_setups)
|
||||
plugin.endpoints_active = cls._normalize_endpoint_count(endpoints_active)
|
||||
|
||||
@staticmethod
|
||||
def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]:
|
||||
"""
|
||||
|
||||
@ -862,15 +862,20 @@ class RetrievalService:
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
)
|
||||
|
||||
query = query or attachment_id
|
||||
if not query:
|
||||
if query:
|
||||
rerank_query = query
|
||||
query_type = QueryType.TEXT_QUERY
|
||||
elif attachment_id:
|
||||
rerank_query = attachment_id
|
||||
query_type = QueryType.IMAGE_QUERY
|
||||
else:
|
||||
return
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
query=rerank_query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
|
||||
query_type=query_type,
|
||||
)
|
||||
if not data_post_processor.rerank_runner and score_threshold:
|
||||
all_documents_item = self._filter_documents_by_vector_score_threshold(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import select
|
||||
@ -25,6 +25,7 @@ class CacheEmbedding(Embeddings):
|
||||
def __init__(self, model_instance: ModelInstance):
|
||||
self._model_instance = model_instance
|
||||
|
||||
@override
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
# use doc embedding cache or store if not exists
|
||||
@ -106,6 +107,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@override
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
"""Embed file documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
@ -189,6 +191,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return multimodel_embeddings
|
||||
|
||||
@override
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
@ -232,6 +235,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return embedding_results # type: ignore
|
||||
|
||||
@override
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
|
||||
@ -1,13 +1,32 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
"""Excel document extractor used for RAG ingestion.
|
||||
|
||||
Supports cell hyperlinks for both `.xls` and `.xlsx`, and embedded worksheet images
|
||||
for `.xlsx` files by converting them into markdown image links. Embedded images are
|
||||
stored with deterministic keys derived from the source upload file and anchor cell so
|
||||
retries can safely reuse the same assets.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import TypedDict, override
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Candidate(TypedDict):
|
||||
@ -16,17 +35,42 @@ class Candidate(TypedDict):
|
||||
map: dict[int, str]
|
||||
|
||||
|
||||
class SheetImageCandidate(TypedDict):
|
||||
anchor: tuple[int, int]
|
||||
content_hash: str
|
||||
file_key: str
|
||||
image_bytes: bytes
|
||||
image_ext: str
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, encoding: str | None = None, autodetect_encoding: bool = False):
|
||||
_file_path: str
|
||||
_encoding: str | None
|
||||
_autodetect_encoding: bool
|
||||
_tenant_id: str | None
|
||||
_user_id: str | None
|
||||
_source_file_id: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
source_file_id: str | None = None,
|
||||
encoding: str | None = None,
|
||||
autodetect_encoding: bool = False,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._source_file_id = source_file_id
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
@ -37,7 +81,8 @@ class ExcelExtractor(BaseExtractor):
|
||||
file_extension = os.path.splitext(self._file_path)[-1].lower()
|
||||
|
||||
if file_extension == ".xlsx":
|
||||
wb = load_workbook(self._file_path, read_only=True, data_only=True)
|
||||
# Worksheet drawing objects, including embedded images, are not available in read-only mode.
|
||||
wb = load_workbook(self._file_path, data_only=True)
|
||||
try:
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
@ -45,10 +90,15 @@ class ExcelExtractor(BaseExtractor):
|
||||
if not column_map:
|
||||
continue
|
||||
start_row = header_row_idx + 1
|
||||
sheet_image_map = self._extract_images_from_sheet(
|
||||
sheet_name=sheet_name,
|
||||
sheet=sheet,
|
||||
valid_columns={column_idx + 1 for column_idx in column_map},
|
||||
min_row=start_row,
|
||||
)
|
||||
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
|
||||
if all(cell.value is None for cell in row):
|
||||
continue
|
||||
page_content = []
|
||||
row_has_content = False
|
||||
for col_idx, cell in enumerate(row):
|
||||
value = cell.value
|
||||
if col_idx in column_map:
|
||||
@ -56,14 +106,27 @@ class ExcelExtractor(BaseExtractor):
|
||||
if hasattr(cell, "hyperlink") and cell.hyperlink:
|
||||
target = getattr(cell.hyperlink, "target", None)
|
||||
if target:
|
||||
value = f"[{value}]({target})"
|
||||
display_value = value if value is not None and str(value).strip() else target
|
||||
value = f"[{display_value}]({target})"
|
||||
cell_row = getattr(cell, "row", None)
|
||||
cell_column = getattr(cell, "column", None)
|
||||
image_links = (
|
||||
sheet_image_map.get((cell_row, cell_column), [])
|
||||
if isinstance(cell_row, int) and isinstance(cell_column, int)
|
||||
else []
|
||||
)
|
||||
if value is None:
|
||||
value = ""
|
||||
elif not isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.strip().replace('"', '\\"')
|
||||
if image_links:
|
||||
value = " ".join(filter(None, [value, " ".join(image_links)]))
|
||||
value = value.strip()
|
||||
if value:
|
||||
row_has_content = True
|
||||
value = value.replace('"', '\\"')
|
||||
page_content.append(f'"{col_name}":"{value}"')
|
||||
if page_content:
|
||||
if row_has_content and page_content:
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
@ -89,6 +152,166 @@ class ExcelExtractor(BaseExtractor):
|
||||
|
||||
return documents
|
||||
|
||||
def _extract_images_from_sheet(
|
||||
self, sheet_name: str, sheet, valid_columns: set[int], min_row: int
|
||||
) -> dict[tuple[int, int], list[str]]:
|
||||
"""
|
||||
Extract embedded worksheet images and map them to their anchor cell.
|
||||
|
||||
Images are stored with deterministic keys derived from the source upload file,
|
||||
sheet, anchor cell, and content hash so retried tasks can reuse the same
|
||||
UploadFile rows and storage objects.
|
||||
"""
|
||||
if not self._tenant_id or not self._user_id or not self._source_file_id:
|
||||
return {}
|
||||
|
||||
images = getattr(sheet, "_images", None) or []
|
||||
image_candidates: list[SheetImageCandidate] = []
|
||||
|
||||
for image in images:
|
||||
marker = getattr(getattr(image, "anchor", None), "_from", None)
|
||||
row_idx = getattr(marker, "row", None)
|
||||
col_idx = getattr(marker, "col", None)
|
||||
if row_idx is None or col_idx is None:
|
||||
continue
|
||||
if row_idx + 1 < min_row or col_idx + 1 not in valid_columns:
|
||||
continue
|
||||
|
||||
image_bytes = self._get_image_bytes(image)
|
||||
if not image_bytes:
|
||||
continue
|
||||
|
||||
image_ext = self._get_image_extension(image)
|
||||
if not image_ext:
|
||||
continue
|
||||
|
||||
anchor_row = row_idx + 1
|
||||
anchor_column = col_idx + 1
|
||||
content_hash = self._hash_image_bytes(image_bytes)
|
||||
image_candidates.append(
|
||||
{
|
||||
"anchor": (anchor_row, anchor_column),
|
||||
"content_hash": content_hash,
|
||||
"file_key": self._build_image_file_key(
|
||||
sheet_name=sheet_name,
|
||||
anchor_row=anchor_row,
|
||||
anchor_column=anchor_column,
|
||||
content_hash=content_hash,
|
||||
image_ext=image_ext,
|
||||
),
|
||||
"image_bytes": image_bytes,
|
||||
"image_ext": image_ext,
|
||||
}
|
||||
)
|
||||
|
||||
if not image_candidates:
|
||||
return {}
|
||||
|
||||
image_map: dict[tuple[int, int], list[str]] = {}
|
||||
base_url = dify_config.FILES_URL
|
||||
candidate_keys = sorted({candidate["file_key"] for candidate in image_candidates})
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
existing_upload_files = session.scalars(
|
||||
select(UploadFile).where(
|
||||
UploadFile.tenant_id == self._tenant_id,
|
||||
UploadFile.key.in_(candidate_keys),
|
||||
)
|
||||
).all()
|
||||
upload_files_by_key = {upload_file.key: upload_file for upload_file in existing_upload_files}
|
||||
new_upload_files: list[UploadFile] = []
|
||||
|
||||
for candidate in image_candidates:
|
||||
upload_file = upload_files_by_key.get(candidate["file_key"])
|
||||
if upload_file is None:
|
||||
storage.save(candidate["file_key"], candidate["image_bytes"])
|
||||
mime_type, _ = mimetypes.guess_type(candidate["file_key"])
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=candidate["file_key"],
|
||||
name=candidate["file_key"],
|
||||
size=len(candidate["image_bytes"]),
|
||||
extension=candidate["image_ext"],
|
||||
mime_type=mime_type or "",
|
||||
created_by=self._user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self._user_id,
|
||||
used_at=naive_utc_now(),
|
||||
hash=candidate["content_hash"],
|
||||
)
|
||||
upload_files_by_key[candidate["file_key"]] = upload_file
|
||||
new_upload_files.append(upload_file)
|
||||
|
||||
image_map.setdefault(candidate["anchor"], []).append(
|
||||
f""
|
||||
)
|
||||
|
||||
if new_upload_files:
|
||||
session.add_all(new_upload_files)
|
||||
session.commit()
|
||||
|
||||
return image_map
|
||||
|
||||
@staticmethod
|
||||
def _hash_image_bytes(image_bytes: bytes) -> str:
|
||||
"""Return a stable content hash for extracted image bytes."""
|
||||
return hashlib.sha256(image_bytes).hexdigest()
|
||||
|
||||
def _build_image_file_key(
|
||||
self,
|
||||
*,
|
||||
sheet_name: str,
|
||||
anchor_row: int,
|
||||
anchor_column: int,
|
||||
content_hash: str,
|
||||
image_ext: str,
|
||||
) -> str:
|
||||
"""Build a deterministic storage key for an embedded worksheet image."""
|
||||
assert self._tenant_id is not None, "tenant_id is required for image extraction"
|
||||
assert self._source_file_id is not None, "source_file_id is required for image extraction"
|
||||
|
||||
normalized_ext = image_ext.strip().lower()
|
||||
sheet_hash = hashlib.sha256(sheet_name.encode("utf-8")).hexdigest()[:16]
|
||||
return (
|
||||
f"image_files/{self._tenant_id}/{self._source_file_id}/"
|
||||
f"{sheet_hash}_r{anchor_row}_c{anchor_column}_{content_hash}.{normalized_ext}"
|
||||
)
|
||||
|
||||
def _get_image_bytes(self, image) -> bytes | None:
|
||||
"""Return embedded image bytes from an openpyxl image object."""
|
||||
data_loader = getattr(image, "_data", None)
|
||||
if not callable(data_loader):
|
||||
return None
|
||||
|
||||
try:
|
||||
data = data_loader()
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
if isinstance(data, bytearray):
|
||||
return bytes(data)
|
||||
logger.warning("Unexpected embedded image payload type: %s", type(data).__name__)
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning("Failed to read embedded image bytes from Excel sheet", exc_info=True)
|
||||
return None
|
||||
|
||||
def _get_image_extension(self, image) -> str | None:
|
||||
"""Resolve an image extension from openpyxl metadata."""
|
||||
image_format = getattr(image, "format", None)
|
||||
if isinstance(image_format, str) and image_format.strip():
|
||||
return image_format.strip().lower()
|
||||
|
||||
image_path = getattr(image, "path", None)
|
||||
if isinstance(image_path, str):
|
||||
_, extension = os.path.splitext(image_path)
|
||||
if extension:
|
||||
return extension.lstrip(".").lower()
|
||||
|
||||
return None
|
||||
|
||||
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
|
||||
"""
|
||||
Scan first N rows to find the most likely header row.
|
||||
|
||||
@ -113,7 +113,12 @@ class ExtractProcessor:
|
||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or ""
|
||||
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
extractor = ExcelExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id,
|
||||
upload_file.created_by,
|
||||
upload_file.id,
|
||||
)
|
||||
elif file_extension == ".pdf":
|
||||
assert upload_file is not None
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
@ -151,7 +156,12 @@ class ExtractProcessor:
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
extractor = ExcelExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id,
|
||||
upload_file.created_by,
|
||||
upload_file.id,
|
||||
)
|
||||
elif file_extension == ".pdf":
|
||||
assert upload_file is not None
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, TypedDict, cast, override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -61,6 +61,7 @@ class ParagraphFormatPreviewDict(TypedDict):
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
@override
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
extract_setting=extract_setting,
|
||||
@ -71,6 +72,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return text_docs
|
||||
|
||||
@override
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
@ -120,6 +122,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
all_documents.extend(split_documents)
|
||||
return all_documents
|
||||
|
||||
@override
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
@ -142,6 +145,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.add_texts(documents)
|
||||
|
||||
@override
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
@ -178,6 +182,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.delete()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
@ -206,6 +211,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@override
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
documents: list[Any] = []
|
||||
all_multimodal_documents: list[Any] = []
|
||||
@ -271,6 +277,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
@override
|
||||
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
|
||||
if isinstance(chunks, list):
|
||||
preview = []
|
||||
@ -285,6 +292,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
@override
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
@ -44,6 +44,7 @@ class ParentChildFormatPreviewDict(TypedDict):
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
@override
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
extract_setting=extract_setting,
|
||||
@ -54,6 +55,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return text_docs
|
||||
|
||||
@override
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
@ -129,6 +131,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return all_documents
|
||||
|
||||
@override
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
@ -149,6 +152,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
@override
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
@ -219,6 +223,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
@ -283,6 +288,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_nodes.append(child_document)
|
||||
return child_nodes
|
||||
|
||||
@override
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
@ -356,6 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
|
||||
@override
|
||||
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
@ -369,6 +376,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
return result
|
||||
|
||||
@override
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
@ -43,6 +43,7 @@ class QAFormatPreviewDict(TypedDict):
|
||||
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
@override
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
extract_setting=extract_setting,
|
||||
@ -52,6 +53,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
)
|
||||
return text_docs
|
||||
|
||||
@override
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
preview = kwargs.get("preview")
|
||||
process_rule = kwargs.get("process_rule")
|
||||
@ -139,6 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError(str(e))
|
||||
return text_docs
|
||||
|
||||
@override
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
@ -153,6 +156,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
@override
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
@ -183,6 +187,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
@ -211,6 +216,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@override
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
@ -234,6 +240,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
raise ValueError("Indexing technique must be high quality.")
|
||||
|
||||
@override
|
||||
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
@ -246,6 +253,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
return result
|
||||
|
||||
@override
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
from typing import override
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -16,6 +17,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
def __init__(self, rerank_model_instance: ModelInstance):
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
@override
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import math
|
||||
from collections import Counter
|
||||
from typing import override
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -19,6 +20,7 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
self.tenant_id = tenant_id
|
||||
self.weights = weights
|
||||
|
||||
@override
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import codecs
|
||||
import re
|
||||
from collections.abc import Set as AbstractSet
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, override
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
|
||||
@ -51,6 +51,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
|
||||
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
|
||||
|
||||
@override
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
if self._fixed_separator:
|
||||
|
||||
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, override
|
||||
|
||||
from core.rag.models.document import BaseDocumentTransformer, Document
|
||||
|
||||
@ -148,10 +148,12 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
)
|
||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||
|
||||
@override
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
@override
|
||||
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
raise NotImplementedError
|
||||
@ -211,6 +213,7 @@ class TokenTextSplitter(TextSplitter):
|
||||
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
|
||||
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
|
||||
|
||||
@override
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
def _encode(_text: str) -> list[int]:
|
||||
return self._tokenizer.encode(
|
||||
@ -287,5 +290,6 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
|
||||
return final_chunks
|
||||
|
||||
@override
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
@ -105,6 +105,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
@override
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
@ -182,6 +183,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
@ -14,6 +14,7 @@ from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ASRTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -56,6 +57,7 @@ class ASRTool(BuiltinTool):
|
||||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
@override
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
@ -12,6 +12,7 @@ from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class TTSTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -66,6 +67,7 @@ class TTSTool(BuiltinTool):
|
||||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
@override
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
@ -8,6 +8,7 @@ from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pytz import timezone as pytz_timezone # type: ignore[import-untyped]
|
||||
|
||||
@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class CurrentTimeTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import pytz # type: ignore[import-untyped]
|
||||
|
||||
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class LocaltimeToTimestampTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import pytz # type: ignore[import-untyped]
|
||||
|
||||
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class TimestampToLocaltimeTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import pytz # type: ignore[import-untyped]
|
||||
|
||||
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class TimezoneConversionTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
import calendar
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class WeekdayTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user